diff options
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 7876d7eeb..a46b20b5f 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -3345,6 +3345,91 @@ struct AutoDiffPass : public InstPassBase return true; } + // This function will check whether a global inst can be gathered as a candidate for + // differentiable IR. Before we do the recursive search on every IR, we will first filter + // the global IR to find out some candidates, and then we will start the recursive search + // on this filtered list. + // For a generic, we will add it to the list only when it's a function and it's used by other + // function, because that is the case of dynamic dispatch. + bool isReachableInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_ForwardDifferentiate: + case kIROp_BackwardDifferentiate: + case kIROp_BackwardDifferentiatePrimal: + case kIROp_BackwardDifferentiatePropagate: + case kIROp_BackwardDiffIntermediateContextType: + case kIROp_Func: + return true; + case kIROp_Generic: + // For generic, if it's a generic function and it's used by any other reachable + // inst, we will consider it reachable. + auto genericIR = as<IRGeneric>(inst); + if (as<IRFunc>(findInnerMostGenericReturnVal(genericIR))) + { + for (auto use = inst->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (as<IRModuleInst>(user->parent)) + return true; + + for (; user; user = user->parent) + { + if (auto genericUser = as<IRGeneric>(user)) + return ( + as<IRFunc>(findInnerMostGenericReturnVal(genericUser)) != nullptr); + + else if (as<IRFunc>(user)) + return true; + } + } + } + } + return false; + } + + template<typename Func> + void processAllReachableInsts(const Func& f) + { + workList.clear(); + workListSet.clear(); + + // We will do the first around of filter to include only functions and generic functions + for (auto child = module->getModuleInst()->getFirstChild(); child; + child = child->getNextInst()) + { + if (isReachableInst(child)) + { + addToWorkList(child); + } + } + + while (workList.getCount() != 0) + { + IRInst* inst = pop(false); + f(inst); + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + if (as<IRDecoration>(child)) + break; + switch (child->getOp()) + { + case kIROp_GenericSpecializationDictionary: + case kIROp_ExistentialFuncSpecializationDictionary: + case kIROp_ExistentialTypeSpecializationDictionary: + case kIROp_DebugInlinedAt: + case kIROp_DebugFunction: + continue; + default: + break; + } + SLANG_ASSERT(child); + if (shouldInstBeLiveIfParentIsLive(child, IRDeadCodeEliminationOptions())) + addToWorkList(child); + } + } + } // Process all differentiate calls, and recursively generate code for forward and backward // derivative functions. // |
