diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-02-11 03:08:27 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-02-11 19:08:27 +0800 |
| commit | 551bbb5fbd61b53253de8f6ba3303bb4d29f8c86 (patch) | |
| tree | 44b21ceefd66ac2b92c1b165fe280dcfe276cf65 /tests | |
| parent | 0b4e463aee4107b383067424007c6a995f1f9f87 (diff) | |
Add checking for differentiability of the primal substitute function. (#6277)
Co-authored-by: Yong He <yonghe@outlook.com>
Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com>
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/primal-substitute-4.slang | 46 | ||||
| -rw-r--r-- | tests/diagnostics/autodiff-primal-substitute.slang | 44 |
2 files changed, 90 insertions, 0 deletions
diff --git a/tests/autodiff/primal-substitute-4.slang b/tests/autodiff/primal-substitute-4.slang new file mode 100644 index 000000000..8f7720639 --- /dev/null +++ b/tests/autodiff/primal-substitute-4.slang @@ -0,0 +1,46 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type -g0 + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBufferPrimal +RWStructuredBuffer<float> outputBufferPrimal; + +//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4):name=gradBuffer +RWStructuredBuffer<float> gradBuffer; + +struct BufferWithGrad +{ + RWStructuredBuffer<float> primal; + RWStructuredBuffer<float> grad; + + [Differentiable] + void add(float value) { primal[0] = primal[0] + detach(value); } + + [PrimalSubstituteOf(add), Differentiable] + void add_subst(float value) + { + } + + [BackwardDerivativeOf(add)] + void add_bwd(inout DifferentialPair<float> d) + { + d = diffPair(d.p, grad[0]); + } +} + +[Differentiable] +void diffCall(BufferWithGrad result) +{ + result.add(1.0f); +} + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + BufferWithGrad bg = {outputBufferPrimal, gradBuffer}; + diffCall(bg); + bwd_diff(diffCall)(bg); + + // CHECK: type: float + // CHECK-NEXT: 1.0 + // CHECK-NEXT: 0.0 +} diff --git a/tests/diagnostics/autodiff-primal-substitute.slang b/tests/diagnostics/autodiff-primal-substitute.slang new file mode 100644 index 000000000..178698719 --- /dev/null +++ b/tests/diagnostics/autodiff-primal-substitute.slang @@ -0,0 +1,44 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -target spirv -entry computeMain -stage compute + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBufferPrimal +RWStructuredBuffer<float> outputBufferPrimal; + +//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4):name=gradBuffer +RWStructuredBuffer<float> gradBuffer; + +struct BufferWithGrad +{ + RWStructuredBuffer<float> primal; + RWStructuredBuffer<float> grad; + + [Differentiable] + void add(float value) { primal[0] = primal[0] + detach(value); } + + // check for diagnostic: + // CHECK-DAG: ([[# @LINE+1]]): error 31158 + [PrimalSubstituteOf(add)] + void add_subst(float value) + { + } + + [BackwardDerivativeOf(add)] + void add_bwd(inout DifferentialPair<float> d) + { + d = diffPair(d.p, grad[0]); + } +} + +[Differentiable] +void diffCall(BufferWithGrad result) +{ + result.add(1.0f); +} + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + BufferWithGrad bg = {outputBufferPrimal, gradBuffer}; + diffCall(bg); + bwd_diff(diffCall)(bg); +} |
