summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2025-04-04 15:25:20 +0800
committerGitHub <noreply@github.com>2025-04-04 15:25:20 +0800
commit83a42cb76feb1f702ff730040f359cabc01c571a (patch)
tree442ff945665ace3bbdafaf410f2664825419fd7d
parent4233d69cf88f1623cb573c8edb61456b24dc5339 (diff)
Do no fail on missing no_diff annotation on non-differentiable (inputs and output) function outputs (#6737)
Closes https://github.com/shader-slang/slang/issues/6632
-rw-r--r--source/slang/slang-ir-autodiff.cpp14
-rw-r--r--source/slang/slang-ir-autodiff.h1
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp21
-rw-r--r--tests/bugs/gh-6632.slang29
4 files changed, 58 insertions, 7 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index afd698e8b..f70e30b72 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -147,6 +147,20 @@ bool isNoDiffType(IRType* paramType)
return false;
}
+// Return true if the result type and all the parameter types are no_diff
+bool isNeverDiffFuncType(IRFuncType* const funcType)
+{
+ const auto resultType = funcType->getResultType();
+ if (!isNoDiffType(resultType))
+ return false;
+ for (const auto p : funcType->getParamTypes())
+ {
+ if (!isNoDiffType(p))
+ return false;
+ }
+ return true;
+}
+
IRInst* lookupForwardDerivativeReference(IRInst* primalFunction)
{
if (auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>())
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index 2cd08eb28..befd1f98a 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -563,6 +563,7 @@ void stripAutoDiffDecorations(IRModule* module);
void stripTempDecorations(IRInst* inst);
bool isNoDiffType(IRType* paramType);
+bool isNeverDiffFuncType(IRFuncType* funcType);
IRInst* lookupForwardDerivativeReference(IRInst* primalFunction);
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 9001295e0..7f3c7bf01 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -487,18 +487,25 @@ public:
{
if (auto call = as<IRCall>(inst))
{
+ const auto callee = call->getCallee();
// If inst's type is differentiable, and it is in expectDiffInstWorkList,
// then some user is expecting the result of the call to produce a derivative.
// In this case we need to issue a diagnostic.
if (isDifferentiableType(diffTypeContext, inst->getFullType()) &&
- !isDifferentiableFunc(call->getCallee(), requiredDiffLevel))
+ !isDifferentiableFunc(callee, requiredDiffLevel))
{
- sink->diagnose(
- inst,
- Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction,
- getResolvedInstForDecorations(call->getCallee()),
- requiredDiffLevel == DifferentiableLevel::Forward ? "forward"
- : "backward");
+ // No need to fail here if the function is no_diff in
+ // both inputs and all outputs, this is equivalent of
+ // inserting no_diff on this inst.
+ if (!isNeverDiffFuncType(cast<IRFuncType>(callee->getDataType())))
+ {
+ sink->diagnose(
+ inst,
+ Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction,
+ getResolvedInstForDecorations(call->getCallee()),
+ requiredDiffLevel == DifferentiableLevel::Forward ? "forward"
+ : "backward");
+ }
}
}
}
diff --git a/tests/bugs/gh-6632.slang b/tests/bugs/gh-6632.slang
new file mode 100644
index 000000000..f9240e63b
--- /dev/null
+++ b/tests/bugs/gh-6632.slang
@@ -0,0 +1,29 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu
+
+// CHECK: 40000000
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+
+// Non-differentiable function with no_diff parameters and return type
+no_diff float targetFunc(no_diff float x)
+{
+ return x * 2.0f;
+}
+
+[Differentiable]
+float errorForward(no_diff float x)
+{
+ float result = targetFunc(x);
+ return result;
+}
+
+RWStructuredBuffer<float> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ float input = 1.0f;
+
+ outputBuffer[0] = errorForward(input);
+}
+