summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-07-21 16:28:22 -0400
committerGitHub <noreply@github.com>2023-07-21 13:28:22 -0700
commitb40b711f54748145ed1340f2a3aa626dcb42b699 (patch)
tree5c63286c13d55c79b4f21f899e8e338393049f8f /source/slang/slang-ir-check-differentiability.cpp
parent32043a48b6503fe3e493082c33eac02865503031 (diff)
Fix data-flow analysis not propagating diff property through differentiable calls (#3010)
* Add test for nodiff diagnostic for non-diff call propagated through diff call * Add logic to disambiguate calls to differentiable and non-differentiable methods * Add expected results for test * Simplify test * Update slang-ir-check-differentiability.cpp * Added comments for TreatAsDifferentiableExpr flavors --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp20
1 files changed, 14 insertions, 6 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index d4b93be5e..3207e0729 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -71,6 +71,15 @@ public:
return false;
}
+ bool shouldTreatCallAsDifferentiable(IRInst* callInst)
+ {
+ SLANG_ASSERT(as<IRCall>(callInst));
+
+ return (
+ callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>() ||
+ callInst->findDecoration<IRDifferentiableCallDecoration>());
+ }
+
bool isDifferentiableFunc(IRInst* func, DifferentiableLevel level)
{
switch (func->getOp())
@@ -300,7 +309,7 @@ public:
case kIROp_FloatLit:
return true;
case kIROp_Call:
- return inst->findDecoration<IRTreatAsDifferentiableDecoration>() ||
+ return shouldTreatCallAsDifferentiable(inst) ||
isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) && isDifferentiableType(diffTypeContext, inst->getFullType());
case kIROp_Load:
// We don't have more knowledge on whether diff is available at the destination address.
@@ -330,7 +339,7 @@ public:
case kIROp_DetachDerivative:
return false;
case kIROp_Call:
- if (inst->findDecoration<IRTreatAsDifferentiableDecoration>())
+ if (shouldTreatCallAsDifferentiable(inst))
return false;
return isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) &&
isDifferentiableType(diffTypeContext, inst->getFullType());
@@ -451,7 +460,8 @@ public:
// If inst's type is differentiable, and it is in expectDiffInstWorkList,
// then some user is expecting the result of the call to produce a derivative.
// In this case we need to issue a diagnostic.
- if (isDifferentiableType(diffTypeContext, inst->getFullType()))
+ if (isDifferentiableType(diffTypeContext, inst->getFullType()) &&
+ !isDifferentiableFunc(call->getCallee(), requiredDiffLevel))
{
sink->diagnose(
inst,
@@ -490,9 +500,7 @@ public:
case kIROp_Call:
{
auto callInst = as<IRCall>(inst);
- if (callInst->findDecoration<IRTreatAsDifferentiableDecoration>())
- continue;
- if (!isDifferentiableFunc(callInst->getCallee(), DifferentiableLevel::Forward))
+ if (callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>())
continue;
auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType());
if (!calleeFuncType) continue;