summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorkaizhangNV <149626564+kaizhangNV@users.noreply.github.com>2025-08-26 12:56:47 -0400
committerGitHub <noreply@github.com>2025-08-26 09:56:47 -0700
commit5060042bb63cbf42063f5e81c58881e1e8323857 (patch)
treef81dbca4524da449df0c68c492235c363666f6cc /source
parent7b30ad489198ecedb16a0265f290c1e32772514c (diff)
fix a autodiff crash (#8259)
close #8068. Currently the AutoDiff aggressively scan every IR inst in searching the differentiable IR. This is not efficient and could have bug, details in https://github.com/shader-slang/slang/issues/8068#issuecomment-3214856668. This PR change the behavior. It will do a initial filter to only gather the global differentiable IRs and IRFunc and IRGeneric as well. For IRGeneric, we will pick it only when it's used in other generic function (it's only useful when dealing with dynamic dispatch). Then we will start searching reachable insts from this IR list by using the same method as before.
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff.cpp85
1 files changed, 85 insertions, 0 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 7876d7eeb..a46b20b5f 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -3345,6 +3345,91 @@ struct AutoDiffPass : public InstPassBase
return true;
}
+ // This function will check whether a global inst can be gathered as a candidate for
+ // differentiable IR. Before we do the recursive search on every IR, we will first filter
+ // the global IR to find out some candidates, and then we will start the recursive search
+ // on this filtered list.
+ // For a generic, we will add it to the list only when it's a function and it's used by other
+ // function, because that is the case of dynamic dispatch.
+ bool isReachableInst(IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_ForwardDifferentiate:
+ case kIROp_BackwardDifferentiate:
+ case kIROp_BackwardDifferentiatePrimal:
+ case kIROp_BackwardDifferentiatePropagate:
+ case kIROp_BackwardDiffIntermediateContextType:
+ case kIROp_Func:
+ return true;
+ case kIROp_Generic:
+ // For generic, if it's a generic function and it's used by any other reachable
+ // inst, we will consider it reachable.
+ auto genericIR = as<IRGeneric>(inst);
+ if (as<IRFunc>(findInnerMostGenericReturnVal(genericIR)))
+ {
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ auto user = use->getUser();
+ if (as<IRModuleInst>(user->parent))
+ return true;
+
+ for (; user; user = user->parent)
+ {
+ if (auto genericUser = as<IRGeneric>(user))
+ return (
+ as<IRFunc>(findInnerMostGenericReturnVal(genericUser)) != nullptr);
+
+ else if (as<IRFunc>(user))
+ return true;
+ }
+ }
+ }
+ }
+ return false;
+ }
+
+ template<typename Func>
+ void processAllReachableInsts(const Func& f)
+ {
+ workList.clear();
+ workListSet.clear();
+
+ // We will do the first around of filter to include only functions and generic functions
+ for (auto child = module->getModuleInst()->getFirstChild(); child;
+ child = child->getNextInst())
+ {
+ if (isReachableInst(child))
+ {
+ addToWorkList(child);
+ }
+ }
+
+ while (workList.getCount() != 0)
+ {
+ IRInst* inst = pop(false);
+ f(inst);
+ for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
+ {
+ if (as<IRDecoration>(child))
+ break;
+ switch (child->getOp())
+ {
+ case kIROp_GenericSpecializationDictionary:
+ case kIROp_ExistentialFuncSpecializationDictionary:
+ case kIROp_ExistentialTypeSpecializationDictionary:
+ case kIROp_DebugInlinedAt:
+ case kIROp_DebugFunction:
+ continue;
+ default:
+ break;
+ }
+ SLANG_ASSERT(child);
+ if (shouldInstBeLiveIfParentIsLive(child, IRDeadCodeEliminationOptions()))
+ addToWorkList(child);
+ }
+ }
+ }
// Process all differentiate calls, and recursively generate code for forward and backward
// derivative functions.
//