From 5e1974e8cad3396a8c4bedfd63c1ad31b82ec8eb Mon Sep 17 00:00:00 2001 From: Yong He Date: Tue, 30 May 2023 10:41:34 -0700 Subject: Fix derivative signature bug in checkDerivativeAttribute. (#2905) --- source/slang/slang-check-decl.cpp | 3 ++- .../reverse-inout-param-custom-derivative.slang | 28 ++++++++++++++++++++++ ...nout-param-custom-derivative.slang.expected.txt | 2 ++ 3 files changed, 32 insertions(+), 1 deletion(-) create mode 100644 tests/autodiff/reverse-inout-param-custom-derivative.slang create mode 100644 tests/autodiff/reverse-inout-param-custom-derivative.slang.expected.txt diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index c7a955f06..4e2f146a1 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -7027,7 +7027,8 @@ namespace Slang List imaginaryArguments; auto isOutParam = [&](ParamDecl* param) { - return param->findModifier() != nullptr && param->findModifier() == nullptr; + return param->findModifier() != nullptr + && param->findModifier() == nullptr && param->findModifier() == nullptr; }; for (auto param : originalFuncDecl->getParameters()) diff --git a/tests/autodiff/reverse-inout-param-custom-derivative.slang b/tests/autodiff/reverse-inout-param-custom-derivative.slang new file mode 100644 index 000000000..8769a33c7 --- /dev/null +++ b/tests/autodiff/reverse-inout-param-custom-derivative.slang @@ -0,0 +1,28 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +float rng(inout int state, float x) +{ + return state + x; +} + +[BackwardDerivativeOf(rng)] +void rng_bwd(int inState, inout DifferentialPair x, float dOut) +{ + x = diffPair(x.p, (float)inState + dOut - 1.0); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var x = diffPair(2.0, 1.0); + + __bwd_diff(rng)(4, x, 3.0); + + outputBuffer[0] = x.d; // should be 6 + +} diff --git a/tests/autodiff/reverse-inout-param-custom-derivative.slang.expected.txt b/tests/autodiff/reverse-inout-param-custom-derivative.slang.expected.txt new file mode 100644 index 000000000..bb8f4137d --- /dev/null +++ b/tests/autodiff/reverse-inout-param-custom-derivative.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +6.000000 -- cgit v1.2.3