summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-02-05 13:49:19 -0800
committerGitHub <noreply@github.com>2025-02-05 13:49:19 -0800
commitf6cbb81e1c0080518185294ee94705f5e93aa849 (patch)
treec7f7271559e0f468e2ea892d2edaa2b0bbace30c
parent7911c9437333692db275d2dff41264f4c8023be8 (diff)
Fix DCE for calls to functions that have associations (#6272)
* Fix DCE for calls to functions that have associations * Update slang-ir-util.cpp * Update slang-ir-util.cpp
-rw-r--r--source/slang/slang-ir-util.cpp82
-rw-r--r--tests/autodiff/custom-diff-empty-func.slang35
2 files changed, 115 insertions, 2 deletions
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index dd7819e1a..a3cf28a68 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -1215,8 +1215,65 @@ bool isSideEffectFreeFunctionalCall(IRCall* call, SideEffectAnalysisOptions opti
return false;
}
+// Enumerate any associated functions of 'func'
+// that might be used by a pass (e.g. auto-diff)
+//
+template<typename TFunc>
+void forEachAssociatedFunction(IRInst* func, TFunc callback)
+{
+ // Resolve the function to get all its decorations
+ auto resolvedFunc = getResolvedInstForDecorations(func);
+ if (!resolvedFunc)
+ return;
+
+ // We'll scan for appropriate decorations and return
+ // the function references.
+ //
+ // TODO: In the future, as we get more function transformation
+ // passes, we might want to create a parent class for such
+ // decorations that associate functions with each other.
+ //
+ for (auto decor : resolvedFunc->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_UserDefinedBackwardDerivativeDecoration:
+ if (as<IRUserDefinedBackwardDerivativeDecoration>(decor))
+ {
+ auto associatedCallee = as<IRUserDefinedBackwardDerivativeDecoration>(decor)
+ ->getBackwardDerivativeFunc();
+ callback(associatedCallee);
+ }
+ break;
+
+ case kIROp_ForwardDerivativeDecoration:
+ if (as<IRForwardDerivativeDecoration>(decor))
+ {
+ auto associatedCallee =
+ as<IRForwardDerivativeDecoration>(decor)->getForwardDerivativeFunc();
+ callback(associatedCallee);
+ }
+ break;
+
+ case kIROp_PrimalSubstituteDecoration:
+ if (as<IRPrimalSubstituteDecoration>(decor))
+ {
+ auto associatedCallee =
+ as<IRPrimalSubstituteDecoration>(decor)->getPrimalSubstituteFunc();
+ callback(associatedCallee);
+ }
+ break;
+
+ default:
+ break;
+ }
+ }
+}
+
bool doesCalleeHaveSideEffect(IRInst* callee)
{
+ bool sideEffect = true;
+
for (auto decor : getResolvedInstForDecorations(callee)->getDecorations())
{
switch (decor->getOp())
@@ -1224,10 +1281,31 @@ bool doesCalleeHaveSideEffect(IRInst* callee)
case kIROp_NoSideEffectDecoration:
case kIROp_ReadNoneDecoration:
case kIROp_IgnoreSideEffectsDecoration:
- return false;
+ sideEffect = false;
+ break;
+ default:
+ break;
}
}
- return true;
+
+ // If the callee has no side effect, check if any of its associated functions have side effect.
+ // If so, we want to keep the callee around.
+ //
+ // Typically, once the relevant pass has completed, the association is removed,
+ // and at that point we can remove the function.
+ //
+ if (!sideEffect)
+ {
+ forEachAssociatedFunction(
+ callee,
+ [&](IRInst* associatedCallee)
+ {
+ sideEffect |= doesCalleeHaveSideEffect(associatedCallee);
+ return;
+ });
+ }
+
+ return sideEffect;
}
IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key)
diff --git a/tests/autodiff/custom-diff-empty-func.slang b/tests/autodiff/custom-diff-empty-func.slang
new file mode 100644
index 000000000..c566d9e19
--- /dev/null
+++ b/tests/autodiff/custom-diff-empty-func.slang
@@ -0,0 +1,35 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type -g0
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+void foo_bwd(float a, inout DifferentialPair<float> dpx)
+{
+ outputBuffer[2] = 2.f;
+}
+
+[Differentiable, BackwardDerivative(foo_bwd)]
+void foo(no_diff float a, float x)
+{ }
+
+[Differentiable]
+float outerFunc(no_diff float a, float x)
+{
+ foo(a, x);
+ return 1.f;
+}
+
+[numthreads(1, 1, 1)]
+[shader("compute")]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ float a = 10.0;
+ DifferentialPair<float> dpx = DifferentialPair<float>(4.f, 1.f);
+ bwd_diff(outerFunc)(a, dpx, 1.0);
+
+ // CHECK: type: float
+ // CHECK: 0.0
+ // CHECK: 0.0
+ // CHECK: 2.0
+ // CHECK: 0.0
+} \ No newline at end of file