From 551bbb5fbd61b53253de8f6ba3303bb4d29f8c86 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Tue, 11 Feb 2025 03:08:27 -0800 Subject: Add checking for differentiability of the primal substitute function. (#6277) Co-authored-by: Yong He Co-authored-by: Ellie Hermaszewska --- tests/diagnostics/autodiff-primal-substitute.slang | 44 ++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 tests/diagnostics/autodiff-primal-substitute.slang (limited to 'tests/diagnostics') diff --git a/tests/diagnostics/autodiff-primal-substitute.slang b/tests/diagnostics/autodiff-primal-substitute.slang new file mode 100644 index 000000000..178698719 --- /dev/null +++ b/tests/diagnostics/autodiff-primal-substitute.slang @@ -0,0 +1,44 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -target spirv -entry computeMain -stage compute + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBufferPrimal +RWStructuredBuffer outputBufferPrimal; + +//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4):name=gradBuffer +RWStructuredBuffer gradBuffer; + +struct BufferWithGrad +{ + RWStructuredBuffer primal; + RWStructuredBuffer grad; + + [Differentiable] + void add(float value) { primal[0] = primal[0] + detach(value); } + + // check for diagnostic: + // CHECK-DAG: ([[# @LINE+1]]): error 31158 + [PrimalSubstituteOf(add)] + void add_subst(float value) + { + } + + [BackwardDerivativeOf(add)] + void add_bwd(inout DifferentialPair d) + { + d = diffPair(d.p, grad[0]); + } +} + +[Differentiable] +void diffCall(BufferWithGrad result) +{ + result.add(1.0f); +} + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + BufferWithGrad bg = {outputBufferPrimal, gradBuffer}; + diffCall(bg); + bwd_diff(diffCall)(bg); +} -- cgit v1.2.3