From b40b711f54748145ed1340f2a3aa626dcb42b699 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 21 Jul 2023 16:28:22 -0400 Subject: Fix data-flow analysis not propagating diff property through differentiable calls (#3010) * Add test for nodiff diagnostic for non-diff call propagated through diff call * Add logic to disambiguate calls to differentiable and non-differentiable methods * Add expected results for test * Simplify test * Update slang-ir-check-differentiability.cpp * Added comments for TreatAsDifferentiableExpr flavors --------- Co-authored-by: Yong He --- source/slang/slang-ir-check-differentiability.cpp | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) (limited to 'source/slang/slang-ir-check-differentiability.cpp') diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index d4b93be5e..3207e0729 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -71,6 +71,15 @@ public: return false; } + bool shouldTreatCallAsDifferentiable(IRInst* callInst) + { + SLANG_ASSERT(as(callInst)); + + return ( + callInst->findDecoration() || + callInst->findDecoration()); + } + bool isDifferentiableFunc(IRInst* func, DifferentiableLevel level) { switch (func->getOp()) @@ -300,7 +309,7 @@ public: case kIROp_FloatLit: return true; case kIROp_Call: - return inst->findDecoration() || + return shouldTreatCallAsDifferentiable(inst) || isDifferentiableFunc(as(inst)->getCallee(), requiredDiffLevel) && isDifferentiableType(diffTypeContext, inst->getFullType()); case kIROp_Load: // We don't have more knowledge on whether diff is available at the destination address. @@ -330,7 +339,7 @@ public: case kIROp_DetachDerivative: return false; case kIROp_Call: - if (inst->findDecoration()) + if (shouldTreatCallAsDifferentiable(inst)) return false; return isDifferentiableFunc(as(inst)->getCallee(), requiredDiffLevel) && isDifferentiableType(diffTypeContext, inst->getFullType()); @@ -451,7 +460,8 @@ public: // If inst's type is differentiable, and it is in expectDiffInstWorkList, // then some user is expecting the result of the call to produce a derivative. // In this case we need to issue a diagnostic. - if (isDifferentiableType(diffTypeContext, inst->getFullType())) + if (isDifferentiableType(diffTypeContext, inst->getFullType()) && + !isDifferentiableFunc(call->getCallee(), requiredDiffLevel)) { sink->diagnose( inst, @@ -490,9 +500,7 @@ public: case kIROp_Call: { auto callInst = as(inst); - if (callInst->findDecoration()) - continue; - if (!isDifferentiableFunc(callInst->getCallee(), DifferentiableLevel::Forward)) + if (callInst->findDecoration()) continue; auto calleeFuncType = as(callInst->getCallee()->getFullType()); if (!calleeFuncType) continue; -- cgit v1.2.3