summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp21
1 files changed, 14 insertions, 7 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index 9001295e0..7f3c7bf01 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -487,18 +487,25 @@ public:
{
if (auto call = as<IRCall>(inst))
{
+ const auto callee = call->getCallee();
// If inst's type is differentiable, and it is in expectDiffInstWorkList,
// then some user is expecting the result of the call to produce a derivative.
// In this case we need to issue a diagnostic.
if (isDifferentiableType(diffTypeContext, inst->getFullType()) &&
- !isDifferentiableFunc(call->getCallee(), requiredDiffLevel))
+ !isDifferentiableFunc(callee, requiredDiffLevel))
{
- sink->diagnose(
- inst,
- Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction,
- getResolvedInstForDecorations(call->getCallee()),
- requiredDiffLevel == DifferentiableLevel::Forward ? "forward"
- : "backward");
+ // No need to fail here if the function is no_diff in
+ // both inputs and all outputs, this is equivalent of
+ // inserting no_diff on this inst.
+ if (!isNeverDiffFuncType(cast<IRFuncType>(callee->getDataType())))
+ {
+ sink->diagnose(
+ inst,
+ Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction,
+ getResolvedInstForDecorations(call->getCallee()),
+ requiredDiffLevel == DifferentiableLevel::Forward ? "forward"
+ : "backward");
+ }
}
}
}