summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp20
1 files changed, 14 insertions, 6 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index d4b93be5e..3207e0729 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -71,6 +71,15 @@ public:
return false;
}
+ bool shouldTreatCallAsDifferentiable(IRInst* callInst)
+ {
+ SLANG_ASSERT(as<IRCall>(callInst));
+
+ return (
+ callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>() ||
+ callInst->findDecoration<IRDifferentiableCallDecoration>());
+ }
+
bool isDifferentiableFunc(IRInst* func, DifferentiableLevel level)
{
switch (func->getOp())
@@ -300,7 +309,7 @@ public:
case kIROp_FloatLit:
return true;
case kIROp_Call:
- return inst->findDecoration<IRTreatAsDifferentiableDecoration>() ||
+ return shouldTreatCallAsDifferentiable(inst) ||
isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) && isDifferentiableType(diffTypeContext, inst->getFullType());
case kIROp_Load:
// We don't have more knowledge on whether diff is available at the destination address.
@@ -330,7 +339,7 @@ public:
case kIROp_DetachDerivative:
return false;
case kIROp_Call:
- if (inst->findDecoration<IRTreatAsDifferentiableDecoration>())
+ if (shouldTreatCallAsDifferentiable(inst))
return false;
return isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) &&
isDifferentiableType(diffTypeContext, inst->getFullType());
@@ -451,7 +460,8 @@ public:
// If inst's type is differentiable, and it is in expectDiffInstWorkList,
// then some user is expecting the result of the call to produce a derivative.
// In this case we need to issue a diagnostic.
- if (isDifferentiableType(diffTypeContext, inst->getFullType()))
+ if (isDifferentiableType(diffTypeContext, inst->getFullType()) &&
+ !isDifferentiableFunc(call->getCallee(), requiredDiffLevel))
{
sink->diagnose(
inst,
@@ -490,9 +500,7 @@ public:
case kIROp_Call:
{
auto callInst = as<IRCall>(inst);
- if (callInst->findDecoration<IRTreatAsDifferentiableDecoration>())
- continue;
- if (!isDifferentiableFunc(callInst->getCallee(), DifferentiableLevel::Forward))
+ if (callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>())
continue;
auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType());
if (!calleeFuncType) continue;