diff options
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
| -rw-r--r-- | source/slang/slang-ir-check-differentiability.cpp | 30 |
1 files changed, 4 insertions, 26 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index ce3e563f5..e7d5a0e5c 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -5,28 +5,6 @@ namespace Slang { -IRInst* getSpecializedVal(IRInst* inst) -{ - int loopLimit = 1024; - while (inst && inst->getOp() == kIROp_Specialize) - { - inst = as<IRSpecialize>(inst)->getBase(); - loopLimit--; - if (loopLimit == 0) - return inst; - } - return inst; -} - -IRInst* getLeafFunc(IRInst* func) -{ - func = getSpecializedVal(func); - if (!func) - return nullptr; - if (auto genericFunc = as<IRGeneric>(func)) - return findInnerMostGenericReturnVal(genericFunc); - return func; -} struct CheckDifferentiabilityPassContext : public InstPassBase { @@ -47,7 +25,7 @@ public: bool _isFuncMarkedForAutoDiff(IRInst* func) { - func = getLeafFunc(func); + func = getResolvedInstForDecorations(func); if (!func) return false; for (auto decorations : func->getDecorations()) @@ -65,7 +43,7 @@ public: bool _isDifferentiableFuncImpl(IRInst* func, DifferentiableLevel level) { - func = getLeafFunc(func); + func = getResolvedInstForDecorations(func); if (!func) return false; @@ -103,7 +81,7 @@ public: } } - func = getLeafFunc(func); + func = getResolvedInstForDecorations(func); if (!func) return false; @@ -332,7 +310,7 @@ public: sink->diagnose( inst, Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, - getLeafFunc(call->getCallee()), + getResolvedInstForDecorations(call->getCallee()), requiredDiffLevel == DifferentiableLevel::Forward ? "forward" : "backward"); } } |
