diff options
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index d4b93be5e..3207e0729 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -71,6 +71,15 @@ public: return false; } + bool shouldTreatCallAsDifferentiable(IRInst* callInst) + { + SLANG_ASSERT(as<IRCall>(callInst)); + + return ( + callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>() || + callInst->findDecoration<IRDifferentiableCallDecoration>()); + } + bool isDifferentiableFunc(IRInst* func, DifferentiableLevel level) { switch (func->getOp()) @@ -300,7 +309,7 @@ public: case kIROp_FloatLit: return true; case kIROp_Call: - return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || + return shouldTreatCallAsDifferentiable(inst) || isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) && isDifferentiableType(diffTypeContext, inst->getFullType()); case kIROp_Load: // We don't have more knowledge on whether diff is available at the destination address. @@ -330,7 +339,7 @@ public: case kIROp_DetachDerivative: return false; case kIROp_Call: - if (inst->findDecoration<IRTreatAsDifferentiableDecoration>()) + if (shouldTreatCallAsDifferentiable(inst)) return false; return isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) && isDifferentiableType(diffTypeContext, inst->getFullType()); @@ -451,7 +460,8 @@ public: // If inst's type is differentiable, and it is in expectDiffInstWorkList, // then some user is expecting the result of the call to produce a derivative. // In this case we need to issue a diagnostic. - if (isDifferentiableType(diffTypeContext, inst->getFullType())) + if (isDifferentiableType(diffTypeContext, inst->getFullType()) && + !isDifferentiableFunc(call->getCallee(), requiredDiffLevel)) { sink->diagnose( inst, @@ -490,9 +500,7 @@ public: case kIROp_Call: { auto callInst = as<IRCall>(inst); - if (callInst->findDecoration<IRTreatAsDifferentiableDecoration>()) - continue; - if (!isDifferentiableFunc(callInst->getCallee(), DifferentiableLevel::Forward)) + if (callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>()) continue; auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType()); if (!calleeFuncType) continue; |
