summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2025-01-10 13:33:02 -0800
committerGitHub <noreply@github.com>2025-01-10 13:33:02 -0800
commit4104aa7f95e0d29e877be5208031e2670fb5a77d (patch)
treee50d7642476668589a6aa5262fa773bd382461e8
parentf199640bb31e1e273e34a068ea0fb7a55f2afb5e (diff)
Fix `markNonContextParamsAsSideEffectFree`. (#6054)
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp1
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp2
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--tests/autodiff/max-iters.slang81
-rw-r--r--tests/autodiff/property.slang48
-rw-r--r--tests/autodiff/trivial-primal.slang41
6 files changed, 174 insertions, 1 deletions
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 5ac4016d7..65ce69877 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -1359,6 +1359,7 @@ ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParame
auto ctxParam =
builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1));
builder->addNameHintDecoration(ctxParam, UnownedStringSlice("_s_diff_ctx"));
+ builder->addDecoration(ctxParam, kIROp_PrimalContextDecoration);
result.primalFuncParams.add(ctxParam);
result.propagateFuncParams.add(ctxParam);
result.dOutParam = dOutParam;
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 49c1d9ff7..6bc428ad6 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -342,7 +342,7 @@ void markNonContextParamsAsSideEffectFree(IRBuilder* builder, IRFunc* func)
{
for (auto param : func->getParams())
{
- if (!isIntermediateContextType(param->getDataType()))
+ if (!param->findDecorationImpl(kIROp_PrimalContextDecoration))
builder->addDecoration(param, kIROp_IgnoreSideEffectsDecoration);
}
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index f5af73dfa..2f4c69820 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -1040,6 +1040,8 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace)
INST(BackwardDerivativePrimalContextDecoration, BackwardDerivativePrimalContextDecoration, 1, 0)
INST(BackwardDerivativePrimalReturnDecoration, BackwardDerivativePrimalReturnDecoration, 1, 0)
+ // Mark a parameter as autodiff primal context.
+ INST(PrimalContextDecoration, PrimalContextDecoration, 0, 0)
INST(LoopCounterDecoration, loopCounterDecoration, 0, 0)
INST(LoopCounterUpdateDecoration, loopCounterUpdateDecoration, 0, 0)
diff --git a/tests/autodiff/max-iters.slang b/tests/autodiff/max-iters.slang
new file mode 100644
index 000000000..c83057b43
--- /dev/null
+++ b/tests/autodiff/max-iters.slang
@@ -0,0 +1,81 @@
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -d3d12 -use-dxil
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -vk
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -metal
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -cuda
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -cpu
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -wgpu
+
+// Note: there is a bug in fxc compiler errorneously reporting infinite loop for this shader.
+// Skipping d3d11 test to avoid the bug.
+//DISABLE_TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type -d3d11
+
+struct GradientBuffer<let D : int>
+{
+ RWStructuredBuffer<float> primal;
+ StructuredBuffer<float> grad;
+ int strides[D];
+
+ int toIndex(int idx[D]) {
+ int result = 0;
+ for (int i = 0; i < D; ++i)
+ result += strides[i] * idx[i];
+ return result;
+ }
+
+ [Differentiable]
+ void write(int[D] idx, float v) { primal[toIndex(idx)] = detach(v); }
+
+ [BackwardDerivativeOf(write)]
+ void write_bwd(int[D] idx, inout DifferentialPair<float> d) { d = diffPair(d.p, grad[toIndex(idx)]); }
+
+ [Differentiable]
+ void store<let N : int>(int context[D - 1], in float value[N])
+ {
+ int idx[D];
+ //[ForceUnroll] /* Using ForceUnroll instead of MaxIters makes it work */
+ [MaxIters(2)]
+ for (int i = 0; i < D - 1; ++i)
+ idx[i] = context[i];
+ [ForceUnroll]
+ for (int i = 0; i < N; i++) {
+ idx[D - 1] = i;
+ write(idx, value[i]);
+ }
+ }
+}
+
+[Differentiable]
+void test(GradientBuffer<2> buf, int[1] base, float[3] value)
+{
+ buf.store(base, value);
+}
+
+float3 repro(RWStructuredBuffer<float> primal, StructuredBuffer<float> grad)
+{
+ float input[3];
+ input[0] = input[1] = input[2] = 1.0f;
+ var result = diffPair(input);
+ GradientBuffer<2> buf = { primal, grad, {3, 1} };
+ bwd_diff(test)(buf, { 1 }, result);
+ return float3(result.d[0], result.d[1], result.d[2]);
+}
+
+//TEST_INPUT: set grad_in = ubuffer(data=[101.0 102.0 103.0 104.0], stride=4)
+uniform StructuredBuffer<float> grad_in;
+
+//TEST_INPUT: set grad_out = ubuffer(data=[0 0 0 0], stride=4)
+uniform RWStructuredBuffer<float> grad_out;
+
+//TEST_INPUT: set output = out ubuffer(data=[0 0 0 0], stride=4)
+uniform RWStructuredBuffer<float> output;
+
+[shader("compute")]
+[numthreads(1,1,1)]
+void computeMain()
+{
+ let result = repro(grad_out, grad_in);
+ // CHECK: 104.0
+ output[0] = result.x;
+ output[1] = result.y;
+ output[2] = result.z;
+} \ No newline at end of file
diff --git a/tests/autodiff/property.slang b/tests/autodiff/property.slang
new file mode 100644
index 000000000..e15b9a75a
--- /dev/null
+++ b/tests/autodiff/property.slang
@@ -0,0 +1,48 @@
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type
+public struct ReadOnlyIndex
+{
+ private int _idx;
+ __init(int i) { _idx = i; }
+ public property int idx { get { return _idx; } }
+}
+struct GradientBuffer
+{
+ RWStructuredBuffer<float> primal;
+ StructuredBuffer<float> grad;
+
+ [Differentiable]
+ void write(int idx, float v) { primal[idx] = detach(v); }
+
+ [BackwardDerivativeOf(write)]
+ void write_bwd(int idx, inout DifferentialPair<float> d) { d = diffPair(d.p, grad[idx]); }
+
+ [Differentiable]
+ void store(ReadOnlyIndex idx, float v) { write(idx.idx, v); }
+}
+[Differentiable]
+void test(GradientBuffer buf, ReadOnlyIndex b, float x)
+{
+ buf.store(b, x);
+}
+public float repro(RWStructuredBuffer<float> primal, StructuredBuffer<float> grad)
+{
+ DifferentialPair<float> result = diffPair(1.0f);
+ GradientBuffer buf = { primal, grad };
+ bwd_diff(test)(buf, ReadOnlyIndex(5), result);
+ return result.d;
+}
+
+//TEST_INPUT: set output = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<float> output;
+
+//TEST_INPUT: set gPrimal = ubuffer(data=[0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0], stride=4)
+RWStructuredBuffer<float> gPrimal;
+//TEST_INPUT: set gGrad = ubuffer(data=[0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0], stride=4)
+StructuredBuffer<float> gGrad;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ // CHECK: 5.0
+ output[0] = repro(gPrimal, gGrad);
+} \ No newline at end of file
diff --git a/tests/autodiff/trivial-primal.slang b/tests/autodiff/trivial-primal.slang
new file mode 100644
index 000000000..d56c46399
--- /dev/null
+++ b/tests/autodiff/trivial-primal.slang
@@ -0,0 +1,41 @@
+//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK):-output-using-type
+
+struct GradientBuffer
+{
+ StructuredBuffer<float> grads;
+
+ [Differentiable]
+ void write(int idx, float value) { /* Discard write */ }
+
+ [BackwardDerivativeOf(write)]
+ void write_bwd(int idx, inout DifferentialPair<float> d)
+ {
+ d = diffPair(d.p, grads[idx]);
+ }
+}
+
+[Differentiable]
+void test(GradientBuffer dst, int idx, float v)
+{
+ dst.write(idx, v);
+}
+
+//TEST_INPUT: set grad_in = ubuffer(data=[101.0 102.0 103.0 104.0], stride=4)
+uniform StructuredBuffer<float> grad_in;
+
+//TEST_INPUT: set grad_out = ubuffer(data=[0 0 0 0], stride=4)
+uniform RWStructuredBuffer<float> grad_out;
+
+//TEST_INPUT: set output = out ubuffer(data=[0 0 0 0], stride=4)
+uniform RWStructuredBuffer<float> output;
+
+[shader("compute")]
+[numthreads(1,1,1)]
+void computeMain()
+{
+ GradientBuffer grads = { grad_in };
+ DifferentialPair<float> result = diffPair(1.0f);
+ bwd_diff(test)(grads, 0, result);
+ // CHECK: 101.0
+ output[0] = result.d; // Should return grad_in[0], but returns 0.0f instead
+} \ No newline at end of file