summaryrefslogtreecommitdiffstats
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/autodiff/reverse-continue-loop.slang2
-rw-r--r--tests/autodiff/reverse-loop-immediate-return.slang59
2 files changed, 60 insertions, 1 deletions
diff --git a/tests/autodiff/reverse-continue-loop.slang b/tests/autodiff/reverse-continue-loop.slang
index 2dfad0a61..51f17b611 100644
--- a/tests/autodiff/reverse-continue-loop.slang
+++ b/tests/autodiff/reverse-continue-loop.slang
@@ -16,7 +16,7 @@ float test_loop_with_continue(float y)
//CHK-DAG: note: 20 bytes (FixedArray<float, 5> ) used to checkpoint the following item:
float t = y;
- //CHK-DAG: note: 4 bytes (int32_t) used for a loop counter here:
+ //CHK-DAG: note: 4 bytes (int32_t) used to checkpoint the following item:
for (int i = 0; i < 3; i++)
{
if (t > 4.0)
diff --git a/tests/autodiff/reverse-loop-immediate-return.slang b/tests/autodiff/reverse-loop-immediate-return.slang
new file mode 100644
index 000000000..121836115
--- /dev/null
+++ b/tests/autodiff/reverse-loop-immediate-return.slang
@@ -0,0 +1,59 @@
+
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK): -slang -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+
+[BackwardDerivative(set_bwd)]
+void set(uint idx, float x)
+{
+ outputBuffer[idx] = x;
+}
+
+void set_bwd(uint idx, inout DifferentialPair<float> x)
+{
+ // For debugging, we'll set the derivative to 1.0
+ x = DifferentialPair<float>(x.p, 1.0f);
+}
+
+[Differentiable]
+void run(
+ uint idx,
+ float x)
+{
+ if (idx >= 1) return;
+
+ if (idx == 0)
+ { }
+
+ for (int i = 0; i < 1; i++)
+ {
+ if (idx > 0)
+ {
+ return;
+ }
+
+ if (idx == 0)
+ {
+ x = x * 2.0f;
+ }
+ }
+
+ if (idx == 0)
+ { }
+
+ set(idx, x);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // bwd_diff
+ DifferentialPair<float> dpa = DifferentialPair<float>(1.0, 0.0);
+ bwd_diff(run)(dispatchThreadID.x, dpa);
+ outputBuffer[dispatchThreadID.x] = dpa.d;
+
+ // CHECK: type: float
+ // CHECK: 2.0
+}