summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-check-differentiability.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-01 14:18:57 -0800
committerGitHub <noreply@github.com>2023-02-01 14:18:57 -0800
commitbbd1e1786401bb88c34802b987d4da72e2364503 (patch)
tree99a4be95ae517fd710fc032a1debdac917dd3ac2 /source/slang/slang-ir-check-differentiability.cpp
parentc5895fb0b82fd14fbe45b58d5fc7f75d67625d15 (diff)
Support `out` parameters in backward differentiation. (#2619)
* Support `out` parameters in backward differentiation. * Fixes. * Fix cleanup. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-check-differentiability.cpp')
-rw-r--r--source/slang/slang-ir-check-differentiability.cpp30
1 files changed, 4 insertions, 26 deletions
diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp
index ce3e563f5..e7d5a0e5c 100644
--- a/source/slang/slang-ir-check-differentiability.cpp
+++ b/source/slang/slang-ir-check-differentiability.cpp
@@ -5,28 +5,6 @@
namespace Slang
{
-IRInst* getSpecializedVal(IRInst* inst)
-{
- int loopLimit = 1024;
- while (inst && inst->getOp() == kIROp_Specialize)
- {
- inst = as<IRSpecialize>(inst)->getBase();
- loopLimit--;
- if (loopLimit == 0)
- return inst;
- }
- return inst;
-}
-
-IRInst* getLeafFunc(IRInst* func)
-{
- func = getSpecializedVal(func);
- if (!func)
- return nullptr;
- if (auto genericFunc = as<IRGeneric>(func))
- return findInnerMostGenericReturnVal(genericFunc);
- return func;
-}
struct CheckDifferentiabilityPassContext : public InstPassBase
{
@@ -47,7 +25,7 @@ public:
bool _isFuncMarkedForAutoDiff(IRInst* func)
{
- func = getLeafFunc(func);
+ func = getResolvedInstForDecorations(func);
if (!func)
return false;
for (auto decorations : func->getDecorations())
@@ -65,7 +43,7 @@ public:
bool _isDifferentiableFuncImpl(IRInst* func, DifferentiableLevel level)
{
- func = getLeafFunc(func);
+ func = getResolvedInstForDecorations(func);
if (!func)
return false;
@@ -103,7 +81,7 @@ public:
}
}
- func = getLeafFunc(func);
+ func = getResolvedInstForDecorations(func);
if (!func)
return false;
@@ -332,7 +310,7 @@ public:
sink->diagnose(
inst,
Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction,
- getLeafFunc(call->getCallee()),
+ getResolvedInstForDecorations(call->getCallee()),
requiredDiffLevel == DifferentiableLevel::Forward ? "forward" : "backward");
}
}