summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-autodiff.cpp85
1 files changed, 85 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 7876d7eeb..a46b20b5f 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -3345,6 +3345,91 @@ struct AutoDiffPass : public InstPassBase
return true;
}
+ // This function will check whether a global inst can be gathered as a candidate for
+ // differentiable IR. Before we do the recursive search on every IR, we will first filter
+ // the global IR to find out some candidates, and then we will start the recursive search
+ // on this filtered list.
+ // For a generic, we will add it to the list only when it's a function and it's used by other
+ // function, because that is the case of dynamic dispatch.
+ bool isReachableInst(IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_ForwardDifferentiate:
+ case kIROp_BackwardDifferentiate:
+ case kIROp_BackwardDifferentiatePrimal:
+ case kIROp_BackwardDifferentiatePropagate:
+ case kIROp_BackwardDiffIntermediateContextType:
+ case kIROp_Func:
+ return true;
+ case kIROp_Generic:
+ // For generic, if it's a generic function and it's used by any other reachable
+ // inst, we will consider it reachable.
+ auto genericIR = as<IRGeneric>(inst);
+ if (as<IRFunc>(findInnerMostGenericReturnVal(genericIR)))
+ {
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (as<IRModuleInst>(user->parent))
+ return true;
+
+ for (; user; user = user->parent)
+ {
+ if (auto genericUser = as<IRGeneric>(user))
+ return (
+ as<IRFunc>(findInnerMostGenericReturnVal(genericUser)) != nullptr);
+
+ else if (as<IRFunc>(user))
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+ }
+
+ template<typename Func>
+ void processAllReachableInsts(const Func& f)
+ {
+ workList.clear();
+ workListSet.clear();
+
+ // We will do the first around of filter to include only functions and generic functions
+ for (auto child = module->getModuleInst()->getFirstChild(); child;
+ child = child->getNextInst())
+ {
+ if (isReachableInst(child))
+ {
+ addToWorkList(child);
+ }
+ }
+
+ while (workList.getCount() != 0)
+ {
+ IRInst* inst = pop(false);
+ f(inst);
+ for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
+ {
+ if (as<IRDecoration>(child))
+ break;
+ switch (child->getOp())
+ {
+ case kIROp_GenericSpecializationDictionary:
+ case kIROp_ExistentialFuncSpecializationDictionary:
+ case kIROp_ExistentialTypeSpecializationDictionary:
+ case kIROp_DebugInlinedAt:
+ case kIROp_DebugFunction:
+ continue;
+ default:
+ break;
+ }
+ SLANG_ASSERT(child);
+ if (shouldInstBeLiveIfParentIsLive(child, IRDeadCodeEliminationOptions()))
+ addToWorkList(child);
+ }
+ }
+ }
// Process all differentiate calls, and recursively generate code for forward and backward
// derivative functions.
//