diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-11-18 16:34:03 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-11-18 13:34:03 -0800 |
| commit | ec5e019fa9732b99b75b2a3ca4f2ff5a7a3d2f33 (patch) | |
| tree | f314f5b16ad18dd3325a7c3a4228242d6d448752 /tests/autodiff/diff-ptr-type-array.slang | |
| parent | 05903f708856a70d68bf41bbfb2b06620508dee0 (diff) | |
Add `IDifferentiablePtrType` support for arrays (#5576)
* Add `IDifferentiablePtrType` support for arrays
- Also fixes an issue with spirv-emit of constructors that contain references to global params
* Fix GLSL legalization for arrays of resource types
Diffstat (limited to 'tests/autodiff/diff-ptr-type-array.slang')
| -rw-r--r-- | tests/autodiff/diff-ptr-type-array.slang | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/tests/autodiff/diff-ptr-type-array.slang b/tests/autodiff/diff-ptr-type-array.slang new file mode 100644 index 000000000..30e6fe963 --- /dev/null +++ b/tests/autodiff/diff-ptr-type-array.slang @@ -0,0 +1,59 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type +//DISABLE_TEST(compute):COMPARE_COMPUTE:-wgpu + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +// ----- MyPtrType definition ----- +struct MyPtrType : IDifferentiablePtrType +{ + typealias Differential = MyPtrType; + + RWStructuredBuffer<float> buffer; + uint offset; + + float load(uint idx) { return buffer[offset + idx]; } + void accumulate(uint idx, float value) { buffer[offset + idx] += value; } +} + +[BackwardDerivative(load_bwd)] +float load(MyPtrType[2] b, uint idx) +{ + return b[1].load(idx); +} + +void load_bwd(DifferentialPtrPair<MyPtrType[2]> b, uint idx, float grad) +{ + b.d[1].accumulate(idx, grad); +} + +// ------ +[Differentiable] +float reduce(MyPtrType a) +{ + return load( { a, a }, 0) + load( { a, a }, 1); +} + +[Differentiable] +float test(MyPtrType b) +{ + return reduce(b); +} + +[numthreads(1, 1, 1)] +void computeMain(uint id: SV_DispatchThreadID) +{ + outputBuffer[0] = 1; // CHECK: 1 + outputBuffer[1] = 2; // CHECK: 2 + + // Denote the first two elements in the buffer as the primal buffer and the last two elements + // for the derivative. + var b = DifferentialPtrPair<MyPtrType>( { outputBuffer, 0 }, { outputBuffer, 2 }); + + bwd_diff(test)(b, 1.5f); + + // Check locations [2] and [3] in the buffer + // CHECK: 1.5 + // CHECK: 1.5 +} |
