diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-03-17 12:02:37 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-03-17 19:02:37 +0000 |
| commit | 0c7104e609d93a46d247f75d4ea8a16dc5ee5855 (patch) | |
| tree | ff88be7ff64fde1860cbeb00f80255f665f5c5be | |
| parent | 714ee76af46b96c32724f0d6edb159fddeffc6bf (diff) | |
Add auto-diff support for `GetOffsetPtr` (#6625)
| -rw-r--r-- | source/slang/slang-ir-addr-inst-elimination.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 1 | ||||
| -rw-r--r-- | tests/autodiff/get-offset-ptr.slang | 40 |
3 files changed, 42 insertions, 0 deletions
diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp index 8dcecf285..51477419b 100644 --- a/source/slang/slang-ir-addr-inst-elimination.cpp +++ b/source/slang/slang-ir-addr-inst-elimination.cpp @@ -174,6 +174,7 @@ struct AddressInstEliminationContext case kIROp_FieldAddress: case kIROp_Unmodified: case kIROp_DebugValue: + case kIROp_GetOffsetPtr: break; default: sink->diagnose( diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index e146ac3e0..92c35a618 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -2187,6 +2187,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_MakeCoopVector: case kIROp_MakeCoopVectorFromValuePack: case kIROp_GetCurrentStage: + case kIROp_GetOffsetPtr: return transcribeNonDiffInst(builder, origInst); // A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value, diff --git a/tests/autodiff/get-offset-ptr.slang b/tests/autodiff/get-offset-ptr.slang new file mode 100644 index 000000000..517acb54d --- /dev/null +++ b/tests/autodiff/get-offset-ptr.slang @@ -0,0 +1,40 @@ +//TEST:SIMPLE(filecheck=CHECK): -target cuda -line-directive-mode none + +//CHECK: struct s_bwd_prop_function_Intermediates{{[_0-9]+}} +//CHECK: { +//CHECK: MyDiffPtr{{[_0-9]+}} {{[_A-Za-z0-9]+}}; +//CHECK: MyDiffPtr{{[_0-9]+}} {{[_A-Za-z0-9]+}}; +//CHECK: }; + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out, name outputBuffer +RWStructuredBuffer<float> outputBuffer; + +struct MyDiffPtr +{ + uint offset; + uint d_offset; + + [BackwardDerivative(__bwd_foo)] + float foo() + { + return outputBuffer[offset] * outputBuffer[offset]; + } + + void __bwd_foo(float grad) + { + outputBuffer[d_offset] = 2.f * outputBuffer[offset] * grad; + } +}; + +[Differentiable] +float function(MyDiffPtr *i) +{ + return i[0].foo() + i[1].foo(); +} + +[numthreads(1, 1, 1), shader("compute")] +void main(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + MyDiffPtr s[2] = {{0, 2}, {1, 3}}; + __bwd_diff(function)(&s[0], 1.0f); +}
\ No newline at end of file |
