diff options
| author | Yong He <yonghe@outlook.com> | 2023-05-30 10:41:34 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-05-30 10:41:34 -0700 |
| commit | 5e1974e8cad3396a8c4bedfd63c1ad31b82ec8eb (patch) | |
| tree | d47003966c2fd0bc6deb40c5220cc17f0855da78 | |
| parent | 4c1396c3532d6ad4973177d1c97578989385f347 (diff) | |
Fix derivative signature bug in checkDerivativeAttribute. (#2905)
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 3 | ||||
| -rw-r--r-- | tests/autodiff/reverse-inout-param-custom-derivative.slang | 28 | ||||
| -rw-r--r-- | tests/autodiff/reverse-inout-param-custom-derivative.slang.expected.txt | 2 |
3 files changed, 32 insertions, 1 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index c7a955f06..4e2f146a1 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -7027,7 +7027,8 @@ namespace Slang List<Expr*> imaginaryArguments; auto isOutParam = [&](ParamDecl* param) { - return param->findModifier<OutModifier>() != nullptr && param->findModifier<InModifier>() == nullptr; + return param->findModifier<OutModifier>() != nullptr + && param->findModifier<InModifier>() == nullptr && param->findModifier<InOutModifier>() == nullptr; }; for (auto param : originalFuncDecl->getParameters()) diff --git a/tests/autodiff/reverse-inout-param-custom-derivative.slang b/tests/autodiff/reverse-inout-param-custom-derivative.slang new file mode 100644 index 000000000..8769a33c7 --- /dev/null +++ b/tests/autodiff/reverse-inout-param-custom-derivative.slang @@ -0,0 +1,28 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +float rng(inout int state, float x) +{ + return state + x; +} + +[BackwardDerivativeOf(rng)] +void rng_bwd(int inState, inout DifferentialPair<float> x, float dOut) +{ + x = diffPair(x.p, (float)inState + dOut - 1.0); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var x = diffPair(2.0, 1.0); + + __bwd_diff(rng)(4, x, 3.0); + + outputBuffer[0] = x.d; // should be 6 + +} diff --git a/tests/autodiff/reverse-inout-param-custom-derivative.slang.expected.txt b/tests/autodiff/reverse-inout-param-custom-derivative.slang.expected.txt new file mode 100644 index 000000000..bb8f4137d --- /dev/null +++ b/tests/autodiff/reverse-inout-param-custom-derivative.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +6.000000 |
