summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-03-17 12:02:37 -0700
committerGitHub <noreply@github.com>2025-03-17 19:02:37 +0000
commit0c7104e609d93a46d247f75d4ea8a16dc5ee5855 (patch)
treeff88be7ff64fde1860cbeb00f80255f665f5c5be
parent714ee76af46b96c32724f0d6edb159fddeffc6bf (diff)
Add auto-diff support for `GetOffsetPtr` (#6625)
-rw-r--r--source/slang/slang-ir-addr-inst-elimination.cpp1
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp1
-rw-r--r--tests/autodiff/get-offset-ptr.slang40
3 files changed, 42 insertions, 0 deletions
diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp
index 8dcecf285..51477419b 100644
--- a/source/slang/slang-ir-addr-inst-elimination.cpp
+++ b/source/slang/slang-ir-addr-inst-elimination.cpp
@@ -174,6 +174,7 @@ struct AddressInstEliminationContext
case kIROp_FieldAddress:
case kIROp_Unmodified:
case kIROp_DebugValue:
+ case kIROp_GetOffsetPtr:
break;
default:
sink->diagnose(
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index e146ac3e0..92c35a618 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -2187,6 +2187,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_MakeCoopVector:
case kIROp_MakeCoopVectorFromValuePack:
case kIROp_GetCurrentStage:
+ case kIROp_GetOffsetPtr:
return transcribeNonDiffInst(builder, origInst);
// A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value,
diff --git a/tests/autodiff/get-offset-ptr.slang b/tests/autodiff/get-offset-ptr.slang
new file mode 100644
index 000000000..517acb54d
--- /dev/null
+++ b/tests/autodiff/get-offset-ptr.slang
@@ -0,0 +1,40 @@
+//TEST:SIMPLE(filecheck=CHECK): -target cuda -line-directive-mode none
+
+//CHECK: struct s_bwd_prop_function_Intermediates{{[_0-9]+}}
+//CHECK: {
+//CHECK: MyDiffPtr{{[_0-9]+}} {{[_A-Za-z0-9]+}};
+//CHECK: MyDiffPtr{{[_0-9]+}} {{[_A-Za-z0-9]+}};
+//CHECK: };
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out, name outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+struct MyDiffPtr
+{
+ uint offset;
+ uint d_offset;
+
+ [BackwardDerivative(__bwd_foo)]
+ float foo()
+ {
+ return outputBuffer[offset] * outputBuffer[offset];
+ }
+
+ void __bwd_foo(float grad)
+ {
+ outputBuffer[d_offset] = 2.f * outputBuffer[offset] * grad;
+ }
+};
+
+[Differentiable]
+float function(MyDiffPtr *i)
+{
+ return i[0].foo() + i[1].foo();
+}
+
+[numthreads(1, 1, 1), shader("compute")]
+void main(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ MyDiffPtr s[2] = {{0, 2}, {1, 3}};
+ __bwd_diff(function)(&s[0], 1.0f);
+} \ No newline at end of file