summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-primal-hoist.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-01-10 03:16:24 +0530
committerGitHub <noreply@github.com>2025-01-09 13:46:24 -0800
commit87f00a36a123e36b415eeea82e02a8366cc5b881 (patch)
tree719270397242dd0ea2cccf36f586118ac30a6ff1 /source/slang/slang-ir-autodiff-primal-hoist.cpp
parent6706c1a7764ae03d810e35ce766ba153ebf7ee03 (diff)
[Auto-diff] Overhaul auto-diff type tracking + Overhaul dynamic dispatch for differentiable functions (#5866)
* Overhauled the auto-diff system for dynamic dispatch * More fixes * remove intermediate dumps * Update slang-ast-type.h * More fixes + add a workaround for existential no-diff * Update reverse-control-flow-3.slang * remove dumps * remove more dumps * Delete working-reverse-control-flow-3.hlsl * Cleanup comments + unused variables * More comment cleanup * Add support for lowering `DiffPairType(TypePack)` & `MakePair(MakeValuePack, MakeValuePack)` * Fix array of issues in Falcor tests. * Update slang-ir-autodiff-pairs.cpp * More fixes for Falcor image tests * Small fixups. --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-primal-hoist.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp141
1 files changed, 114 insertions, 27 deletions
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index a3f6079ac..ef5161104 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -344,8 +344,18 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
continue;
}
+ // General case: we'll add all primal operands to the work list.
addPrimalOperandsToWorkList(child);
+ // Also add type annotations to the list, since these have to be made available to the
+ // function context.
+ //
+ if (as<IRDifferentiableTypeAnnotation>(child))
+ {
+ checkpointInfo->recomputeSet.add(child);
+ addPrimalOperandsToWorkList(child);
+ }
+
// We'll be conservative with the decorations we consider as differential uses
// of a primal inst, in order to avoid weird behaviour with some decorations
//
@@ -1333,7 +1343,7 @@ struct UseChain
return result;
}
- void replace(IRBuilder* builder, IRInst* inst)
+ void replace(IROutOfOrderCloneContext* ctx, IRBuilder* builder, IRInst* inst)
{
SLANG_ASSERT(chain.getCount() > 0);
@@ -1345,30 +1355,27 @@ struct UseChain
return;
}
- IRCloneEnv env;
-
// Pop the last use, which is the base use that needs to be replaced.
auto baseUse = chain.getLast();
chain.removeLast();
// Ensure that replacement inst is set as mapping for the baseUse.
- env.mapOldValToNew[baseUse->get()] = inst;
-
- auto lastInstInChain = inst;
+ ctx->cloneEnv.mapOldValToNew[baseUse->get()] = inst;
IRBuilder chainBuilder(builder->getModule());
setInsertAfterOrdinaryInst(&chainBuilder, inst);
chain.reverse();
+ chain.removeLast();
// Clone the rest of the chain.
for (auto& use : chain)
{
- lastInstInChain = cloneInst(&env, &chainBuilder, use->get());
+ ctx->cloneInstOutOfOrder(&chainBuilder, use->get());
}
- // Replace the base use.
- builder->replaceOperand(chain.getLast(), lastInstInChain);
+ // We won't actually replace the final use, because if there are multiple chains
+ // it can cause problems. The parent UseGraph will handle that.
chain.clear();
}
@@ -1380,13 +1387,93 @@ struct UseChain
}
};
+struct UseGraph
+{
+ // Set of linear paths to the base use.
+ // Note that some nodes may be common to multiple paths.
+ //
+ OrderedDictionary<IRUse*, List<UseChain>> chainSets;
+
+ static UseGraph from(
+ IRInst* baseInst,
+ Func<bool, IRUse*> isRelevantUse,
+ Func<bool, IRInst*> passthroughInst)
+ {
+ UseGraph result;
+ for (auto use = baseInst->firstUse; use;)
+ {
+ auto nextUse = use->nextUse;
+
+ auto chains = UseChain::from(use, isRelevantUse, passthroughInst);
+ for (auto& chain : chains)
+ {
+ auto finalUse = chain.chain.getFirst();
+
+ if (!result.chainSets.containsKey(finalUse))
+ {
+ result.chainSets[finalUse] = List<UseChain>();
+ }
+
+ result.chainSets[finalUse].getValue().add(chain);
+ }
+
+ use = nextUse;
+ }
+ return result;
+ }
+
+ void replace(IRBuilder* builder, IRUse* use, IRInst* inst)
+ {
+ // Since we may have common nodes, we will use an out-of-order cloning context
+ // that can retroactively correct the uses as needed.
+ //
+ IROutOfOrderCloneContext ctx;
+ List<UseChain> chains = chainSets[use];
+ for (auto chain : chains)
+ {
+ chain.replace(&ctx, builder, inst);
+ }
+
+ if (!isTrivial())
+ {
+ builder->setInsertBefore(use->getUser());
+ auto lastInstInChain = ctx.cloneInstOutOfOrder(builder, use->get());
+
+ // Replace the base use.
+ builder->replaceOperand(use, lastInstInChain);
+ }
+ }
+
+ bool isTrivial()
+ {
+ // We're trivial if there's only one chain, and it has only one use.
+ if (chainSets.getCount() != 1)
+ return false;
+
+ auto& chain = chainSets.getFirst().value;
+ return chain.getCount() == 1;
+ }
+
+ List<IRUse*> getUniqueUses() const
+ {
+ List<IRUse*> result;
+
+ for (auto& pair : chainSets)
+ {
+ result.add(pair.key);
+ }
+
+ return result;
+ }
+};
+
// Trim defBlockIndices based on the indices of out of scope uses.
//
static List<IndexTrackingInfo> maybeTrimIndices(
const List<IndexTrackingInfo>& defBlockIndices,
const Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo,
- const List<UseChain>& outOfScopeUses)
+ const List<IRUse*>& outOfScopeUses)
{
// Go through uses, lookup the defBlockIndices, and remove any indices if they
// are not present in any of the uses. (This is sort of slow...)
@@ -1397,7 +1484,7 @@ static List<IndexTrackingInfo> maybeTrimIndices(
bool found = false;
for (const auto& use : outOfScopeUses)
{
- auto useInst = use.getUser();
+ auto useInst = use->getUser();
auto useBlock = useInst->getParent();
auto useBlockIndices = indexedBlockInfo.getValue(as<IRBlock>(useBlock));
if (useBlockIndices.contains(index))
@@ -1419,7 +1506,8 @@ bool canInstBeStored(IRInst* inst)
// stored into variables or context structs as normal values.
//
if (as<IRTypeType>(inst->getDataType()) || as<IRWitnessTableType>(inst->getDataType()) ||
- as<IRTypeKind>(inst->getDataType()) || as<IRFuncType>(inst->getDataType()))
+ as<IRTypeKind>(inst->getDataType()) || as<IRFuncType>(inst->getDataType()) ||
+ !inst->getDataType())
return false;
return true;
@@ -1577,6 +1665,9 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
//
auto isPassthroughInst = [&](IRInst* inst)
{
+ if (as<IRTerminatorInst>(inst))
+ return false;
+
if (!canInstBeStored(inst))
return true;
@@ -1590,16 +1681,9 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
return false;
};
- List<UseChain> outOfScopeUses;
- for (auto use = instToStore->firstUse; use;)
- {
- auto nextUse = use->nextUse;
+ UseGraph useGraph = UseGraph::from(instToStore, isRelevantUse, isPassthroughInst);
- List<UseChain> useChains = UseChain::from(use, isRelevantUse, isPassthroughInst);
- outOfScopeUses.addRange(useChains);
-
- use = nextUse;
- }
+ List<IRUse*> outOfScopeUses = useGraph.getUniqueUses();
if (outOfScopeUses.getCount() == 0)
{
@@ -1659,10 +1743,10 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
for (auto use : outOfScopeUses)
{
- setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use.getUser()));
+ setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser()));
List<IndexTrackingInfo>& useBlockIndices =
- indexedBlockInfo[getBlock(use.getUser())];
+ indexedBlockInfo[getBlock(use->getUser())];
IRInst* loadAddr = emitIndexedLoadAddressForVar(
&builder,
@@ -1670,7 +1754,8 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
defBlock,
defBlockIndices,
useBlockIndices);
- use.replace(&builder, loadAddr);
+
+ useGraph.replace(&builder, use, loadAddr);
}
if (!isRecomputeInst)
@@ -1729,11 +1814,13 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
for (auto use : outOfScopeUses)
{
+ // TODO: Prevent terminator insts from being treated as passthrough..
List<IndexTrackingInfo> useBlockIndices =
- indexedBlockInfo[getBlock(use.getUser())];
- setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use.getUser()));
- use.replace(
+ indexedBlockInfo[getBlock(use->getUser())];
+ setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser()));
+ useGraph.replace(
&builder,
+ use,
loadIndexedValue(
&builder,
localVar,