diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-09-19 03:10:28 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-09-19 00:10:28 -0700 |
| commit | ccc310fa4e8096cda8a6c127aacc1a1fa9d8503a (patch) | |
| tree | 435e9c462a78fb848ab3b36c23287543d1a859de /tests | |
| parent | 1781c2969eb65fb7ade01d3f0d7d9b8973bcd4d3 (diff) | |
Support `IDifferentiablePtrType` (#5031)
* initial diff-ref-type interface
* Initial support for `IDifferentiablePtrType`
* Fix unused vars
* More tests + fix switch case fallthrough.
* Update slang-ir-autodiff.cpp
* Update diff-ptr-type-loop.slang
* Add optimization to allow more complex pair types
* Update slang-ir-autodiff-primal-hoist.cpp
* Update diff-ptr-type-loop.slang
* Update slang-ir-autodiff-primal-hoist.cpp
* More fixes to address reviews
* Update slang-check-expr.cpp
* Optimizations + rename `differentiableRefInterfaceType` -> `differentiablePtrInterfaceType`
* Move pair logic to ir-builder, unify the type dictionaries.
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/autodiff/diff-ptr-type-call.slang | 57 | ||||
| -rw-r--r-- | tests/autodiff/diff-ptr-type-loop.slang | 65 | ||||
| -rw-r--r-- | tests/autodiff/diff-ptr-type-smoke.slang | 49 |
3 files changed, 171 insertions, 0 deletions
diff --git a/tests/autodiff/diff-ptr-type-call.slang b/tests/autodiff/diff-ptr-type-call.slang new file mode 100644 index 000000000..258a4477b --- /dev/null +++ b/tests/autodiff/diff-ptr-type-call.slang @@ -0,0 +1,57 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +// ----- MyPtrType definition ----- +struct MyPtrType : IDifferentiablePtrType +{ + typealias Differential = MyPtrType; + + RWStructuredBuffer<float> buffer; + uint offset; + + float load(uint idx) { return buffer[offset + idx]; } + void accumulate(uint idx, float value) { buffer[offset + idx] += value; } +} + +[BackwardDerivative(load_bwd)] +float load(MyPtrType b, uint idx) +{ + return b.load(idx); +} + +void load_bwd(DifferentialPtrPair<MyPtrType> b, uint idx, float grad) +{ + b.d.accumulate(idx, grad); +} + +// ------ +[Differentiable] +float reduce(MyPtrType a) +{ + return load(a, 0) + load(a, 1); +} + +[Differentiable] +float test(MyPtrType b) +{ + return reduce(b); +} + +[numthreads(1, 1, 1)] +void computeMain(uint id : SV_DispatchThreadID) +{ + outputBuffer[0] = 1; // CHECK: 1 + outputBuffer[1] = 2; // CHECK: 2 + + // Denote the first two elements in the buffer as the primal buffer and the last two elements for the derivative. + var b = DifferentialPtrPair<MyPtrType>( { outputBuffer, 0 }, { outputBuffer, 2 } ); + + bwd_diff(test)(b, 1.5f); + + // Check locations [2] and [3] in the buffer + // CHECK: 1.5 + // CHECK: 1.5 +}
\ No newline at end of file diff --git a/tests/autodiff/diff-ptr-type-loop.slang b/tests/autodiff/diff-ptr-type-loop.slang new file mode 100644 index 000000000..a57c69b76 --- /dev/null +++ b/tests/autodiff/diff-ptr-type-loop.slang @@ -0,0 +1,65 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +// ----- MyPtrType definition ----- +struct MyPtrType : IDifferentiablePtrType +{ + typealias Differential = MyPtrType; + + RWStructuredBuffer<float> buffer; + uint offset; + + float load(uint idx) { return buffer[offset + idx]; } + void accumulate(uint idx, float value) { buffer[offset + idx] += value; } +} + +[BackwardDerivative(load_bwd)] +float load(MyPtrType b, uint idx) +{ + return b.load(idx); +} + +void load_bwd(DifferentialPtrPair<MyPtrType> b, uint idx, float grad) +{ + b.d.accumulate(idx, grad); +} + + +// ------ +[Differentiable] +float reduce(MyPtrType a, uint num) +{ + float sum = 0; + [MaxIters(3)] + for (uint i = 0; i < num; i++) + { + sum += load(a, i); + } + + return sum; +} + +[Differentiable] +float test(MyPtrType b, uint num) +{ + return reduce(b, num); +} + +[numthreads(1, 1, 1)] +void computeMain(uint id : SV_DispatchThreadID) +{ + outputBuffer[0] = 1; // CHECK: 1 + outputBuffer[1] = 2; // CHECK: 2 + + // Denote the first two elements in the buffer as the primal buffer and the last two elements for the derivative. + var b = DifferentialPtrPair<MyPtrType>( { outputBuffer, 0 }, { outputBuffer, 2 } ); + + bwd_diff(test)(b, 2, 1.5f); + + // Check locations [2] and [3] in the buffer + // CHECK: 1.5 + // CHECK: 1.5 +}
\ No newline at end of file diff --git a/tests/autodiff/diff-ptr-type-smoke.slang b/tests/autodiff/diff-ptr-type-smoke.slang new file mode 100644 index 000000000..e7e03c5e3 --- /dev/null +++ b/tests/autodiff/diff-ptr-type-smoke.slang @@ -0,0 +1,49 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +struct MyPtrType : IDifferentiablePtrType +{ + typealias Differential = MyPtrType; + + RWStructuredBuffer<float> buffer; + uint offset; + + float load(uint idx) { return buffer[offset + idx]; } + void accumulate(uint idx, float value) { buffer[offset + idx] += value; } +} + +[BackwardDerivative(load_bwd)] +float load(MyPtrType b, uint idx) +{ + return b.load(idx); +} + +void load_bwd(DifferentialPtrPair<MyPtrType> b, uint idx, float grad) +{ + b.d.accumulate(idx, grad); +} + +[BackwardDifferentiable] +float test(MyPtrType b, uint idx) +{ + return load(b, idx) + load(b, idx + 1); +} + +[numthreads(1, 1, 1)] +void computeMain(uint id : SV_DispatchThreadID) +{ + outputBuffer[0] = 1; // CHECK: 1 + outputBuffer[1] = 2; // CHECK: 2 + + // Denote the first two elements in the buffer as the primal buffer and the last two elements for the derivative. + var b = DifferentialPtrPair<MyPtrType>( { outputBuffer, 0 }, { outputBuffer, 2 } ); + + bwd_diff(test)(b, id, 1.5f); + + // Check locations [2] and [3] in the buffer + // CHECK: 1.5 + // CHECK: 1.5 +}
\ No newline at end of file |
