From d48cd130aacbab34bb98d51bb237ad38ff37348c Mon Sep 17 00:00:00 2001 From: kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> Date: Thu, 2 Jan 2025 14:29:57 -0600 Subject: Correct IR generation for no-diff pointer type (#5976) * Correct IR generation for no-diff pointer type Close #5805 There is an issue on checking whether a pointer type parameter is no_diff, we should first check whether this parameter is an Attribute type first, then check the data type. In the back-propagate pass, for the pointer type parameter, we should load this parameter to a temp variable, then pass it to the primal function call. Otherwise, the temp variable will no be initialized, which will cause the following calculation wrong. --- tests/autodiff/nodiff-ptr.slang | 40 ++++++++++++++++++++++++++++ tests/autodiff/nodiff-ptr.slang.expected.txt | 6 +++++ 2 files changed, 46 insertions(+) create mode 100644 tests/autodiff/nodiff-ptr.slang create mode 100644 tests/autodiff/nodiff-ptr.slang.expected.txt (limited to 'tests') diff --git a/tests/autodiff/nodiff-ptr.slang b/tests/autodiff/nodiff-ptr.slang new file mode 100644 index 000000000..d20abddac --- /dev/null +++ b/tests/autodiff/nodiff-ptr.slang @@ -0,0 +1,40 @@ + +[Differentiable] +float sumOfSquares(float x, float y, no_diff float4* test) +{ + return x * x + y * y * (test->x + test->y + test->z); +} + +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -compile-arg -skip-spirv-validation -emit-spirv-directly + +//TEST_INPUT: set ptr = ubuffer(data=[1.0 2.0 3.0], stride=4) +uniform float* ptr; + +//TEST_INPUT:ubuffer(data=[0.0 0.0 0.0 0.0 0.0], stride=4):out, name outputBuffer +RWStructuredBuffer outputBuffer; + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain() +{ + float4* testPtr = (float4*)ptr; + + let result = sumOfSquares(2.0, 3.0, testPtr); + + // Use forward differentiation to compute the gradient of the output w.r.t. x only. + let diffX = fwd_diff(sumOfSquares)(diffPair(2.0, 1.0), diffPair(3.0, 0.0), testPtr); + + // Create a differentiable pair to pass in the primal value and to receive the gradient. + var dpX = diffPair(2.0); + var dpY = diffPair(3.0); + + // Propagate the gradient of the output (1.0f) to the input parameters. + bwd_diff(sumOfSquares)(dpX, dpY, testPtr, 1.0); + + outputBuffer[0] = result; // 2^2 + 3^2 * (1 + 2 + 3) = 58 + outputBuffer[1] = diffX.d; // 2*x * dx + 2*y * dy * (1 + 2 + 3) = 4 + outputBuffer[2] = diffX.p; // 2^2 + 3^2 * (1 + 2 + 3) = 58 + outputBuffer[3] = dpX.d; // 2*x = 4 + + outputBuffer[4] = dpY.d; // 2*y * (1 + 2 +3) = 36 +} diff --git a/tests/autodiff/nodiff-ptr.slang.expected.txt b/tests/autodiff/nodiff-ptr.slang.expected.txt new file mode 100644 index 000000000..959cc68e4 --- /dev/null +++ b/tests/autodiff/nodiff-ptr.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +58.000000 +4.000000 +58.000000 +4.000000 +36.000000 -- cgit v1.2.3