diff options
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/primal-substitute-4.slang | 46 | ||||
| -rw-r--r-- | tests/diagnostics/autodiff-primal-substitute.slang | 44 |
2 files changed, 90 insertions, 0 deletions
diff --git a/tests/autodiff/primal-substitute-4.slang b/tests/autodiff/primal-substitute-4.slang new file mode 100644 index 000000000..8f7720639 --- /dev/null +++ b/tests/autodiff/primal-substitute-4.slang @@ -0,0 +1,46 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type -g0 + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBufferPrimal +RWStructuredBuffer<float> outputBufferPrimal; + +//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4):name=gradBuffer +RWStructuredBuffer<float> gradBuffer; + +struct BufferWithGrad +{ + RWStructuredBuffer<float> primal; + RWStructuredBuffer<float> grad; + + [Differentiable] + void add(float value) { primal[0] = primal[0] + detach(value); } + + [PrimalSubstituteOf(add), Differentiable] + void add_subst(float value) + { + } + + [BackwardDerivativeOf(add)] + void add_bwd(inout DifferentialPair<float> d) + { + d = diffPair(d.p, grad[0]); + } +} + +[Differentiable] +void diffCall(BufferWithGrad result) +{ + result.add(1.0f); +} + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + BufferWithGrad bg = {outputBufferPrimal, gradBuffer}; + diffCall(bg); + bwd_diff(diffCall)(bg); + + // CHECK: type: float + // CHECK-NEXT: 1.0 + // CHECK-NEXT: 0.0 +} diff --git a/tests/diagnostics/autodiff-primal-substitute.slang b/tests/diagnostics/autodiff-primal-substitute.slang new file mode 100644 index 000000000..178698719 --- /dev/null +++ b/tests/diagnostics/autodiff-primal-substitute.slang @@ -0,0 +1,44 @@ +//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK): -target spirv -entry computeMain -stage compute + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBufferPrimal +RWStructuredBuffer<float> outputBufferPrimal; + +//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4):name=gradBuffer +RWStructuredBuffer<float> gradBuffer; + +struct BufferWithGrad +{ + RWStructuredBuffer<float> primal; + RWStructuredBuffer<float> grad; + + [Differentiable] + void add(float value) { primal[0] = primal[0] + detach(value); } + + // check for diagnostic: + // CHECK-DAG: ([[# @LINE+1]]): error 31158 + [PrimalSubstituteOf(add)] + void add_subst(float value) + { + } + + [BackwardDerivativeOf(add)] + void add_bwd(inout DifferentialPair<float> d) + { + d = diffPair(d.p, grad[0]); + } +} + +[Differentiable] +void diffCall(BufferWithGrad result) +{ + result.add(1.0f); +} + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + BufferWithGrad bg = {outputBufferPrimal, gradBuffer}; + diffCall(bg); + bwd_diff(diffCall)(bg); +} |
