summaryrefslogtreecommitdiffstats
path: root/tests/autodiff/get-offset-ptr.slang
diff options
context:
space:
mode:
Diffstat (limited to 'tests/autodiff/get-offset-ptr.slang')
-rw-r--r--tests/autodiff/get-offset-ptr.slang40
1 files changed, 16 insertions, 24 deletions
diff --git a/tests/autodiff/get-offset-ptr.slang b/tests/autodiff/get-offset-ptr.slang
index 517acb54d..e497f1e48 100644
--- a/tests/autodiff/get-offset-ptr.slang
+++ b/tests/autodiff/get-offset-ptr.slang
@@ -1,40 +1,32 @@
-//TEST:SIMPLE(filecheck=CHECK): -target cuda -line-directive-mode none
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -cuda -output-using-type
-//CHECK: struct s_bwd_prop_function_Intermediates{{[_0-9]+}}
-//CHECK: {
-//CHECK: MyDiffPtr{{[_0-9]+}} {{[_A-Za-z0-9]+}};
-//CHECK: MyDiffPtr{{[_0-9]+}} {{[_A-Za-z0-9]+}};
-//CHECK: };
+// This test just ensures that we compile and run the code.
+// It does not check the correctness of the autodiff.
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out, name outputBuffer
RWStructuredBuffer<float> outputBuffer;
struct MyDiffPtr
{
- uint offset;
- uint d_offset;
-
- [BackwardDerivative(__bwd_foo)]
- float foo()
- {
- return outputBuffer[offset] * outputBuffer[offset];
- }
-
- void __bwd_foo(float grad)
- {
- outputBuffer[d_offset] = 2.f * outputBuffer[offset] * grad;
- }
+ float data1;
+ float data2;
};
[Differentiable]
-float function(MyDiffPtr *i)
+float function(Ptr<MyDiffPtr, Access::ReadWrite, AddressSpace::GroupShared> i)
{
- return i[0].foo() + i[1].foo();
+ return i[0].data1 + i[1].data2;
}
+groupshared MyDiffPtr s[2];
[numthreads(1, 1, 1), shader("compute")]
-void main(uint3 dispatchThreadID: SV_DispatchThreadID)
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
{
- MyDiffPtr s[2] = {{0, 2}, {1, 3}};
- __bwd_diff(function)(&s[0], 1.0f);
+ s = { { 0, 2 }, { 1, 3 } };
+ float result = 1.0f;
+ let pair = __fwd_diff(function)(__getAddress(s[0]));
+ outputBuffer[0] = pair.getPrimal();
+ outputBuffer[1] = pair.getDifferential();
+ // CHECK: 3.0
+ // CHECK-NEXT: 0.0
} \ No newline at end of file