From 2179480e28bdd46c71cec269a8f55ba93aa54f53 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 27 Mar 2023 10:05:07 -0700 Subject: Fix lowering crash in [BackwardDerivativeOf]. (#2737) Co-authored-by: Yong He --- tests/autodiff/custom-derivative-array-param.slang | 30 ++++++++++++++++++++++ ...ustom-derivative-array-param.slang.expected.txt | 5 ++++ 2 files changed, 35 insertions(+) create mode 100644 tests/autodiff/custom-derivative-array-param.slang create mode 100644 tests/autodiff/custom-derivative-array-param.slang.expected.txt (limited to 'tests') diff --git a/tests/autodiff/custom-derivative-array-param.slang b/tests/autodiff/custom-derivative-array-param.slang new file mode 100644 index 000000000..d50454b7a --- /dev/null +++ b/tests/autodiff/custom-derivative-array-param.slang @@ -0,0 +1,30 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +struct NonDiff +{ + float a; +} + +void getCode(uint2 x, out Array v) +{ + for (int i = 0; i < 8; i++) + v[i] = (float)i; +} + +[BackwardDerivativeOf(getCode)] +void getCode_bwd(uint2 x, Array dout) +{ + outputBuffer[0] = dout[0]; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + float a = 10.0; + float inArray[2] = { 1, 2 }; + __bwd_diff(getCode)(uint2(1,2), inArray); +} \ No newline at end of file diff --git a/tests/autodiff/custom-derivative-array-param.slang.expected.txt b/tests/autodiff/custom-derivative-array-param.slang.expected.txt new file mode 100644 index 000000000..5fce3dc6d --- /dev/null +++ b/tests/autodiff/custom-derivative-array-param.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +1.000000 +0.000000 +0.000000 +0.000000 \ No newline at end of file -- cgit v1.2.3