From 87f00a36a123e36b415eeea82e02a8366cc5b881 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 10 Jan 2025 03:16:24 +0530 Subject: [Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866) * Overhauled the auto-diff system for dynamic dispatch * More fixes * remove intermediate dumps * Update slang-ast-type.h * More fixes + add a workaround for existential no-diff * Update reverse-control-flow-3.slang * remove dumps * remove more dumps * Delete working-reverse-control-flow-3.hlsl * Cleanup comments + unused variables * More comment cleanup * Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)` * Fix array of issues in Falcor tests. * Update slang-ir-autodiff-pairs.cpp * More fixes for Falcor image tests * Small fixups. --------- Co-authored-by: Yong He --- source/slang/slang-ir-autodiff-primal-hoist.cpp | 141 +++++++++++++++++++----- 1 file changed, 114 insertions(+), 27 deletions(-) (limited to 'source/slang/slang-ir-autodiff-primal-hoist.cpp') diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index a3f6079ac..ef5161104 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -344,8 +344,18 @@ RefPtr AutodiffCheckpointPolicyBase::processFunc( continue; } + // General case: we'll add all primal operands to the work list. addPrimalOperandsToWorkList(child); + // Also add type annotations to the list, since these have to be made available to the + // function context. + // + if (as(child)) + { + checkpointInfo->recomputeSet.add(child); + addPrimalOperandsToWorkList(child); + } + // We'll be conservative with the decorations we consider as differential uses // of a primal inst, in order to avoid weird behaviour with some decorations // @@ -1333,7 +1343,7 @@ struct UseChain return result; } - void replace(IRBuilder* builder, IRInst* inst) + void replace(IROutOfOrderCloneContext* ctx, IRBuilder* builder, IRInst* inst) { SLANG_ASSERT(chain.getCount() > 0); @@ -1345,30 +1355,27 @@ struct UseChain return; } - IRCloneEnv env; - // Pop the last use, which is the base use that needs to be replaced. auto baseUse = chain.getLast(); chain.removeLast(); // Ensure that replacement inst is set as mapping for the baseUse. - env.mapOldValToNew[baseUse->get()] = inst; - - auto lastInstInChain = inst; + ctx->cloneEnv.mapOldValToNew[baseUse->get()] = inst; IRBuilder chainBuilder(builder->getModule()); setInsertAfterOrdinaryInst(&chainBuilder, inst); chain.reverse(); + chain.removeLast(); // Clone the rest of the chain. for (auto& use : chain) { - lastInstInChain = cloneInst(&env, &chainBuilder, use->get()); + ctx->cloneInstOutOfOrder(&chainBuilder, use->get()); } - // Replace the base use. - builder->replaceOperand(chain.getLast(), lastInstInChain); + // We won't actually replace the final use, because if there are multiple chains + // it can cause problems. The parent UseGraph will handle that. chain.clear(); } @@ -1380,13 +1387,93 @@ struct UseChain } }; +struct UseGraph +{ + // Set of linear paths to the base use. + // Note that some nodes may be common to multiple paths. + // + OrderedDictionary> chainSets; + + static UseGraph from( + IRInst* baseInst, + Func isRelevantUse, + Func passthroughInst) + { + UseGraph result; + for (auto use = baseInst->firstUse; use;) + { + auto nextUse = use->nextUse; + + auto chains = UseChain::from(use, isRelevantUse, passthroughInst); + for (auto& chain : chains) + { + auto finalUse = chain.chain.getFirst(); + + if (!result.chainSets.containsKey(finalUse)) + { + result.chainSets[finalUse] = List(); + } + + result.chainSets[finalUse].getValue().add(chain); + } + + use = nextUse; + } + return result; + } + + void replace(IRBuilder* builder, IRUse* use, IRInst* inst) + { + // Since we may have common nodes, we will use an out-of-order cloning context + // that can retroactively correct the uses as needed. + // + IROutOfOrderCloneContext ctx; + List chains = chainSets[use]; + for (auto chain : chains) + { + chain.replace(&ctx, builder, inst); + } + + if (!isTrivial()) + { + builder->setInsertBefore(use->getUser()); + auto lastInstInChain = ctx.cloneInstOutOfOrder(builder, use->get()); + + // Replace the base use. + builder->replaceOperand(use, lastInstInChain); + } + } + + bool isTrivial() + { + // We're trivial if there's only one chain, and it has only one use. + if (chainSets.getCount() != 1) + return false; + + auto& chain = chainSets.getFirst().value; + return chain.getCount() == 1; + } + + List getUniqueUses() const + { + List result; + + for (auto& pair : chainSets) + { + result.add(pair.key); + } + + return result; + } +}; + // Trim defBlockIndices based on the indices of out of scope uses. // static List maybeTrimIndices( const List& defBlockIndices, const Dictionary>& indexedBlockInfo, - const List& outOfScopeUses) + const List& outOfScopeUses) { // Go through uses, lookup the defBlockIndices, and remove any indices if they // are not present in any of the uses. (This is sort of slow...) @@ -1397,7 +1484,7 @@ static List maybeTrimIndices( bool found = false; for (const auto& use : outOfScopeUses) { - auto useInst = use.getUser(); + auto useInst = use->getUser(); auto useBlock = useInst->getParent(); auto useBlockIndices = indexedBlockInfo.getValue(as(useBlock)); if (useBlockIndices.contains(index)) @@ -1419,7 +1506,8 @@ bool canInstBeStored(IRInst* inst) // stored into variables or context structs as normal values. // if (as(inst->getDataType()) || as(inst->getDataType()) || - as(inst->getDataType()) || as(inst->getDataType())) + as(inst->getDataType()) || as(inst->getDataType()) || + !inst->getDataType()) return false; return true; @@ -1577,6 +1665,9 @@ RefPtr ensurePrimalAvailability( // auto isPassthroughInst = [&](IRInst* inst) { + if (as(inst)) + return false; + if (!canInstBeStored(inst)) return true; @@ -1590,16 +1681,9 @@ RefPtr ensurePrimalAvailability( return false; }; - List outOfScopeUses; - for (auto use = instToStore->firstUse; use;) - { - auto nextUse = use->nextUse; + UseGraph useGraph = UseGraph::from(instToStore, isRelevantUse, isPassthroughInst); - List useChains = UseChain::from(use, isRelevantUse, isPassthroughInst); - outOfScopeUses.addRange(useChains); - - use = nextUse; - } + List outOfScopeUses = useGraph.getUniqueUses(); if (outOfScopeUses.getCount() == 0) { @@ -1659,10 +1743,10 @@ RefPtr ensurePrimalAvailability( for (auto use : outOfScopeUses) { - setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use.getUser())); + setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser())); List& useBlockIndices = - indexedBlockInfo[getBlock(use.getUser())]; + indexedBlockInfo[getBlock(use->getUser())]; IRInst* loadAddr = emitIndexedLoadAddressForVar( &builder, @@ -1670,7 +1754,8 @@ RefPtr ensurePrimalAvailability( defBlock, defBlockIndices, useBlockIndices); - use.replace(&builder, loadAddr); + + useGraph.replace(&builder, use, loadAddr); } if (!isRecomputeInst) @@ -1729,11 +1814,13 @@ RefPtr ensurePrimalAvailability( for (auto use : outOfScopeUses) { + // TODO: Prevent terminator insts from being treated as passthrough.. List useBlockIndices = - indexedBlockInfo[getBlock(use.getUser())]; - setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use.getUser())); - use.replace( + indexedBlockInfo[getBlock(use->getUser())]; + setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser())); + useGraph.replace( &builder, + use, loadIndexedValue( &builder, localVar, -- cgit v1.2.3