summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2025-04-04 15:25:20 +0800
committerGitHub <noreply@github.com>2025-04-04 15:25:20 +0800
commit83a42cb76feb1f702ff730040f359cabc01c571a (patch)
tree442ff945665ace3bbdafaf410f2664825419fd7d /source
parent4233d69cf88f1623cb573c8edb61456b24dc5339 (diff)
Do no fail on missing no_diff annotation on non-differentiable (inputs and output) function outputs (#6737)
Closes https://github.com/shader-slang/slang/issues/6632
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff.cpp14
-rw-r--r--source/slang/slang-ir-autodiff.h1
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp21
3 files changed, 29 insertions, 7 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index afd698e8b..f70e30b72 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -147,6 +147,20 @@ bool isNoDiffType(IRType* paramType)
return false;
}
+// Return true if the result type and all the parameter types are no_diff
+bool isNeverDiffFuncType(IRFuncType* const funcType)
+{
+ const auto resultType = funcType->getResultType();
+ if (!isNoDiffType(resultType))
+ return false;
+ for (const auto p : funcType->getParamTypes())
+ {
+ if (!isNoDiffType(p))
+ return false;
+ }
+ return true;
+}
+
IRInst* lookupForwardDerivativeReference(IRInst* primalFunction)
{
if (auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>())
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index 2cd08eb28..befd1f98a 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -563,6 +563,7 @@ void stripAutoDiffDecorations(IRModule* module);
void stripTempDecorations(IRInst* inst);
bool isNoDiffType(IRType* paramType);
+bool isNeverDiffFuncType(IRFuncType* funcType);
IRInst* lookupForwardDerivativeReference(IRInst* primalFunction);
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 9001295e0..7f3c7bf01 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -487,18 +487,25 @@ public:
{
if (auto call = as<IRCall>(inst))
{
+ const auto callee = call->getCallee();
// If inst's type is differentiable, and it is in expectDiffInstWorkList,
// then some user is expecting the result of the call to produce a derivative.
// In this case we need to issue a diagnostic.
if (isDifferentiableType(diffTypeContext, inst->getFullType()) &&
- !isDifferentiableFunc(call->getCallee(), requiredDiffLevel))
+ !isDifferentiableFunc(callee, requiredDiffLevel))
{
- sink->diagnose(
- inst,
- Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction,
- getResolvedInstForDecorations(call->getCallee()),
- requiredDiffLevel == DifferentiableLevel::Forward ? "forward"
- : "backward");
+ // No need to fail here if the function is no_diff in
+ // both inputs and all outputs, this is equivalent of
+ // inserting no_diff on this inst.
+ if (!isNeverDiffFuncType(cast<IRFuncType>(callee->getDataType())))
+ {
+ sink->diagnose(
+ inst,
+ Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction,
+ getResolvedInstForDecorations(call->getCallee()),
+ requiredDiffLevel == DifferentiableLevel::Forward ? "forward"
+ : "backward");
+ }
}
}
}