From bbd1e1786401bb88c34802b987d4da72e2364503 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 1 Feb 2023 14:18:57 -0800 Subject: Support `out` parameters in backward differentiation. (#2619) * Support `out` parameters in backward differentiation. * Fixes. * Fix cleanup. --------- Co-authored-by: Yong He --- source/slang/slang-ir-check-differentiability.cpp | 30 +++-------------------- 1 file changed, 4 insertions(+), 26 deletions(-) (limited to 'source/slang/slang-ir-check-differentiability.cpp') diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index ce3e563f5..e7d5a0e5c 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -5,28 +5,6 @@ namespace Slang { -IRInst* getSpecializedVal(IRInst* inst) -{ - int loopLimit = 1024; - while (inst && inst->getOp() == kIROp_Specialize) - { - inst = as(inst)->getBase(); - loopLimit--; - if (loopLimit == 0) - return inst; - } - return inst; -} - -IRInst* getLeafFunc(IRInst* func) -{ - func = getSpecializedVal(func); - if (!func) - return nullptr; - if (auto genericFunc = as(func)) - return findInnerMostGenericReturnVal(genericFunc); - return func; -} struct CheckDifferentiabilityPassContext : public InstPassBase { @@ -47,7 +25,7 @@ public: bool _isFuncMarkedForAutoDiff(IRInst* func) { - func = getLeafFunc(func); + func = getResolvedInstForDecorations(func); if (!func) return false; for (auto decorations : func->getDecorations()) @@ -65,7 +43,7 @@ public: bool _isDifferentiableFuncImpl(IRInst* func, DifferentiableLevel level) { - func = getLeafFunc(func); + func = getResolvedInstForDecorations(func); if (!func) return false; @@ -103,7 +81,7 @@ public: } } - func = getLeafFunc(func); + func = getResolvedInstForDecorations(func); if (!func) return false; @@ -332,7 +310,7 @@ public: sink->diagnose( inst, Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, - getLeafFunc(call->getCallee()), + getResolvedInstForDecorations(call->getCallee()), requiredDiffLevel == DifferentiableLevel::Forward ? "forward" : "backward"); } } -- cgit v1.2.3