diff options
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 21 |
1 files changed, 14 insertions, 7 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 9001295e0..7f3c7bf01 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -487,18 +487,25 @@ public: { if (auto call = as<IRCall>(inst)) { + const auto callee = call->getCallee(); // If inst's type is differentiable, and it is in expectDiffInstWorkList, // then some user is expecting the result of the call to produce a derivative. // In this case we need to issue a diagnostic. if (isDifferentiableType(diffTypeContext, inst->getFullType()) && - !isDifferentiableFunc(call->getCallee(), requiredDiffLevel)) + !isDifferentiableFunc(callee, requiredDiffLevel)) { - sink->diagnose( - inst, - Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, - getResolvedInstForDecorations(call->getCallee()), - requiredDiffLevel == DifferentiableLevel::Forward ? "forward" - : "backward"); + // No need to fail here if the function is no_diff in + // both inputs and all outputs, this is equivalent of + // inserting no_diff on this inst. + if (!isNeverDiffFuncType(cast<IRFuncType>(callee->getDataType()))) + { + sink->diagnose( + inst, + Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, + getResolvedInstForDecorations(call->getCallee()), + requiredDiffLevel == DifferentiableLevel::Forward ? "forward" + : "backward"); + } } } } |
