summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp30
1 files changed, 4 insertions, 26 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index ce3e563f5..e7d5a0e5c 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -5,28 +5,6 @@
namespace Slang
{
-IRInst* getSpecializedVal(IRInst* inst)
-{
- int loopLimit = 1024;
- while (inst && inst->getOp() == kIROp_Specialize)
- {
- inst = as<IRSpecialize>(inst)->getBase();
- loopLimit--;
- if (loopLimit == 0)
- return inst;
- }
- return inst;
-}
-
-IRInst* getLeafFunc(IRInst* func)
-{
- func = getSpecializedVal(func);
- if (!func)
- return nullptr;
- if (auto genericFunc = as<IRGeneric>(func))
- return findInnerMostGenericReturnVal(genericFunc);
- return func;
-}
struct CheckDifferentiabilityPassContext : public InstPassBase
{
@@ -47,7 +25,7 @@ public:
bool _isFuncMarkedForAutoDiff(IRInst* func)
{
- func = getLeafFunc(func);
+ func = getResolvedInstForDecorations(func);
if (!func)
return false;
for (auto decorations : func->getDecorations())
@@ -65,7 +43,7 @@ public:
bool _isDifferentiableFuncImpl(IRInst* func, DifferentiableLevel level)
{
- func = getLeafFunc(func);
+ func = getResolvedInstForDecorations(func);
if (!func)
return false;
@@ -103,7 +81,7 @@ public:
}
}
- func = getLeafFunc(func);
+ func = getResolvedInstForDecorations(func);
if (!func)
return false;
@@ -332,7 +310,7 @@ public:
sink->diagnose(
inst,
Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction,
- getLeafFunc(call->getCallee()),
+ getResolvedInstForDecorations(call->getCallee()),
requiredDiffLevel == DifferentiableLevel::Forward ? "forward" : "backward");
}
}