diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 21 |
3 files changed, 29 insertions, 7 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index afd698e8b..f70e30b72 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -147,6 +147,20 @@ bool isNoDiffType(IRType* paramType) return false; } +// Return true if the result type and all the parameter types are no_diff +bool isNeverDiffFuncType(IRFuncType* const funcType) +{ + const auto resultType = funcType->getResultType(); + if (!isNoDiffType(resultType)) + return false; + for (const auto p : funcType->getParamTypes()) + { + if (!isNoDiffType(p)) + return false; + } + return true; +} + IRInst* lookupForwardDerivativeReference(IRInst* primalFunction) { if (auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>()) diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 2cd08eb28..befd1f98a 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -563,6 +563,7 @@ void stripAutoDiffDecorations(IRModule* module); void stripTempDecorations(IRInst* inst); bool isNoDiffType(IRType* paramType); +bool isNeverDiffFuncType(IRFuncType* funcType); IRInst* lookupForwardDerivativeReference(IRInst* primalFunction); diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 9001295e0..7f3c7bf01 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -487,18 +487,25 @@ public: { if (auto call = as<IRCall>(inst)) { + const auto callee = call->getCallee(); // If inst's type is differentiable, and it is in expectDiffInstWorkList, // then some user is expecting the result of the call to produce a derivative. // In this case we need to issue a diagnostic. if (isDifferentiableType(diffTypeContext, inst->getFullType()) && - !isDifferentiableFunc(call->getCallee(), requiredDiffLevel)) + !isDifferentiableFunc(callee, requiredDiffLevel)) { - sink->diagnose( - inst, - Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, - getResolvedInstForDecorations(call->getCallee()), - requiredDiffLevel == DifferentiableLevel::Forward ? "forward" - : "backward"); + // No need to fail here if the function is no_diff in + // both inputs and all outputs, this is equivalent of + // inserting no_diff on this inst. + if (!isNeverDiffFuncType(cast<IRFuncType>(callee->getDataType()))) + { + sink->diagnose( + inst, + Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, + getResolvedInstForDecorations(call->getCallee()), + requiredDiffLevel == DifferentiableLevel::Forward ? "forward" + : "backward"); + } } } } |
