summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-09-26 00:10:45 -0400
committerGitHub <noreply@github.com>2024-09-26 00:10:45 -0400
commit7398e1e09312ed4e19195e060de9a2c9a073fcc1 (patch)
treed5622ffa3095e156f9ada816146d260137145cfd
parentd752482c9223eef8deebb0d8f0b13ce9679781c4 (diff)
Always run AD cleanup pass. (#5157)
-rw-r--r--source/slang/slang-emit.cpp10
-rw-r--r--source/slang/slang-ir-autodiff.cpp2
-rw-r--r--tests/autodiff/no-diff-strip.slang45
-rw-r--r--tests/autodiff/no-diff-strip.slang.expected.txt6
4 files changed, 57 insertions, 6 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 71ef7ee33..a29142ba1 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -834,12 +834,10 @@ Result linkAndOptimizeIR(
if (codeGenContext->shouldReportCheckpointIntermediates())
reportCheckpointIntermediates(codeGenContext, sink, irModule);
- if (requiredLoweringPassSet.autodiff)
- finalizeAutoDiffPass(targetProgram, irModule);
-
- // Remove auto-diff related decorations.
- // We may have an autodiff decoration regardless of if autodiff is being used.
- stripAutoDiffDecorations(irModule);
+ // Finalization is always run so AD-related instructions can be removed,
+ // even the AD pass itself is not run.
+ //
+ finalizeAutoDiffPass(targetProgram, irModule);
finalizeSpecialization(irModule);
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 94a605a68..6c729ea63 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -2559,6 +2559,8 @@ bool finalizeAutoDiffPass(TargetProgram* target, IRModule* module)
stripNoDiffTypeAttribute(module);
+ stripAutoDiffDecorations(module);
+
return modified;
}
diff --git a/tests/autodiff/no-diff-strip.slang b/tests/autodiff/no-diff-strip.slang
new file mode 100644
index 000000000..ebcdd7972
--- /dev/null
+++ b/tests/autodiff/no-diff-strip.slang
@@ -0,0 +1,45 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+// This test just checks that we compile and run successfully.
+// "NDBuffer<float, 2>" & "no_diff NDBuffer<float, 2>" should resolve to the same code-gen type.
+//
+
+struct NDBuffer<T, let N : int>
+{
+ RWStructuredBuffer<T> buffer;
+ int[N] strides;
+ int[N] transform;
+
+ T get(int[N] index) { return buffer[index[0]]; }
+}
+
+float _read_slice(int2 index, NDBuffer<float, 2> texture)
+{
+ return texture.get({index.x, index.y});
+}
+
+[Differentiable]
+void _trampoline(no_diff in vector<int,2> index, in no_diff NDBuffer<float, 2> texture, no_diff out float _result)
+{
+ _result = _read_slice(index, texture);
+}
+
+
+[shader("compute")]
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ NDBuffer<float, 2> texture;
+ texture.buffer = outputBuffer;
+ texture.strides = {1, 1};
+
+ float result;
+ _trampoline({dispatchThreadID.x, dispatchThreadID.y}, texture, result);
+ outputBuffer[dispatchThreadID.x] = result;
+}
+
diff --git a/tests/autodiff/no-diff-strip.slang.expected.txt b/tests/autodiff/no-diff-strip.slang.expected.txt
new file mode 100644
index 000000000..e070cf84d
--- /dev/null
+++ b/tests/autodiff/no-diff-strip.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+0.000000
+0.000000
+0.000000
+0.000000
+0.000000