diff options
| author | kaizhangNV <149626564+kaizhangNV@users.noreply.github.com> | 2025-08-26 12:56:47 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-08-26 09:56:47 -0700 |
| commit | 5060042bb63cbf42063f5e81c58881e1e8323857 (patch) | |
| tree | f81dbca4524da449df0c68c492235c363666f6cc /source/slang | |
| parent | 7b30ad489198ecedb16a0265f290c1e32772514c (diff) | |
fix a autodiff crash (#8259)
close #8068.
Currently the AutoDiff aggressively scan every IR inst in searching the
differentiable IR. This is not efficient and could have bug, details in
https://github.com/shader-slang/slang/issues/8068#issuecomment-3214856668.
This PR change the behavior. It will do a initial filter to only gather
the global differentiable IRs and IRFunc and IRGeneric as well. For
IRGeneric, we will pick it only when it's used in other generic function
(it's only useful when dealing with dynamic dispatch).
Then we will start searching reachable insts from this IR list by using
the same method as before.
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 7876d7eeb..a46b20b5f 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -3345,6 +3345,91 @@ struct AutoDiffPass : public InstPassBase return true; } + // This function will check whether a global inst can be gathered as a candidate for + // differentiable IR. Before we do the recursive search on every IR, we will first filter + // the global IR to find out some candidates, and then we will start the recursive search + // on this filtered list. + // For a generic, we will add it to the list only when it's a function and it's used by other + // function, because that is the case of dynamic dispatch. + bool isReachableInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_ForwardDifferentiate: + case kIROp_BackwardDifferentiate: + case kIROp_BackwardDifferentiatePrimal: + case kIROp_BackwardDifferentiatePropagate: + case kIROp_BackwardDiffIntermediateContextType: + case kIROp_Func: + return true; + case kIROp_Generic: + // For generic, if it's a generic function and it's used by any other reachable + // inst, we will consider it reachable. + auto genericIR = as<IRGeneric>(inst); + if (as<IRFunc>(findInnerMostGenericReturnVal(genericIR))) + { + for (auto use = inst->firstUse; use; use = use->nextUse) + { + auto user = use->getUser(); + if (as<IRModuleInst>(user->parent)) + return true; + + for (; user; user = user->parent) + { + if (auto genericUser = as<IRGeneric>(user)) + return ( + as<IRFunc>(findInnerMostGenericReturnVal(genericUser)) != nullptr); + + else if (as<IRFunc>(user)) + return true; + } + } + } + } + return false; + } + + template<typename Func> + void processAllReachableInsts(const Func& f) + { + workList.clear(); + workListSet.clear(); + + // We will do the first around of filter to include only functions and generic functions + for (auto child = module->getModuleInst()->getFirstChild(); child; + child = child->getNextInst()) + { + if (isReachableInst(child)) + { + addToWorkList(child); + } + } + + while (workList.getCount() != 0) + { + IRInst* inst = pop(false); + f(inst); + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + if (as<IRDecoration>(child)) + break; + switch (child->getOp()) + { + case kIROp_GenericSpecializationDictionary: + case kIROp_ExistentialFuncSpecializationDictionary: + case kIROp_ExistentialTypeSpecializationDictionary: + case kIROp_DebugInlinedAt: + case kIROp_DebugFunction: + continue; + default: + break; + } + SLANG_ASSERT(child); + if (shouldInstBeLiveIfParentIsLive(child, IRDeadCodeEliminationOptions())) + addToWorkList(child); + } + } + } // Process all differentiate calls, and recursively generate code for forward and backward // derivative functions. // |
