diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-07-21 16:28:22 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-07-21 13:28:22 -0700 |
| commit | b40b711f54748145ed1340f2a3aa626dcb42b699 (patch) | |
| tree | 5c63286c13d55c79b4f21f899e8e338393049f8f /source/slang/slang-ir-check-differentiability.cpp | |
| parent | 32043a48b6503fe3e493082c33eac02865503031 (diff) | |
Fix data-flow analysis not propagating diff property through differentiable calls (#3010)
* Add test for nodiff diagnostic for non-diff call propagated through diff call
* Add logic to disambiguate calls to differentiable and non-differentiable methods
* Add expected results for test
* Simplify test
* Update slang-ir-check-differentiability.cpp
* Added comments for TreatAsDifferentiableExpr flavors
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index d4b93be5e..3207e0729 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -71,6 +71,15 @@ public: return false; } + bool shouldTreatCallAsDifferentiable(IRInst* callInst) + { + SLANG_ASSERT(as<IRCall>(callInst)); + + return ( + callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>() || + callInst->findDecoration<IRDifferentiableCallDecoration>()); + } + bool isDifferentiableFunc(IRInst* func, DifferentiableLevel level) { switch (func->getOp()) @@ -300,7 +309,7 @@ public: case kIROp_FloatLit: return true; case kIROp_Call: - return inst->findDecoration<IRTreatAsDifferentiableDecoration>() || + return shouldTreatCallAsDifferentiable(inst) || isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) && isDifferentiableType(diffTypeContext, inst->getFullType()); case kIROp_Load: // We don't have more knowledge on whether diff is available at the destination address. @@ -330,7 +339,7 @@ public: case kIROp_DetachDerivative: return false; case kIROp_Call: - if (inst->findDecoration<IRTreatAsDifferentiableDecoration>()) + if (shouldTreatCallAsDifferentiable(inst)) return false; return isDifferentiableFunc(as<IRCall>(inst)->getCallee(), requiredDiffLevel) && isDifferentiableType(diffTypeContext, inst->getFullType()); @@ -451,7 +460,8 @@ public: // If inst's type is differentiable, and it is in expectDiffInstWorkList, // then some user is expecting the result of the call to produce a derivative. // In this case we need to issue a diagnostic. - if (isDifferentiableType(diffTypeContext, inst->getFullType())) + if (isDifferentiableType(diffTypeContext, inst->getFullType()) && + !isDifferentiableFunc(call->getCallee(), requiredDiffLevel)) { sink->diagnose( inst, @@ -490,9 +500,7 @@ public: case kIROp_Call: { auto callInst = as<IRCall>(inst); - if (callInst->findDecoration<IRTreatAsDifferentiableDecoration>()) - continue; - if (!isDifferentiableFunc(callInst->getCallee(), DifferentiableLevel::Forward)) + if (callInst->findDecoration<IRTreatCallAsDifferentiableDecoration>()) continue; auto calleeFuncType = as<IRFuncType>(callInst->getCallee()->getFullType()); if (!calleeFuncType) continue; |
