summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-04-21 14:28:57 -0700
committerGitHub <noreply@github.com>2023-04-21 14:28:57 -0700
commit957a4d3eb0a14a9d57bbb325ef0e1d458df2d2b9 (patch)
treefabc9317b1595c9f74f5b25ee83d16f4260a19d3
parent69a327a98e3f9504863f9ecb623aa93036ac43db (diff)
Refactor checkpointing policy and availability pass. (#2826)
Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/slang-ir-autodiff-cfg-norm.cpp357
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.cpp895
-rw-r--r--source/slang/slang-ir-autodiff-primal-hoist.h62
-rw-r--r--source/slang/slang-ir-autodiff-region.h25
-rw-r--r--source/slang/slang-ir-autodiff-rev.cpp21
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h560
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp9
-rw-r--r--source/slang/slang-ir-autodiff-unzip.h224
-rw-r--r--source/slang/slang-ir-autodiff.cpp48
-rw-r--r--source/slang/slang-ir-autodiff.h4
-rw-r--r--source/slang/slang-ir-dce.cpp24
-rw-r--r--source/slang/slang-ir-inst-defs.h6
-rw-r--r--source/slang/slang-ir-insts.h21
-rw-r--r--source/slang/slang-ir-loop-unroll.cpp21
-rw-r--r--source/slang/slang-ir-redundancy-removal.cpp8
-rw-r--r--source/slang/slang-ir-util.cpp20
-rw-r--r--source/slang/slang-ir-util.h2
-rw-r--r--source/slang/slang-ir.cpp7
18 files changed, 1150 insertions, 1164 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp
index 80ee37988..a67b7f167 100644
--- a/source/slang/slang-ir-autodiff-cfg-norm.cpp
+++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp
@@ -4,6 +4,7 @@
#include "slang-ir-ssa.h"
#include "slang-ir-validate.h"
+#include "slang-ir-util.h"
namespace Slang
{
@@ -17,31 +18,26 @@ struct RegionEndpoint
bool isRegionEmpty = false;
- RegionEndpoint(IRBlock* exitBlock, bool inBreakRegion, bool inBaseRegion) :
- exitBlock(exitBlock),
- inBreakRegion(inBreakRegion),
- inBaseRegion(inBaseRegion),
- isRegionEmpty(false)
- { }
-
- RegionEndpoint(
- IRBlock* exitBlock,
- bool inBreakRegion,
- bool inBaseRegion,
- bool isRegionEmpty) :
- exitBlock(exitBlock),
- inBreakRegion(inBreakRegion),
- inBaseRegion(inBaseRegion),
- isRegionEmpty(isRegionEmpty)
- { }
-
- RegionEndpoint()
- { }
+ RegionEndpoint(IRBlock* exitBlock, bool inBreakRegion, bool inBaseRegion)
+ : exitBlock(exitBlock)
+ , inBreakRegion(inBreakRegion)
+ , inBaseRegion(inBaseRegion)
+ , isRegionEmpty(false)
+ {}
+
+ RegionEndpoint(IRBlock* exitBlock, bool inBreakRegion, bool inBaseRegion, bool isRegionEmpty)
+ : exitBlock(exitBlock)
+ , inBreakRegion(inBreakRegion)
+ , inBaseRegion(inBaseRegion)
+ , isRegionEmpty(isRegionEmpty)
+ {}
+
+ RegionEndpoint() {}
};
struct BreakableRegionInfo
{
- IRVar* breakVar;
+ IRVar* breakVar;
IRBlock* breakBlock;
IRBlock* headerBlock;
};
@@ -49,16 +45,15 @@ struct BreakableRegionInfo
struct CFGNormalizationContext
{
IRModule* module;
- DiagnosticSink* sink;
+ DiagnosticSink* sink;
};
-
IRBlock* getOrCreateTopLevelCondition(IRLoop* loopInst)
{
// For now, we're going to naively assume the next block is the condition block.
// Add in more support for more cases as necessary.
- //
-
+ //
+
auto firstBlock = loopInst->getTargetBlock();
if (as<IRIfElse>(firstBlock->getTerminator()))
@@ -72,7 +67,7 @@ IRBlock* getOrCreateTopLevelCondition(IRLoop* loopInst)
//
IRBuilder condBuilder(loopInst->getModule());
-
+
auto condBlock = condBuilder.emitBlock();
condBlock->insertAfter(as<IRBlock>(loopInst->getParent()));
@@ -81,22 +76,17 @@ IRBlock* getOrCreateTopLevelCondition(IRLoop* loopInst)
// Emit a condition: true side goes to the loop body, and
// false side goes into the break block.
- //
+ //
condBuilder.setInsertInto(condBlock);
auto ifElse = as<IRIfElse>(condBuilder.emitIfElse(
- condBuilder.getBoolValue(true),
- firstBlock,
- loopInst->getBreakBlock(),
- firstBlock));
-
+ condBuilder.getBoolValue(true), firstBlock, loopInst->getBreakBlock(), firstBlock));
+
// We'll insert a blank block between the condition and the
// break block, since otherwise, we might trip up the later
// parts of this pass.
//
- condBuilder.insertBlockAlongEdge(
- loopInst->getModule(),
- IREdge(&ifElse->falseBlock));
-
+ condBuilder.insertBlockAlongEdge(loopInst->getModule(), IREdge(&ifElse->falseBlock));
+
return condBlock;
}
}
@@ -105,9 +95,9 @@ struct CFGNormalizationPass
{
CFGNormalizationContext cfgContext;
- CFGNormalizationPass(CFGNormalizationContext ctx) :
- cfgContext(ctx)
- { }
+ CFGNormalizationPass(CFGNormalizationContext ctx)
+ : cfgContext(ctx)
+ {}
void replaceBreakWithAfterBlock(
IRBuilder* builder,
@@ -158,13 +148,12 @@ struct CFGNormalizationPass
return branchInst ? branchInst->getTargetBlock() : nullptr;
}
-
bool isSuccessorBlock(IRBlock* baseBlock, IRBlock* succBlock)
{
for (auto successor : baseBlock->getSuccessors())
if (successor == succBlock)
return true;
-
+
return false;
}
@@ -184,9 +173,7 @@ struct CFGNormalizationPass
}
RegionEndpoint getNormalizedRegionEndpoint(
- BreakableRegionInfo* parentRegion,
- IRBlock* entryBlock,
- List<IRBlock*> afterBlocks)
+ BreakableRegionInfo* parentRegion, IRBlock* entryBlock, List<IRBlock*> afterBlocks)
{
IRBlock* currentBlock = entryBlock;
_moveVarsToRegionHeader(parentRegion, currentBlock);
@@ -195,7 +182,7 @@ struct CFGNormalizationPass
// and not in the 'break' control flow
// It is the job of the *caller* to make sure the break flow
// does not reach this point.
- //
+ //
bool currBreakRegion = false;
bool currBaseRegion = true;
@@ -204,7 +191,7 @@ struct CFGNormalizationPass
//
if (afterBlocks.contains(currentBlock))
return RegionEndpoint(currentBlock, currBreakRegion, currBaseRegion, true);
-
+
IRBuilder builder(cfgContext.module);
List<IRBlock*> pendingAfterBlocks;
@@ -216,7 +203,7 @@ struct CFGNormalizationPass
// We could arrive at the after-block before or
// after encountering a break statement.
// To handle this, we'll split the flow by checking the break flag
- //
+ //
builder.setInsertAfter(block);
auto preAfterSplitBlock = builder.emitBlock();
@@ -229,28 +216,24 @@ struct CFGNormalizationPass
builder.setInsertInto(preAfterSplitBlock);
builder.emitBranch(afterSplitBlock);
-
+
// Converging block for the split that we're making.
auto afterSplitAfterBlock = builder.emitBlock();
builder.setInsertInto(afterSplitBlock);
auto breakFlagValue = builder.emitLoad(parentRegion->breakVar);
- builder.emitIfElse(
- breakFlagValue,
- block,
- afterSplitAfterBlock,
- afterSplitAfterBlock);
+ builder.emitIfElse(breakFlagValue, block, afterSplitAfterBlock, afterSplitAfterBlock);
// At this point, we need to place afterSplitAfterBlock between
- // at the _end_ of this region, but we aren't there yet (and
+ // at the _end_ of this region, but we aren't there yet (and
// don't know which block is the end of this region)
// Therefore, we'll defer this step and add it to a list for later.
- //
+ //
pendingAfterBlocks.add(afterSplitAfterBlock);
};
- // Follow this thread of execution till we hit an
+ // Follow this thread of execution till we hit an
// acceptable after block.
//
while (!afterBlocks.contains(maybeGetUnconditionalTarget(currentBlock)))
@@ -259,14 +242,14 @@ struct CFGNormalizationPass
auto terminator = currentBlock->getTerminator();
switch (terminator->getOp())
{
- case kIROp_unconditionalBranch:
+ case kIROp_unconditionalBranch:
{
auto targetBlock = as<IRUnconditionalBranch>(terminator)->getTargetBlock();
currentBlock = targetBlock;
break;
}
-
- case kIROp_ifElse:
+
+ case kIROp_ifElse:
{
auto ifElse = as<IRIfElse>(terminator);
@@ -274,24 +257,24 @@ struct CFGNormalizationPass
// lead back to the condition.
//
SLANG_ASSERT(ifElse->getAfterBlock() != parentRegion->breakBlock);
-
+
auto trueEndPoint = getNormalizedRegionEndpoint(
parentRegion,
ifElse->getTrueBlock(),
List<IRBlock*>(ifElse->getAfterBlock(), parentRegion->breakBlock));
-
+
auto falseEndPoint = getNormalizedRegionEndpoint(
parentRegion,
ifElse->getFalseBlock(),
List<IRBlock*>(ifElse->getAfterBlock(), parentRegion->breakBlock));
-
+
auto trueTargetBlock = getUnconditionalTarget(trueEndPoint);
auto falseTargetBlock = getUnconditionalTarget(falseEndPoint);
-
+
auto afterBlock = ifElse->getAfterBlock();
// Trivial case, both end-points branch into the after block
- /*if (trueTargetBlock == afterBlock &&
+ /*if (trueTargetBlock == afterBlock &&
falseTargetBlock == afterBlock)
{
if ()
@@ -308,7 +291,7 @@ struct CFGNormalizationPass
{
// Branch into after block (and set break variable)
replaceBreakWithAfterBlock(
- &builder,
+ &builder,
parentRegion,
trueEndPoint.exitBlock,
afterBlock,
@@ -321,10 +304,10 @@ struct CFGNormalizationPass
}
else
{
- // If this branch naturally branches into our
+ // If this branch naturally branches into our
// after-block, copy whatever flags the endpoints
// have.
- //
+ //
afterBreakRegion = afterBreakRegion || trueEndPoint.inBreakRegion;
afterBaseRegion = afterBaseRegion || trueEndPoint.inBaseRegion;
}
@@ -346,10 +329,10 @@ struct CFGNormalizationPass
}
else
{
- // If this branch naturally branches into our
+ // If this branch naturally branches into our
// after-block, copy whatever flags the endpoints
// have.
- //
+ //
afterBreakRegion = afterBreakRegion || falseEndPoint.inBreakRegion;
afterBaseRegion = afterBaseRegion || falseEndPoint.inBaseRegion;
}
@@ -365,12 +348,12 @@ struct CFGNormalizationPass
// Do we need to split the after region?
if (afterBaseRegion && afterBreakRegion)
{
- // Before we split the afterBlock, we
+ // Before we split the afterBlock, we
// want to make sure the afterBlock is
// firmly _inside_ the current region.
- // If it's part of the parent, add a
+ // If it's part of the parent, add a
// dummy block.
- //
+ //
if (afterBlocks.contains(afterBlock))
{
auto newAfterBlock = builder.emitBlock();
@@ -382,15 +365,17 @@ struct CFGNormalizationPass
// condition block. (This eventually causes cloneInst to fail,
// since it is currently order-dependent)
// Remove this once cloneInst is order-independent.
- //
+ //
// newAfterBlock->insertBefore(afterBlock);
newAfterBlock->insertAfter(falseEndPoint.exitBlock);
builder.emitBranch(afterBlock);
-
+
ifElse->afterBlock.set(newAfterBlock);
- as<IRUnconditionalBranch>(trueEndPoint.exitBlock->getTerminator())->block.set(newAfterBlock);
- as<IRUnconditionalBranch>(falseEndPoint.exitBlock->getTerminator())->block.set(newAfterBlock);
+ as<IRUnconditionalBranch>(trueEndPoint.exitBlock->getTerminator())
+ ->block.set(newAfterBlock);
+ as<IRUnconditionalBranch>(falseEndPoint.exitBlock->getTerminator())
+ ->block.set(newAfterBlock);
afterBlock = newAfterBlock;
}
@@ -402,15 +387,15 @@ struct CFGNormalizationPass
afterBreakRegion = false;
afterBaseRegion = true;
}
-
+
currentBlock = afterBlock;
currBreakRegion = afterBreakRegion;
currBaseRegion = afterBaseRegion;
break;
}
- case kIROp_loop:
- case kIROp_Switch:
+ case kIROp_loop:
+ case kIROp_Switch:
{
auto breakBlock = normalizeBreakableRegion(terminator);
@@ -419,10 +404,10 @@ struct CFGNormalizationPass
break;
}
- default:
- // Do proper diagnosing
- SLANG_UNEXPECTED("Unhandled control flow inst");
- break;
+ default:
+ // Do proper diagnosing
+ SLANG_UNEXPECTED("Unhandled control flow inst");
+ break;
}
_moveVarsToRegionHeader(parentRegion, currentBlock);
@@ -438,7 +423,7 @@ struct CFGNormalizationPass
SLANG_ASSERT(nextRegionBlock);
builder.emitBranch(nextRegionBlock);
-
+
builder.setInsertInto(currentBlock);
currentBlock->getTerminator()->removeAndDeallocate();
builder.emitBranch(block);
@@ -458,7 +443,7 @@ struct CFGNormalizationPass
HashSet<IRBlock*> predecessorSet;
for (auto predecessor : block->getPredecessors())
predecessorSet.Add(predecessor);
-
+
return predecessorSet;
}
@@ -466,29 +451,27 @@ struct CFGNormalizationPass
{
// Get 'looping' block (first block in loop)
auto firstLoopBlock = loop->getTargetBlock();
-
+
// If we only have one predecessor, the loop is trivial.
return (getPredecessorSet(firstLoopBlock).Count() == 1);
}
- IRBlock* normalizeBreakableRegion(
- IRInst* branchInst)
+ IRBlock* normalizeBreakableRegion(IRInst* branchInst)
{
IRBuilder builder(cfgContext.module);
switch (branchInst->getOp())
{
- case kIROp_loop:
+ case kIROp_loop:
{
BreakableRegionInfo info;
info.breakBlock = as<IRLoop>(branchInst)->getBreakBlock();
info.headerBlock = as<IRBlock>(branchInst->getParent());
// Emit var into parent block.
- builder.setInsertBefore(
- as<IRBlock>(branchInst->getParent())->getTerminator());
-
- // Create and initialize break var to true
+ builder.setInsertBefore(as<IRBlock>(branchInst->getParent())->getTerminator());
+
+ // Create and initialize break var to true
// true -> no break yet.
// false -> atleast one break statement hit.
//
@@ -500,24 +483,23 @@ struct CFGNormalizationPass
// edges actually in a loop), we're just going to remove
// it.. (we can do this, because the normalization pass
// will transform any break and continue statements)
- //
+ //
if (isLoopTrivial(as<IRLoop>(branchInst)))
{
auto firstLoopBlock = as<IRLoop>(branchInst)->getTargetBlock();
-
+
// Normalize the region from the first loop block till break.
auto preBreakEndPoint = getNormalizedRegionEndpoint(
- &info,
- firstLoopBlock,
- List<IRBlock*>(info.breakBlock));
-
+ &info, firstLoopBlock, List<IRBlock*>(info.breakBlock));
+
// Should not be empty.. but check anyway
SLANG_RELEASE_ASSERT(!preBreakEndPoint.isRegionEmpty);
- // Quick consistency check.. preBreakEndPoint should be
+ // Quick consistency check.. preBreakEndPoint should be
// branching into break block.
- SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(
- preBreakEndPoint.exitBlock->getTerminator())->getTargetBlock() == info.breakBlock);
+ SLANG_RELEASE_ASSERT(
+ as<IRUnconditionalBranch>(preBreakEndPoint.exitBlock->getTerminator())
+ ->getTargetBlock() == info.breakBlock);
auto currentBlock = branchInst->getParent();
@@ -529,30 +511,27 @@ struct CFGNormalizationPass
return info.breakBlock;
}
- auto condBlock = getOrCreateTopLevelCondition(as<IRLoop>(branchInst));
+ auto condBlock =
+ getOrCreateTopLevelCondition(as<IRLoop>(branchInst));
auto ifElse = as<IRIfElse>(condBlock->getTerminator());
auto trueEndPoint = getNormalizedRegionEndpoint(
- &info,
- ifElse->getTrueBlock(),
- List<IRBlock*>(condBlock, info.breakBlock));
-
+ &info, ifElse->getTrueBlock(), List<IRBlock*>(condBlock, info.breakBlock));
+
auto falseEndPoint = getNormalizedRegionEndpoint(
- &info,
- ifElse->getFalseBlock(),
- List<IRBlock*>(condBlock, info.breakBlock));
-
+ &info, ifElse->getFalseBlock(), List<IRBlock*>(condBlock, info.breakBlock));
+
RegionEndpoint loopEndPoint;
bool isLoopOnTrueSide = true;
-
+
// First figure out which side belongs to the loop body.
if (isSuccessorBlock(trueEndPoint.exitBlock, condBlock))
{
loopEndPoint = trueEndPoint;
isLoopOnTrueSide = true;
}
-
+
if (isSuccessorBlock(falseEndPoint.exitBlock, condBlock))
{
loopEndPoint = falseEndPoint;
@@ -560,11 +539,11 @@ struct CFGNormalizationPass
}
// Right now, we only support loops where the loop is on the true side of
- // the condition. If we ever encounter the other case, fill in logic to
+ // the condition. If we ever encounter the other case, fill in logic to
// flip the condition.
//
SLANG_RELEASE_ASSERT(isLoopOnTrueSide);
-
+
// Expect atleast one basic block (other than the condition block), in
// the loop.
//
@@ -573,7 +552,7 @@ struct CFGNormalizationPass
// Does the loop endpoint have both 'break' and 'base'
// control flows?
- //
+ //
if (loopEndPoint.inBaseRegion && loopEndPoint.inBreakRegion)
{
// Add a test for the break variable into the condition.
@@ -582,36 +561,30 @@ struct CFGNormalizationPass
builder.setInsertBefore(ifElse);
auto breakFlagVal = builder.emitLoad(info.breakVar);
- // Need to invert the break flag if the loop is
+ // Need to invert the break flag if the loop is
// on the false side.
- //
+ //
if (!isLoopOnTrueSide)
{
IRInst* args[1] = {breakFlagVal};
- breakFlagVal = builder.emitIntrinsicInst(
- builder.getBoolType(),
- kIROp_Not,
- 1,
- args);
+ breakFlagVal =
+ builder.emitIntrinsicInst(builder.getBoolType(), kIROp_Not, 1, args);
}
IRInst* args[2] = {cond, breakFlagVal};
// If break-var = true, direct flow to the loop
// otherwise, direct flow to break
- //
- auto complexCond = builder.emitIntrinsicInst(
- builder.getBoolType(),
- kIROp_And,
- 2,
- args);
-
+ //
+ auto complexCond =
+ builder.emitIntrinsicInst(builder.getBoolType(), kIROp_And, 2, args);
+
ifElse->condition.set(complexCond);
}
-
+
return info.breakBlock;
}
- case kIROp_Switch:
+ case kIROp_Switch:
{
auto switchInst = as<IRSwitch>(branchInst);
@@ -620,10 +593,9 @@ struct CFGNormalizationPass
info.breakBlock = as<IRSwitch>(branchInst)->getBreakLabel();
// Emit var into parent block.
- builder.setInsertBefore(
- as<IRBlock>(branchInst->getParent())->getTerminator());
-
- // Create and initialize break var to true
+ builder.setInsertBefore(as<IRBlock>(branchInst->getParent())->getTerminator());
+
+ // Create and initialize break var to true
// true -> no break yet.
// false -> atleast one break statement hit.
//
@@ -635,30 +607,31 @@ struct CFGNormalizationPass
{
auto caseBlock = switchInst->getCaseLabel(ii);
auto caseEndPoint = getNormalizedRegionEndpoint(
- &info,
- caseBlock,
- List<IRBlock*>(info.breakBlock)).exitBlock;
+ &info, caseBlock, List<IRBlock*>(info.breakBlock))
+ .exitBlock;
// Consistency check (if this case hits, it's probably
// because the switch has fall-through, which we don't support)
- SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(
- caseEndPoint->getTerminator())->getTargetBlock() == info.breakBlock);
+ SLANG_RELEASE_ASSERT(
+ as<IRUnconditionalBranch>(caseEndPoint->getTerminator())
+ ->getTargetBlock() == info.breakBlock);
}
- auto defaultEndPoint = getNormalizedRegionEndpoint(
- &info,
- switchInst->getDefaultLabel(),
- List<IRBlock*>(info.breakBlock)).exitBlock;
+ auto defaultEndPoint =
+ getNormalizedRegionEndpoint(
+ &info, switchInst->getDefaultLabel(), List<IRBlock*>(info.breakBlock))
+ .exitBlock;
// Consistency check (if this case hits, it's probably
// because the switch has fall-through, which we don't support)
- SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(
- defaultEndPoint->getTerminator())->getTargetBlock() == info.breakBlock);
+ SLANG_RELEASE_ASSERT(
+ as<IRUnconditionalBranch>(defaultEndPoint->getTerminator())->getTargetBlock() ==
+ info.breakBlock);
return info.breakBlock;
}
- default:
- break;
+ default:
+ break;
}
SLANG_UNEXPECTED("Unhandled control-flow inst");
@@ -666,18 +639,16 @@ struct CFGNormalizationPass
};
void normalizeCFG(
- IRModule* module,
- IRGlobalValueWithCode* func,
- IRCFGNormalizationPass const& options)
+ IRModule* module, IRGlobalValueWithCode* func, IRCFGNormalizationPass const& options)
{
// Remove phis to simplify our pass. We'll add them back in later
// with constructSSA.
- //
+ //
eliminatePhisInFunc(LivenessMode::Disabled, func->getModule(), func);
- CFGNormalizationContext context = {module, options.sink};
+ CFGNormalizationContext context = {module, options.sink};
CFGNormalizationPass cfgPass(context);
-
+
List<IRBlock*> workList;
workList.add(func->getFirstBlock());
@@ -703,9 +674,83 @@ void normalizeCFG(
}
}
+ // If we created a new condition block for a loop, the local vars defined in
+ // the original loop body will no longer dominate the exit block of the
+ // loop. If there are any uses of these variables outside the loop, they
+ // will become invalid. Therefore we need to hoist the local variables to
+ // the loop header block.
+ HashSet<IRBlock*> workListSet;
+ for (auto block : func->getBlocks())
+ {
+ if (auto loop = as<IRLoop>(block->getTerminator()))
+ {
+ auto condBlock = loop->getTargetBlock();
+ auto ifElse = as<IRIfElse>(condBlock->getTerminator());
+ auto bodyBlock = ifElse->getTrueBlock();
+
+ // Collect loop body blocks.
+ workList.clear();
+ workListSet.Clear();
+ workList.add(bodyBlock);
+ workListSet.add(bodyBlock);
+ for (Index i = 0; i < workList.getCount(); i++)
+ {
+ auto b = workList[i];
+ for (auto succ : b->getSuccessors())
+ {
+ if (succ != loop->getTargetBlock() && succ != loop->getBreakBlock())
+ {
+ if (workListSet.add(succ))
+ workList.add(succ);
+ }
+ }
+ }
+ auto insertionPoint = loop;
+ IRBuilder builder(func);
+ for (auto b : workList)
+ {
+ for (auto inst : b->getChildren())
+ {
+ // If inst has uses outside the loop body, we need to hoist it.
+ IRVar* tempVar = nullptr;
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ auto userBlock = as<IRBlock>(use->getUser()->getParent());
+ if (userBlock && !workListSet.Contains(userBlock))
+ {
+ // Hoist the inst.
+ if (auto var = as<IRVar>(inst))
+ {
+ // If inst is an var, this is easy, we just move it to the
+ // loop header.
+ var->insertBefore(insertionPoint);
+ break;
+ }
+ else
+ {
+ // For all other insts, we need to create a local var for it.
+ if (!tempVar)
+ {
+ builder.setInsertBefore(insertionPoint);
+ tempVar = builder.emitVar(inst->getFullType());
+ builder.setInsertAfter(inst);
+ builder.emitStore(tempVar, inst);
+ }
+ // Replace the use with a load of tempVar.
+ builder.setInsertBefore(use->getUser());
+ auto load = builder.emitLoad(tempVar);
+ builder.replaceOperand(use, load);
+ }
+ break;
+ }
+ }
+ }
+ }
+ }
+ }
disableIRValidationAtInsert();
constructSSA(module, func);
enableIRValidationAtInsert();
}
-}
+} // namespace Slang
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp
index 363572c86..6a9b504a6 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.cpp
+++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp
@@ -1,19 +1,277 @@
#include "slang-ir-autodiff-primal-hoist.h"
#include "slang-ir-autodiff-region.h"
-namespace Slang
+namespace Slang
{
+void applyCheckpointSet(
+ CheckpointSetInfo* checkpointInfo,
+ IRGlobalValueWithCode* func,
+ HoistedPrimalsInfo* hoistInfo,
+ HashSet<IRUse*> pendingUses,
+ Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock);
+
bool containsOperand(IRInst* inst, IRInst* operand)
{
for (UIndex ii = 0; ii < inst->getOperandCount(); ii++)
if (inst->getOperand(ii) == operand)
return true;
-
+
return false;
}
-RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalValueWithCode* func, BlockSplitInfo* splitInfo)
+static bool isDifferentialInst(IRInst* inst)
+{
+ auto parent = inst->getParent();
+ if (parent->findDecoration<IRDifferentialInstDecoration>())
+ return true;
+ return inst->findDecoration<IRDifferentialInstDecoration>() != nullptr;
+}
+
+static bool isDifferentialBlock(IRBlock* block)
+{
+ return block->findDecoration<IRDifferentialInstDecoration>();
+}
+
+static Dictionary<IRBlock*, IRBlock*> reconstructDiffBlockMap(IRGlobalValueWithCode* func)
+{
+ Dictionary<IRBlock*, IRBlock*> diffBlockMap;
+ for (auto block : func->getBlocks())
+ {
+ if (auto diffDecor = block->findDecoration<IRDifferentialInstDecoration>())
+ {
+ if (diffDecor->getPrimalType())
+ diffBlockMap[as<IRBlock>(diffDecor->getPrimalInst())] = block;
+ }
+ }
+ return diffBlockMap;
+}
+
+static IRBlock* getLoopRegionBodyBlock(IRLoop* loop)
+{
+ auto condBlock = as<IRBlock>(loop->getTargetBlock());
+ // We assume the loop body always sit at the true side of the if-else.
+ if (auto ifElse = as<IRIfElse>(condBlock->getTerminator()))
+ {
+ return ifElse->getTrueBlock();
+ }
+ return nullptr;
+}
+
+static IRBlock* tryGetSubRegionEndBlock(IRInst* terminator)
+{
+ auto loop = as<IRLoop>(terminator);
+ if (!loop)
+ return nullptr;
+ return loop->getBreakBlock();
+}
+
+static Dictionary<IRBlock*, IRBlock*> createPrimalRecomputeBlocks(
+ IRGlobalValueWithCode* func,
+ Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo)
+{
+ IRBlock* firstDiffBlock = nullptr;
+ for (auto block : func->getBlocks())
+ {
+ if (isDifferentialBlock(block))
+ {
+ firstDiffBlock = block;
+ break;
+ }
+ }
+ if (!firstDiffBlock)
+ return Dictionary<IRBlock*, IRBlock*>();
+
+ Dictionary<IRLoop*, IRLoop*> mapPrimalLoopToDiffLoop;
+ for (auto block : func->getBlocks())
+ {
+ if (isDifferentialBlock(block))
+ {
+ if (auto diffLoop = as<IRLoop>(block->getTerminator()))
+ {
+ if (auto diffDecor = diffLoop->findDecoration<IRDifferentialInstDecoration>())
+ {
+ mapPrimalLoopToDiffLoop[as<IRLoop>(diffDecor->getPrimalInst())] = diffLoop;
+ }
+ }
+ }
+ }
+
+ IRBuilder builder(func);
+ Dictionary<IRBlock*, IRBlock*> recomputeBlockMap;
+
+ // Create the first recompute block right before the first diff block,
+ // and change all jumps into the diff block to the recompute block instead.
+ auto createRecomputeBlock = [&](IRBlock* primalBlock)
+ {
+ auto recomputeBlock = builder.createBlock();
+ recomputeBlock->insertAtEnd(func);
+ builder.addDecoration(recomputeBlock, kIROp_RecomputeBlockDecoration);
+ recomputeBlockMap.Add(primalBlock, recomputeBlock);
+ indexedBlockInfo[recomputeBlock] = indexedBlockInfo[primalBlock].GetValue();
+ return recomputeBlock;
+ };
+
+ auto firstRecomputeBlock = createRecomputeBlock(func->getFirstBlock());
+ firstRecomputeBlock->insertBefore(firstDiffBlock);
+ moveParams(firstRecomputeBlock, firstDiffBlock);
+ firstDiffBlock->replaceUsesWith(firstRecomputeBlock);
+
+ struct WorkItem
+ {
+ // The first primal block in this region.
+ IRBlock* primalBlock;
+
+ // The recompute block created for the first primal block in this region.
+ IRBlock* recomptueBlock;
+
+ // The end of primal block in tihs region.
+ IRBlock* regionEndBlock;
+
+ // The first diff block in this region.
+ IRBlock* firstDiffBlock;
+ };
+
+ List<WorkItem> workList;
+ WorkItem firstWorkItem = { func->getFirstBlock(), firstRecomputeBlock, firstRecomputeBlock, firstDiffBlock };
+ workList.add(firstWorkItem);
+
+ IRCloneEnv recomputeCloneEnv;
+ recomputeBlockMap[func->getFirstBlock()] = firstRecomputeBlock;
+
+ for (Index i = 0; i < workList.getCount(); i++)
+ {
+ auto workItem = workList[i];
+ auto primalBlock = workItem.primalBlock;
+ auto recomputeBlock = workItem.recomptueBlock;
+
+ List<IndexTrackingInfo>* thisBlockIndexInfo = indexedBlockInfo.TryGetValue(primalBlock);
+ if (!thisBlockIndexInfo)
+ continue;
+
+ builder.setInsertInto(recomputeBlock);
+ if (auto subRegionEndBlock = tryGetSubRegionEndBlock(primalBlock->getTerminator()))
+ {
+ // The terminal inst of primalBlock marks the start of a sub loop region?
+ // We need to queue work for both the next region after the loop at the current level,
+ // and for the sub region for the next level.
+ if (subRegionEndBlock == workItem.regionEndBlock)
+ {
+ // We have reached the end of top-level region, jump to first diff block.
+ builder.emitBranch(workItem.firstDiffBlock);
+ }
+ else
+ {
+ // Have we already created a recompute block for this target?
+ // If so, use it.
+ IRBlock* existingRecomputeBlock = nullptr;
+ if (recomputeBlockMap.TryGetValue(subRegionEndBlock, existingRecomputeBlock))
+ {
+ builder.emitBranch(existingRecomputeBlock);
+ }
+ else
+ {
+ // Queue work for the next region after the subregion at this level.
+ auto nextRegionRecomputeBlock = createRecomputeBlock(subRegionEndBlock);
+ nextRegionRecomputeBlock->insertAfter(recomputeBlock);
+ builder.emitBranch(nextRegionRecomputeBlock);
+
+ {
+ WorkItem newWorkItem = {
+ subRegionEndBlock,
+ nextRegionRecomputeBlock,
+ workItem.regionEndBlock,
+ workItem.firstDiffBlock };
+ workList.add(newWorkItem);
+ }
+ }
+ }
+ // Queue work for the subregion.
+ auto loop = as<IRLoop>(primalBlock->getTerminator());
+ auto bodyBlock = getLoopRegionBodyBlock(loop);
+ auto diffLoop = mapPrimalLoopToDiffLoop[loop].GetValue();
+ auto diffBodyBlock = getLoopRegionBodyBlock(diffLoop);
+ auto bodyRecomputeBlock = createRecomputeBlock(bodyBlock);
+ bodyRecomputeBlock->insertBefore(diffBodyBlock);
+ diffBodyBlock->replaceUsesWith(bodyRecomputeBlock);
+ moveParams(bodyRecomputeBlock, diffBodyBlock);
+ {
+ // After CFG normalization, the loop body will contain only jumps to the
+ // beginning of the loop.
+ // If we see such a jump, it means we have reached the end of current
+ // region in the loop.
+ // Therefore, we set the regionEndBlock for the sub-region as loop's target
+ // block.
+ WorkItem newWorkItem = {
+ bodyBlock, bodyRecomputeBlock, loop->getTargetBlock(), diffBodyBlock};
+ workList.add(newWorkItem);
+ }
+ }
+ else
+ {
+ // This is a normal control flow, just copy the CFG structure.
+ auto terminator = primalBlock->getTerminator();
+ IRInst* newTerminator = nullptr;
+ switch (terminator->getOp())
+ {
+ case kIROp_Switch:
+ case kIROp_ifElse:
+ newTerminator = cloneInst(&recomputeCloneEnv, &builder, primalBlock->getTerminator());
+ break;
+ case kIROp_unconditionalBranch:
+ newTerminator = builder.emitBranch(as<IRUnconditionalBranch>(terminator)->getTargetBlock());
+ break;
+ default:
+ SLANG_UNREACHABLE("terminator type");
+ }
+
+ // Modify jump targets in newTerminator to point to the right recompute block or firstDiffBlock.
+ for (UInt op = 0; op < newTerminator->getOperandCount(); op++)
+ {
+ auto target = as<IRBlock>(newTerminator->getOperand(op));
+ if (!target)
+ continue;
+ if (target == workItem.regionEndBlock)
+ {
+ // This jump target is the end of the current region, we will jump to
+ // firstDiffBlock instead.
+ newTerminator->setOperand(op, workItem.firstDiffBlock);
+ continue;
+ }
+
+ // Have we already created a recompute block for this target?
+ // If so, use it.
+ IRBlock* existingRecomputeBlock = nullptr;
+ if (recomputeBlockMap.TryGetValue(target, existingRecomputeBlock))
+ {
+ newTerminator->setOperand(op, existingRecomputeBlock);
+ continue;
+ }
+
+ // This jump target is a normal part of control flow, clone the next block.
+ auto targetRecomputeBlock = createRecomputeBlock(target);
+ targetRecomputeBlock->insertBefore(workItem.firstDiffBlock);
+
+ newTerminator->setOperand(op, targetRecomputeBlock);
+
+ // Queue work for the successor.
+ WorkItem newWorkItem = {
+ target,
+ targetRecomputeBlock,
+ workItem.regionEndBlock,
+ workItem.firstDiffBlock};
+ workList.add(newWorkItem);
+ }
+ }
+ }
+ // After this pass, all primal blocks except the condition block and the false block of a loop
+ // will have a corresponding recomputeBlock.
+ return recomputeBlockMap;
+}
+
+RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
+ IRGlobalValueWithCode* func,
+ Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock)
{
RefPtr<CheckpointSetInfo> checkpointInfo = new CheckpointSetInfo();
@@ -29,11 +287,11 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalVal
UIndex opIndex = 0;
for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); operand++, opIndex++)
{
- if (!operand->get()->findDecoration<IRDifferentialInstDecoration>() &&
+ if (!isDifferentialInst(operand->get()) &&
!as<IRFunc>(operand->get()) &&
!as<IRBlock>(operand->get()) &&
!(as<IRModuleInst>(operand->get()->getParent())) &&
- !getBlock(operand->get())->findDecoration<IRDifferentialInstDecoration>())
+ !isDifferentialBlock(getBlock(operand->get())))
workList.add(operand);
}
@@ -44,7 +302,7 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalVal
//
if (inst->getDataType() && (getParentFunc(inst->getDataType()) == func))
{
- if (!getBlock(inst->getDataType())->findDecoration<IRDifferentialInstDecoration>())
+ if (!isDifferentialBlock(getBlock(inst->getDataType())))
workList.add(&inst->typeUse);
}
};
@@ -58,7 +316,7 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalVal
if (block == func->getFirstBlock())
continue;
- if (!block->findDecoration<IRDifferentialInstDecoration>())
+ if (!isDifferentialBlock(block))
continue;
for (auto child : block->getChildren())
@@ -111,7 +369,7 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalVal
SLANG_ASSERT(!checkpointInfo->storeSet.Contains(result.instToRecompute));
checkpointInfo->recomputeSet.Add(result.instToRecompute);
- if (use->getUser()->findDecoration<IRDifferentialInstDecoration>())
+ if (isDifferentialInst(use->getUser()))
usesToReplace.Add(use);
if (auto param = as<IRParam>(result.instToRecompute))
@@ -160,7 +418,7 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalVal
SLANG_RELEASE_ASSERT(containsOperand(instToInvert, use->getUser()));
SLANG_RELEASE_ASSERT(result.inversionInfo.targetInsts.contains(use->getUser()));
- if (use->getUser()->findDecoration<IRDifferentialInstDecoration>())
+ if (isDifferentialInst(use->getUser()))
usesToReplace.Add(use);
checkpointInfo->invertSet.Add(instToInvert);
@@ -178,7 +436,55 @@ RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(IRGlobalVal
}
}
- return applyCheckpointSet(checkpointInfo, func, splitInfo, usesToReplace);
+ // If a var or call is in recomputeSet, move any var/calls associated with the same call to
+ // recomputeSet.
+ List<IRInst*> instWorkList;
+ HashSet<IRInst*> instWorkListSet;
+ for (auto inst : checkpointInfo->recomputeSet)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_Call:
+ case kIROp_Var:
+ instWorkList.add(inst);
+ instWorkListSet.add(inst);
+ break;
+ }
+ }
+ for (Index i = 0; i < instWorkList.getCount(); i++)
+ {
+ auto inst = instWorkList[i];
+ if (auto var = as<IRVar>(inst))
+ {
+ for (auto use = var->firstUse; use; use = use->nextUse)
+ {
+ auto callUser = as<IRCall>(use->getUser());
+ if (!callUser)
+ continue;
+ checkpointInfo->recomputeSet.add(callUser);
+ checkpointInfo->storeSet.Remove(callUser);
+ if (instWorkListSet.add(callUser))
+ instWorkList.add(callUser);
+ }
+ }
+ else if (auto call = as<IRCall>(inst))
+ {
+ for (UInt j = 0; j < call->getArgCount(); j++)
+ {
+ if (auto varArg = as<IRVar>(call->getArg(j)))
+ {
+ checkpointInfo->recomputeSet.add(varArg);
+ checkpointInfo->storeSet.Remove(varArg);
+ if (instWorkListSet.add(varArg))
+ instWorkList.add(varArg);
+ }
+ }
+ }
+ }
+
+ RefPtr<HoistedPrimalsInfo> hoistInfo = new HoistedPrimalsInfo();
+ applyCheckpointSet(checkpointInfo, func, hoistInfo, usesToReplace, mapDiffBlockToRecomputeBlock);
+ return hoistInfo;
}
void applyToInst(
@@ -195,6 +501,11 @@ void applyToInst(
return;
}
+ if (hoistInfo->ignoreSet.Contains(inst))
+ {
+ return;
+ }
+
bool isInstRecomputed = checkpointInfo->recomputeSet.Contains(inst);
if (isInstRecomputed)
{
@@ -242,13 +553,15 @@ void applyToInst(
}
}
-RefPtr<HoistedPrimalsInfo> applyCheckpointSet(
+void applyCheckpointSet(
CheckpointSetInfo* checkpointInfo,
IRGlobalValueWithCode* func,
- BlockSplitInfo* splitInfo,
- HashSet<IRUse*> pendingUses)
+ HoistedPrimalsInfo* hoistInfo,
+ HashSet<IRUse*> pendingUses,
+ Dictionary<IRBlock*, IRBlock*>& mapPrimalBlockToRecomputeBlock)
{
- RefPtr<HoistedPrimalsInfo> hoistInfo = new HoistedPrimalsInfo();
+ // Reconstruct diff block map.
+ Dictionary<IRBlock*, IRBlock*> diffBlockMap = reconstructDiffBlockMap(func);
RefPtr<IROutOfOrderCloneContext> cloneCtx = new IROutOfOrderCloneContext();
@@ -264,7 +577,7 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointSet(
UIndex opIndex = 0;
for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); operand++, opIndex++)
{
- if (!operand->get()->findDecoration<IRDifferentialInstDecoration>())
+ if (!isDifferentialInst(operand->get()))
cloneCtx->pendingUses.Add(operand);
}
};
@@ -276,15 +589,15 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointSet(
if (block == func->getFirstBlock())
continue;
- if (block->findDecoration<IRDifferentialInstDecoration>())
+ if (isDifferentialBlock(block))
+ continue;
+
+ if (block->findDecoration<IRRecomputeBlockDecoration>())
continue;
- auto diffBlock = as<IRBlock>(splitInfo->diffBlockMap[block]);
-
- auto firstDiffInst = as<IRBlock>(splitInfo->diffBlockMap[block])->getFirstOrdinaryInst();
+ auto diffBlock = as<IRBlock>(diffBlockMap[block]);
IRBuilder builder(func->getModule());
-
UIndex ii = 0;
for (auto param : block->getParams())
{
@@ -302,48 +615,58 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointSet(
predecessorSet.Add(predecessor);
- auto diffPredecessor = as<IRBlock>(splitInfo->diffBlockMap[block]);
+ auto diffPredecessor = as<IRBlock>(diffBlockMap[block]);
if (checkpointInfo->recomputeSet.Contains(param))
+ {
+ IRInst* terminator = diffPredecessor->getTerminator();
addPhiOutputArg(&builder,
diffPredecessor,
+ terminator,
as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(ii));
+ }
if (checkpointInfo->invertSet.Contains(param))
+ {
+ IRInst* terminator = diffPredecessor->getTerminator();
+
addPhiOutputArg(&builder,
diffPredecessor,
+ terminator,
as<IRUnconditionalBranch>(predecessor->getTerminator())->getArg(ii));
+ }
}
ii++;
}
+ IRBlock* recomputeBlock = block;
+ mapPrimalBlockToRecomputeBlock.TryGetValue(block, recomputeBlock);
+ auto recomputeInsertBeforeInst = recomputeBlock->getFirstOrdinaryInst();
+
for (auto child : block->getChildren())
{
- builder.setInsertBefore(firstDiffInst);
-
+ builder.setInsertBefore(recomputeInsertBeforeInst);
applyToInst(&builder, checkpointInfo, hoistInfo, cloneCtx, child);
}
}
-
- return hoistInfo;
}
IRType* getTypeForLocalStorage(
IRBuilder* builder,
IRType* storageType,
- List<IndexTrackingInfo*> defBlockIndices)
+ const List<IndexTrackingInfo>& defBlockIndices)
{
- for (auto index : defBlockIndices)
+ for (auto& index : defBlockIndices)
{
- SLANG_ASSERT(index->status == IndexTrackingInfo::CountStatus::Static);
- SLANG_ASSERT(index->maxIters >= 0);
+ SLANG_ASSERT(index.status == IndexTrackingInfo::CountStatus::Static);
+ SLANG_ASSERT(index.maxIters >= 0);
storageType = builder->getArrayType(
storageType,
builder->getIntValue(
builder->getUIntType(),
- index->maxIters + 1));
+ index.maxIters + 1));
}
return storageType;
@@ -352,7 +675,7 @@ IRType* getTypeForLocalStorage(
IRVar* emitIndexedLocalVar(
IRBlock* varBlock,
IRType* baseType,
- List<IndexTrackingInfo*> defBlockIndices)
+ const List<IndexTrackingInfo>& defBlockIndices)
{
SLANG_RELEASE_ASSERT(!as<IRPtrTypeBase>(baseType));
@@ -370,19 +693,19 @@ IRVar* emitIndexedLocalVar(
IRInst* emitIndexedStoreAddressForVar(
IRBuilder* builder,
IRVar* localVar,
- List<IndexTrackingInfo*> defBlockIndices)
+ const List<IndexTrackingInfo>& defBlockIndices)
{
IRInst* storeAddr = localVar;
IRType* currType = as<IRPtrTypeBase>(localVar->getDataType())->getValueType();
- for (auto index : defBlockIndices)
+ for (auto& index : defBlockIndices)
{
currType = as<IRArrayType>(currType)->getElementType();
storeAddr = builder->emitElementAddress(
builder->getPtrType(currType),
storeAddr,
- index->primalCountParam);
+ index.primalCountParam);
}
return storeAddr;
@@ -392,8 +715,8 @@ IRInst* emitIndexedStoreAddressForVar(
IRInst* emitIndexedLoadAddressForVar(
IRBuilder* builder,
IRVar* localVar,
- List<IndexTrackingInfo*> defBlockIndices,
- List<IndexTrackingInfo*> useBlockIndices)
+ const List<IndexTrackingInfo>& defBlockIndices,
+ const List<IndexTrackingInfo>& useBlockIndices)
{
IRInst* loadAddr = localVar;
IRType* currType = as<IRPtrTypeBase>(localVar->getDataType())->getValueType();
@@ -406,7 +729,7 @@ IRInst* emitIndexedLoadAddressForVar(
// If the use-block is under the same region, use the
// differential counter variable
//
- auto diffCounterCurrValue = index->diffCountParam;
+ auto diffCounterCurrValue = index.diffCountParam;
loadAddr = builder->emitElementAddress(
builder->getPtrType(currType),
@@ -418,7 +741,7 @@ IRInst* emitIndexedLoadAddressForVar(
// If the use-block is outside this region, use the
// last available value (by indexing with primal counter minus 1)
//
- auto primalCounterCurrValue = builder->emitLoad(index->primalCountLastVar);
+ auto primalCounterCurrValue = index.primalCountParam;
auto primalCounterLastValue = builder->emitSub(
primalCounterCurrValue->getDataType(),
primalCounterCurrValue,
@@ -438,7 +761,7 @@ IRVar* storeIndexedValue(
IRBuilder* builder,
IRBlock* defaultVarBlock,
IRInst* instToStore,
- List<IndexTrackingInfo*> defBlockIndices)
+ const List<IndexTrackingInfo>& defBlockIndices)
{
IRVar* localVar = emitIndexedLocalVar(defaultVarBlock, instToStore->getDataType(), defBlockIndices);
@@ -452,8 +775,8 @@ IRVar* storeIndexedValue(
IRInst* loadIndexedValue(
IRBuilder* builder,
IRVar* localVar,
- List<IndexTrackingInfo*> defBlockIndices,
- List<IndexTrackingInfo*> useBlockIndices)
+ const List<IndexTrackingInfo>& defBlockIndices,
+ const List<IndexTrackingInfo>& useBlockIndices)
{
IRInst* addr = emitIndexedLoadAddressForVar(builder, localVar, defBlockIndices, useBlockIndices);
@@ -461,15 +784,15 @@ IRInst* loadIndexedValue(
}
bool areIndicesEqual(
- List<IndexTrackingInfo*> indicesA,
- List<IndexTrackingInfo*> indicesB)
+ const List<IndexTrackingInfo>& indicesA,
+ const List<IndexTrackingInfo>& indicesB)
{
if (indicesA.getCount() != indicesB.getCount())
return false;
for (Index ii = 0; ii < indicesA.getCount(); ii++)
{
- if (indicesA[ii] != indicesB[ii])
+ if (indicesA[ii].primalCountParam != indicesB[ii].primalCountParam)
return false;
}
@@ -477,31 +800,37 @@ bool areIndicesEqual(
}
bool areIndicesSubsetOf(
- List<IndexTrackingInfo*> indicesA,
- List<IndexTrackingInfo*> indicesB)
+ List<IndexTrackingInfo>& indicesA,
+ List<IndexTrackingInfo>& indicesB)
{
if (indicesA.getCount() > indicesB.getCount())
return false;
for (Index ii = 0; ii < indicesA.getCount(); ii++)
{
- if (indicesA[ii] != indicesB[ii])
+ if (indicesA[ii].primalCountParam != indicesB[ii].primalCountParam)
return false;
}
return true;
}
-
-bool isDifferentialBlock(IRBlock* block)
+static int getInstRegionNestLevel(
+ Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo,
+ IRBlock* defBlock,
+ IRInst* inst)
{
- return block->findDecoration<IRDifferentialInstDecoration>();
+ auto result = indexedBlockInfo[defBlock].GetValue().getCount();
+ // Loop counters are considered to not belong to the region started by the its loop.
+ if (result > 0 && inst->findDecoration<IRLoopCounterDecoration>())
+ result--;
+ return (int)result;
}
RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
HoistedPrimalsInfo* hoistInfo,
IRGlobalValueWithCode* func,
- Dictionary<IRBlock*, List<IndexTrackingInfo*>> indexedBlockInfo)
+ Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo)
{
RefPtr<IRDominatorTree> domTree = computeDominatorTree(func);
@@ -510,129 +839,369 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
SLANG_ASSERT(!isDifferentialBlock(defaultVarBlock));
- HashSet<IRInst*> processedStoreSet;
+ OrderedHashSet<IRInst*> processedStoreSet;
- // TODO: Also ensure availability of everything in the recompute set (for proper recompute support)
- for (auto instToStore : hoistInfo->storeSet)
+ auto ensureInstAvailable = [&](OrderedHashSet<IRInst*>& instSet)
{
- IRBlock* defBlock = nullptr;
- if (auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType()))
+ for (auto instToStore : instSet)
{
- auto varInst = as<IRVar>(instToStore);
- auto storeUse = findUniqueStoredVal(varInst);
+ if (!instSet.Contains(instToStore))
+ continue;
- defBlock = getBlock(storeUse->getUser());
- }
- else
- defBlock = getBlock(instToStore);
+ if (hoistInfo->ignoreSet.Contains(instToStore))
+ continue;
+ IRBlock* defBlock = nullptr;
+ if (auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType()))
+ {
+ auto varInst = as<IRVar>(instToStore);
+ auto storeUse = findUniqueStoredVal(varInst);
- SLANG_RELEASE_ASSERT(defBlock);
+ defBlock = getBlock(storeUse->getUser());
+ }
+ else
+ defBlock = getBlock(instToStore);
- List<IRUse*> outOfScopeUses;
- for (auto use = instToStore->firstUse; use;)
- {
- auto nextUse = use->nextUse;
-
- // Only consider uses in differential blocks.
- // This method is not responsible for other blocks.
- //
- IRBlock* userBlock = getBlock(use->getUser());
- if (userBlock->findDecoration<IRDifferentialInstDecoration>())
+ SLANG_RELEASE_ASSERT(defBlock);
+
+ List<IRUse*> outOfScopeUses;
+ for (auto use = instToStore->firstUse; use;)
{
- if (!domTree->dominates(defBlock, userBlock))
+ auto nextUse = use->nextUse;
+
+ // Only consider uses in differential blocks.
+ // This method is not responsible for other blocks.
+ //
+ IRBlock* userBlock = getBlock(use->getUser());
+ if (isDifferentialOrRecomputeBlock(userBlock))
+ {
+ if (!domTree->dominates(defBlock, userBlock))
+ {
+ outOfScopeUses.add(use);
+ }
+ else if (!areIndicesSubsetOf(indexedBlockInfo[defBlock], indexedBlockInfo[userBlock]))
+ {
+ outOfScopeUses.add(use);
+ }
+ else if (getInstRegionNestLevel(indexedBlockInfo, defBlock, instToStore) > 0 &&
+ !isDifferentialOrRecomputeBlock(defBlock))
+ {
+ outOfScopeUses.add(use);
+ }
+ else if (as<IRPtrTypeBase>(instToStore->getDataType()) &&
+ !isDifferentialOrRecomputeBlock(defBlock))
+ {
+ outOfScopeUses.add(use);
+ }
+ }
+
+ use = nextUse;
+ }
+
+ if (outOfScopeUses.getCount() == 0)
+ {
+ processedStoreSet.Add(instToStore);
+ continue;
+ }
+
+ if (auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType()))
+ {
+
+ IRVar* varToStore = as<IRVar>(instToStore);
+ SLANG_RELEASE_ASSERT(varToStore);
+
+ auto storeUse = findUniqueStoredVal(varToStore);
+
+ List<IndexTrackingInfo>& defBlockIndices = indexedBlockInfo[defBlock];
+
+ bool isIndexedStore = (storeUse && defBlockIndices.getCount() > 0);
+
+ // TODO: There's a slight hackiness here. (Ideally we might just want to emit
+ // additional vars when splitting a call)
+ //
+ if (!isIndexedStore && isDerivativeContextVar(varToStore))
{
- outOfScopeUses.add(use);
+ varToStore->insertBefore(defaultVarBlock->getFirstOrdinaryInst());
+ processedStoreSet.Add(varToStore);
+ continue;
}
- else if (!areIndicesSubsetOf(indexedBlockInfo[defBlock], indexedBlockInfo[userBlock]))
+
+ setInsertAfterOrdinaryInst(&builder, getInstInBlock(storeUse->getUser()));
+
+ IRVar* localVar = storeIndexedValue(
+ &builder,
+ defaultVarBlock,
+ builder.emitLoad(varToStore),
+ defBlockIndices);
+
+ for (auto use : outOfScopeUses)
{
- outOfScopeUses.add(use);
+ setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser()));
+
+ List<IndexTrackingInfo>& useBlockIndices = indexedBlockInfo[getBlock(use->getUser())];
+
+ IRInst* loadAddr = emitIndexedLoadAddressForVar(&builder, localVar, defBlockIndices, useBlockIndices);
+ builder.replaceOperand(use, loadAddr);
}
- else if (indexedBlockInfo[defBlock].GetValue().getCount() > 0 &&
- !isDifferentialBlock(defBlock))
+
+ processedStoreSet.Add(localVar);
+ }
+ else
+ {
+ // Handle the special case of loop counters.
+ // The only case where there will be a reference of primal loop counter from rev blocks
+ // is the start of a loop in the reverse code. Since loop counters are not considered a
+ // part of their loop region, so we remove the first index info.
+ List<IndexTrackingInfo> defBlockIndices = indexedBlockInfo[defBlock];
+ bool isLoopCounter = (instToStore->findDecoration<IRLoopCounterDecoration>() != nullptr);
+ if (isLoopCounter)
{
- outOfScopeUses.add(use);
+ defBlockIndices.removeAt(0);
}
- else if (as<IRPtrTypeBase>(instToStore->getDataType()) &&
- !isDifferentialBlock(defBlock))
+
+ setInsertAfterOrdinaryInst(&builder, instToStore);
+ auto localVar = storeIndexedValue(&builder, defaultVarBlock, instToStore, defBlockIndices);
+
+ for (auto use : outOfScopeUses)
{
- outOfScopeUses.add(use);
+ List<IndexTrackingInfo> useBlockIndices = indexedBlockInfo[getBlock(use->getUser())];
+ if (isLoopCounter)
+ {
+ // The use site of a primal loop counter should be right before we enter the
+ // loop, and therefore its index count should equal to defBlockIndices.getCount()
+ // after we remove the first index from defBlockIndices.
+ SLANG_RELEASE_ASSERT(useBlockIndices.getCount() == defBlockIndices.getCount());
+ }
+ setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser()));
+ builder.replaceOperand(use, loadIndexedValue(&builder, localVar, defBlockIndices, useBlockIndices));
}
- }
- use = nextUse;
+ processedStoreSet.Add(localVar);
+ }
}
+ };
- if (outOfScopeUses.getCount() == 0)
- {
- processedStoreSet.Add(instToStore);
- continue;
- }
+ ensureInstAvailable(hoistInfo->storeSet);
+
+ // Replace the old store set with the processed one.
+ hoistInfo->storeSet = processedStoreSet;
- if (auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType()))
- {
+ return hoistInfo;
+}
- IRVar* varToStore = as<IRVar>(instToStore);
- SLANG_RELEASE_ASSERT(varToStore);
-
- auto storeUse = findUniqueStoredVal(varToStore);
-
- List<IndexTrackingInfo*> defBlockIndices = indexedBlockInfo[defBlock];
- bool isIndexedStore = (storeUse && defBlockIndices.getCount() > 0);
+void tryInferMaxIndex(IndexedRegion* region, IndexTrackingInfo* info)
+{
+ if (info->status != IndexTrackingInfo::CountStatus::Unresolved)
+ return;
- // TODO: There's a slight hackiness here. (Ideally we might just want to emit
- // additional vars when splitting a call)
- //
- if (!isIndexedStore && isDerivativeContextVar(varToStore))
- {
- varToStore->insertBefore(defaultVarBlock->getFirstOrdinaryInst());
- processedStoreSet.Add(varToStore);
- continue;
- }
+ auto loop = as<IRLoop>(region->getInitializerBlock()->getTerminator());
+
+ if (auto maxItersDecoration = loop->findDecoration<IRLoopMaxItersDecoration>())
+ {
+ info->maxIters = (Count)maxItersDecoration->getMaxIters();
+ info->status = IndexTrackingInfo::CountStatus::Static;
+ }
+}
- setInsertAfterOrdinaryInst(&builder, getInstInBlock(storeUse->getUser()));
- IRVar* localVar = storeIndexedValue(
- &builder,
- defaultVarBlock,
- builder.emitLoad(varToStore),
- defBlockIndices);
+IRInst* addPhiInputParam(IRBuilder* builder, IRBlock* block, IRType* type)
+{
+ builder->setInsertInto(block);
+ return builder->emitParam(type);
+}
- for (auto use : outOfScopeUses)
- {
- setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser()));
-
- List<IndexTrackingInfo*> useBlockIndices = indexedBlockInfo[getBlock(use->getUser())];
+IRInst* addPhiInputParam(IRBuilder* builder, IRBlock* block, IRType* type, UIndex index)
+{
+ List<IRParam*> params;
+ for (auto param : block->getParams())
+ params.add(param);
- IRInst* loadAddr = emitIndexedLoadAddressForVar(&builder, localVar, defBlockIndices, useBlockIndices);
- builder.replaceOperand(use, loadAddr);
- }
+ SLANG_RELEASE_ASSERT(index == (UCount)params.getCount());
- processedStoreSet.Add(localVar);
- }
- else
- {
- setInsertAfterOrdinaryInst(&builder, instToStore);
+ return addPhiInputParam(builder, block, type);
+}
- List<IndexTrackingInfo*> defBlockIndices = indexedBlockInfo[defBlock];
- auto localVar = storeIndexedValue(&builder, defaultVarBlock, instToStore, defBlockIndices);
-
- for (auto use : outOfScopeUses)
- {
- setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser()));
+static IRBlock* getUpdateBlock(IRLoop* loop)
+{
+ auto initBlock = cast<IRBlock>(loop->getParent());
- List<IndexTrackingInfo*> useBlockIndices = indexedBlockInfo[getBlock(use->getUser())];
- builder.replaceOperand(use, loadIndexedValue(&builder, localVar, defBlockIndices, useBlockIndices));
- }
+ auto condBlock = loop->getTargetBlock();
+
+ IRBlock* lastLoopBlock = nullptr;
+
+ for (auto predecessor : condBlock->getPredecessors())
+ {
+ if (predecessor != initBlock)
+ lastLoopBlock = predecessor;
+ }
+
+ // Should find atleast one predecessor that is _not_ the
+ // init block (that contains the loop info). This
+ // predecessor would be the last block in the loop
+ // before looping back to the condition.
+ //
+ SLANG_RELEASE_ASSERT(lastLoopBlock);
+
+ return lastLoopBlock;
+}
+
+void lowerIndexedRegion(IRLoop*& primalLoop, IRLoop*& diffLoop, IRInst*& primalCountParam, IRInst*& diffCountParam)
+{
+ IRBuilder builder(primalLoop);
+ primalCountParam = nullptr;
+
+ // Grab first primal block.
+ IRBlock* primalInitBlock = as<IRBlock>(primalLoop->getParent());
+ builder.setInsertBefore(primalInitBlock->getTerminator());
+ {
+ auto primalCondBlock = as<IRUnconditionalBranch>(
+ primalInitBlock->getTerminator())->getTargetBlock();
+ builder.setInsertBefore(primalInitBlock->getTerminator());
+
+ auto phiCounterArgLoopEntryIndex = addPhiOutputArg(
+ &builder,
+ primalInitBlock,
+ *(IRInst**)&primalLoop,
+ builder.getIntValue(builder.getIntType(), 0));
+
+ builder.setInsertBefore(primalCondBlock->getTerminator());
+ primalCountParam = addPhiInputParam(
+ &builder,
+ primalCondBlock,
+ builder.getIntType(),
+ phiCounterArgLoopEntryIndex);
+ builder.addLoopCounterDecoration(primalCountParam);
+ builder.addNameHintDecoration(primalCountParam, UnownedStringSlice("_pc"));
+ builder.markInstAsPrimal(primalCountParam);
+
+ IRBlock* primalUpdateBlock = getUpdateBlock(primalLoop);
+ IRInst* terminator = primalUpdateBlock->getTerminator();
+ builder.setInsertBefore(primalUpdateBlock->getTerminator());
+
+ auto incCounterVal = builder.emitAdd(
+ builder.getIntType(),
+ primalCountParam,
+ builder.getIntValue(builder.getIntType(), 1));
+ builder.markInstAsPrimal(incCounterVal);
+
+ auto phiCounterArgLoopCycleIndex = addPhiOutputArg(&builder, primalUpdateBlock, terminator, incCounterVal);
+
+ SLANG_RELEASE_ASSERT(phiCounterArgLoopEntryIndex == phiCounterArgLoopCycleIndex);
+ }
+
+ {
+ IRBlock* diffInitBlock = as<IRBlock>(diffLoop->getParent());
+
+ auto diffCondBlock = as<IRUnconditionalBranch>(
+ diffInitBlock->getTerminator())->getTargetBlock();
+ builder.setInsertBefore(diffInitBlock->getTerminator());
+ auto revCounterInitVal = builder.emitSub(
+ builder.getIntType(),
+ primalCountParam,
+ builder.getIntValue(builder.getIntType(), 1));
+ auto phiCounterArgLoopEntryIndex = addPhiOutputArg(
+ &builder,
+ diffInitBlock,
+ *(IRInst**)&diffLoop,
+ revCounterInitVal);
+
+ builder.setInsertBefore(diffCondBlock->getTerminator());
+
+ diffCountParam = addPhiInputParam(
+ &builder,
+ diffCondBlock,
+ builder.getIntType(),
+ phiCounterArgLoopEntryIndex);
+ builder.addNameHintDecoration(diffCountParam, UnownedStringSlice("_dc"));
+ builder.markInstAsPrimal(diffCountParam);
+
+ IRBlock* diffUpdateBlock = getUpdateBlock(diffLoop);
+ builder.setInsertBefore(diffUpdateBlock->getTerminator());
+ IRInst* terminator = diffUpdateBlock->getTerminator();
+
+ auto decCounterVal = builder.emitSub(
+ builder.getIntType(),
+ diffCountParam,
+ builder.getIntValue(builder.getIntType(), 1));
+ builder.markInstAsPrimal(decCounterVal);
+
+ auto phiCounterArgLoopCycleIndex = addPhiOutputArg(&builder, diffUpdateBlock, terminator, decCounterVal);
+
+ auto ifElse = as<IRIfElse>(diffCondBlock->getTerminator());
+ builder.setInsertBefore(ifElse);
+ auto exitCondition = builder.emitGeq(diffCountParam, builder.getIntValue(builder.getIntType(), 0));
+ ifElse->condition.set(exitCondition);
+
+ SLANG_RELEASE_ASSERT(phiCounterArgLoopEntryIndex == phiCounterArgLoopCycleIndex);
+ }
+}
+
+void buildIndexedBlocks(
+ Dictionary<IRBlock*, List<IndexTrackingInfo>>& info,
+ IRGlobalValueWithCode* func)
+{
+ Dictionary<IRLoop*, IndexTrackingInfo> mapLoopToTrackingInfo;
- processedStoreSet.Add(localVar);
+ for (auto block : func->getBlocks())
+ {
+ auto loop = as<IRLoop>(block->getTerminator());
+ if (!loop) continue;
+ auto diffDecor = loop->findDecoration<IRDifferentialInstDecoration>();
+ if (!diffDecor) continue;
+ auto primalLoop = as<IRLoop>(diffDecor->getPrimalInst());
+ if (!primalLoop) continue;
+
+ IndexTrackingInfo indexInfo = {};
+ lowerIndexedRegion(primalLoop, loop, indexInfo.primalCountParam, indexInfo.diffCountParam);
+
+ SLANG_RELEASE_ASSERT(indexInfo.primalCountParam);
+ SLANG_RELEASE_ASSERT(indexInfo.diffCountParam);
+
+ mapLoopToTrackingInfo[loop] = indexInfo;
+ mapLoopToTrackingInfo[primalLoop] = indexInfo;
+ }
+
+ auto regionMap = buildIndexedRegionMap(func);
+
+ for (auto block : func->getBlocks())
+ {
+ List<IndexTrackingInfo> trackingInfos;
+ for (auto region : regionMap->getAllAncestorRegions(block))
+ {
+ IndexTrackingInfo trackingInfo;
+ if (mapLoopToTrackingInfo.TryGetValue(region->loop, trackingInfo))
+ {
+ tryInferMaxIndex(region, &trackingInfo);
+ trackingInfos.add(trackingInfo);
+ }
}
+ info[block] = trackingInfos;
}
-
- // Replace the old store set with the processed onne one.
- hoistInfo->storeSet = processedStoreSet;
+}
- return hoistInfo;
+RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(
+ IRGlobalValueWithCode* func, const List<IRInst*>& instsToIgnore)
+{
+ sortBlocksInFunc(func);
+
+ Dictionary<IRBlock*, List<IndexTrackingInfo>> indexedBlockInfo;
+ buildIndexedBlocks(indexedBlockInfo, func);
+
+ auto recomputeBlockMap = createPrimalRecomputeBlocks(func, indexedBlockInfo);
+
+ sortBlocksInFunc(func);
+
+ RefPtr<AutodiffCheckpointPolicyBase> chkPolicy = new DefaultCheckpointPolicy(func->getModule());
+ chkPolicy->preparePolicy(func);
+
+ auto primalsInfo = chkPolicy->processFunc(func, recomputeBlockMap);
+
+ for (auto propagateFuncSpecificInst : instsToIgnore)
+ {
+ primalsInfo->ignoreSet.add(propagateFuncSpecificInst);
+ }
+ primalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo);
+ return primalsInfo;
}
void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode* func)
@@ -716,7 +1285,7 @@ static bool shouldStoreVar(IRVar* var)
{
for (UInt i = 0; i < spec->getArgCount(); i++)
{
- if (!canTypeBeStored(spec->getArg(i)->getDataType()))
+ if (!canTypeBeStored(spec->getArg(i)))
return false;
}
}
@@ -772,6 +1341,30 @@ static bool shouldStoreInst(IRInst* inst)
case kIROp_ExtractExistentialWitnessTable:
case kIROp_undefined:
case kIROp_GetSequentialID:
+ case kIROp_Specialize:
+ case kIROp_LookupWitness:
+#if 0
+ case kIROp_Add:
+ case kIROp_Sub:
+ case kIROp_Mul:
+ case kIROp_Div:
+ case kIROp_Neg:
+ case kIROp_Geq:
+ case kIROp_Leq:
+ case kIROp_Neq:
+ case kIROp_Eql:
+ case kIROp_Greater:
+ case kIROp_Less:
+ case kIROp_And:
+ case kIROp_Or:
+ case kIROp_Not:
+ case kIROp_BitNot:
+ case kIROp_BitAnd:
+ case kIROp_BitOr:
+ case kIROp_BitXor:
+ case kIROp_Lsh:
+ case kIROp_Rsh:
+#endif
return false;
case kIROp_GetElement:
case kIROp_FieldExtract:
@@ -791,6 +1384,9 @@ static bool shouldStoreInst(IRInst* inst)
break;
}
+ if (as<IRType>(inst))
+ return false;
+
// Only store if the inst has differential inst user.
bool hasDiffUser = doesInstHaveDiffUse(inst);
if (!hasDiffUser)
@@ -801,22 +1397,11 @@ static bool shouldStoreInst(IRInst* inst)
bool canRecompute(IRDominatorTree* domTree, IRUse* use)
{
+ SLANG_UNUSED(domTree);
auto param = as<IRParam>(use->get());
if (!param)
return true;
- auto paramBlock = as<IRBlock>(param->getParent());
- for (auto predecessor : paramBlock->getPredecessors())
- {
- // If we hit this, the checkpoint policy is trying to recompute
- // values across a loop region boundary (we don't currently support this,
- // and in general this is quite inefficient in both compute & memory)
- //
- if (domTree->dominates(paramBlock, predecessor))
- {
- return false;
- }
- }
- return true;
+ return false;
}
HoistResult DefaultCheckpointPolicy::classify(IRUse* use)
diff --git a/source/slang/slang-ir-autodiff-primal-hoist.h b/source/slang/slang-ir-autodiff-primal-hoist.h
index bd2575172..3b3fb82b1 100644
--- a/source/slang/slang-ir-autodiff-primal-hoist.h
+++ b/source/slang/slang-ir-autodiff-primal-hoist.h
@@ -7,7 +7,6 @@
#include "slang-ir-autodiff-region.h"
#include "slang-ir-dominators.h"
-
namespace Slang
{
struct IROutOfOrderCloneContext : public RefObject
@@ -84,11 +83,11 @@ namespace Slang
struct HoistedPrimalsInfo : public RefObject
{
- HashSet<IRInst*> storeSet;
- HashSet<IRInst*> recomputeSet;
- HashSet<IRInst*> invertSet;
-
- HashSet<IRInst*> instsToInvert;
+ OrderedHashSet<IRInst*> storeSet;
+ OrderedHashSet<IRInst*> recomputeSet;
+ OrderedHashSet<IRInst*> invertSet;
+ OrderedHashSet<IRInst*> ignoreSet;
+ OrderedHashSet<IRInst*> instsToInvert;
Dictionary<IRInst*, InversionInfo> invertInfoMap;
@@ -130,6 +129,9 @@ namespace Slang
for (auto inst : info->invertSet)
invertSet.Add(inst);
+ for (auto inst : info->ignoreSet)
+ ignoreSet.add(inst);
+
for (auto inst : info->instsToInvert)
instsToInvert.Add(inst);
@@ -195,6 +197,31 @@ namespace Slang
}
};
+ struct IndexTrackingInfo : public RefObject
+ {
+ // After lowering, store references to the count
+ // variables associated with this region
+ //
+ IRInst* primalCountParam = nullptr;
+ IRInst* diffCountParam = nullptr;
+
+ enum CountStatus
+ {
+ Unresolved,
+ Dynamic,
+ Static
+ };
+
+ CountStatus status = CountStatus::Unresolved;
+
+ // Inferred maximum number of iterations.
+ Count maxIters = -1;
+
+ bool operator==(const IndexTrackingInfo& other) const
+ {
+ return primalCountParam == other.primalCountParam;
+ }
+ };
// Information on which insts are to be stored, recomputed
// and inverted within a single function.
@@ -210,6 +237,15 @@ namespace Slang
Dictionary<IRInst*, InversionInfo> invInfoMap;
};
+ // Information on a block after it has been split in the unzip step.
+ // After unzipping, every block in the original function will have
+ // two corresponding blocks in the new function:
+ // - A 'primal-recompute' block, which contains the original instructions
+ // from the original block, but located in the corresponding the reverse
+ // diff region so their results are accessible in the diff block for
+ // derivative computation.
+ // - A 'diff' block, which contains the transcribed instructions from the
+ // original block.
struct BlockSplitInfo : public RefObject
{
// Maps primal to differential blocks from the unzip step.
@@ -223,7 +259,9 @@ namespace Slang
AutodiffCheckpointPolicyBase(IRModule* module) : module(module)
{ }
- RefPtr<HoistedPrimalsInfo> processFunc(IRGlobalValueWithCode* func, BlockSplitInfo* info);
+ RefPtr<HoistedPrimalsInfo> processFunc(
+ IRGlobalValueWithCode* func,
+ Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock);
// Do pre-processing on the function (mainly for
// 'global' checkpointing methods that consider the entire
@@ -252,15 +290,9 @@ namespace Slang
RefPtr<IRDominatorTree> domTree;
};
- RefPtr<HoistedPrimalsInfo> applyCheckpointSet(
- CheckpointSetInfo* checkpointInfo,
+ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(
IRGlobalValueWithCode* func,
- BlockSplitInfo* splitInfo,
- HashSet<IRUse*> pendingUses);
+ const List<IRInst*>& instsToIgnore);
- RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
- HoistedPrimalsInfo* hoistInfo,
- IRGlobalValueWithCode* func,
- Dictionary<IRBlock*, List<IndexTrackingInfo*>> indexedBlockInfo);
};
diff --git a/source/slang/slang-ir-autodiff-region.h b/source/slang/slang-ir-autodiff-region.h
index a4618e257..59a977619 100644
--- a/source/slang/slang-ir-autodiff-region.h
+++ b/source/slang/slang-ir-autodiff-region.h
@@ -50,29 +50,6 @@ struct IndexedRegion : public RefObject
}
};
-struct IndexTrackingInfo : public RefObject
-{
- // After lowering, store references to the count
- // variables associated with this region
- //
- IRInst* primalCountParam = nullptr;
- IRInst* diffCountParam = nullptr;
-
- IRVar* primalCountLastVar = nullptr;
-
- enum CountStatus
- {
- Unresolved,
- Dynamic,
- Static
- };
-
- CountStatus status = CountStatus::Unresolved;
-
- // Inferred maximum number of iterations.
- Count maxIters = -1;
-};
-
struct IndexedRegionMap : public RefObject
{
Dictionary<IRBlock*, IndexedRegion*> map;
@@ -116,4 +93,4 @@ struct IndexedRegionMap : public RefObject
RefPtr<IndexedRegionMap> buildIndexedRegionMap(IRGlobalValueWithCode* func);
-}; \ No newline at end of file
+};
diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp
index 0bdc4a935..979eb6343 100644
--- a/source/slang/slang-ir-autodiff-rev.cpp
+++ b/source/slang/slang-ir-autodiff-rev.cpp
@@ -679,10 +679,7 @@ namespace Slang
//
// diffPropagationPass->propagateDiffInstDecoration(builder, fwdDiffFunc);
- // Copy primal insts to the first block of the unzipped function, copy diff insts to the
- // second block of the unzipped function.
- //
- RefPtr<HoistedPrimalsInfo> primalsInfo = diffUnzipPass->unzipDiffInsts(fwdDiffFunc);
+ diffUnzipPass->unzipDiffInsts(fwdDiffFunc);
IRFunc* unzippedFwdDiffFunc = fwdDiffFunc;
// Move blocks from `unzippedFwdDiffFunc` to the `diffPropagateFunc` shell.
@@ -709,10 +706,17 @@ namespace Slang
// Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the
// derivative of the return value.
- DiffTransposePass::FuncTranspositionInfo transposeInfo = { paramTransposeInfo.dOutParam, primalsInfo };
+ DiffTransposePass::FuncTranspositionInfo transposeInfo = { paramTransposeInfo.dOutParam };
diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, transposeInfo);
+ // Apply checkpointing policy to legalize cross-scope uses of primal values
+ // using either recompute or store strategies.
+ auto primalsInfo = applyCheckpointPolicy(
+ diffPropagateFunc, paramTransposeInfo.propagateFuncSpecificPrimalInsts);
+
+
eliminateDeadCode(diffPropagateFunc);
+
// Extracts the primal computations into its own func, and replace the primal insts
// with the intermediate results computed from the extracted func.
@@ -907,6 +911,7 @@ namespace Slang
// after transposition.
auto tempVar = nextBlockBuilder.emitVar(diffType);
copyNameHintDecoration(tempVar, fwdParam);
+ result.propagateFuncSpecificPrimalInsts.add(tempVar);
// Initialize the var with input diff param at start.
// Note that we insert the store in the primal block so it won't get transposed.
@@ -993,9 +998,11 @@ namespace Slang
// of the differential component of the pair.
auto newParamLoad = diffBuilder.emitLoad(propParam);
diffBuilder.markInstAsDifferential(newParamLoad, primalType);
+ result.propagateFuncSpecificPrimalInsts.add(newParamLoad);
diffRefReplacement = diffBuilder.emitDifferentialPairGetDifferential(diffType, newParamLoad);
diffBuilder.markInstAsDifferential(diffRefReplacement, primalType);
+ result.propagateFuncSpecificPrimalInsts.add(diffRefReplacement);
// Load the primal component from the prop param and use it as replacement for the
// primal param in the primal part of the prop func.
@@ -1031,7 +1038,10 @@ namespace Slang
// Load the inital diff value.
auto loadedParam = nextBlockBuilder.emitLoad(diffParam);
+ result.propagateFuncSpecificPrimalInsts.add(loadedParam);
+
auto initDiff = nextBlockBuilder.emitDifferentialPairGetDifferential(diffType, loadedParam);
+ result.propagateFuncSpecificPrimalInsts.add(initDiff);
// Create a local var for diff read access.
auto diffVar = nextBlockBuilder.emitVar(diffType);
@@ -1047,6 +1057,7 @@ namespace Slang
// Create a local var for diff write access.
auto diffWriteVar = nextBlockBuilder.emitVar(diffType);
+ result.propagateFuncSpecificPrimalInsts.add(diffWriteVar);
copyNameHintDecoration(diffWriteVar, fwdParam);
// Initialize write var to 0.
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 8c005a5c6..c7ac8c357 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -80,11 +80,6 @@ struct DiffTransposePass
// of the *output* of the function.
//
IRInst* dOutInst;
-
- // Information from the unzip pass on how primal insts
- // are split across the primal and differential blocks.
- //
- HoistedPrimalsInfo* hoistedPrimalsInfo;
};
struct PendingBlockTerminatorEntry
@@ -235,16 +230,12 @@ struct DiffTransposePass
builder.setInsertInto(revCondBlock);
- //hoistPrimalInst(&builder, ifElse->getCondition());
-
- auto newIfElse = builder.emitIfElse(
+ builder.emitIfElse(
ifElse->getCondition(),
revTrueEntryBlock,
revFalseEntryBlock,
revAfterBlock);
- hoistPrimalOperands(&builder, newIfElse);
-
if (!revTrueRegionInfo.isTrivial)
{
builder.setInsertInto(revTrueExitBlock);
@@ -358,21 +349,21 @@ struct DiffTransposePass
// Emit condition into the new cond block.
builder.setInsertInto(revCondBlock);
- // TODO: Need to defer this until after the CFG reversal is complete.
- //hoistPrimalInst(&builder, ifElse->getCondition());
-
- auto newIfElse = builder.emitIfElse(
+ builder.emitIfElse(
ifElse->getCondition(),
revTrueBlock,
revFalseBlock,
revTrueBlock);
-
- hoistPrimalOperands(&builder, newIfElse);
-
+
+ auto loopParentBlockDiffDecor = loop->getParent()->findDecoration<IRDifferentialInstDecoration>();
+ SLANG_RELEASE_ASSERT(loopParentBlockDiffDecor);
+ auto primalBlock = as<IRBlock>(loopParentBlockDiffDecor->getPrimalInst());
+ auto primalLoop = as<IRLoop>(primalBlock->getTerminator());
+ SLANG_RELEASE_ASSERT(primalLoop);
+
// Old false-side starting block becomes end block
// for the new pre-cond region (which could be empty)
//
-
if (!falseRegionInfo.isTrivial)
{
IRBlock* revPreCondEndBlock = revBlockMap[falseBlock];
@@ -384,7 +375,8 @@ struct DiffTransposePass
getPhiGrads(falseBlock).getCount(),
getPhiGrads(falseBlock).getBuffer());
loop->transferDecorationsTo(revLoop);
-
+ builder.markInstAsDifferential(revLoop, builder.getVoidType(), primalLoop);
+
auto revLoopStartBlock = revBlockMap[breakBlock];
builder.setInsertInto(revLoopStartBlock);
builder.emitBranch(
@@ -404,6 +396,7 @@ struct DiffTransposePass
getPhiGrads(breakBlock).getCount(),
getPhiGrads(breakBlock).getBuffer());
loop->transferDecorationsTo(revLoop);
+ builder.markInstAsDifferential(revLoop, builder.getVoidType(), primalLoop);
}
currentBlock = breakBlock;
@@ -478,17 +471,13 @@ struct DiffTransposePass
builder.setInsertInto(revSwitchBlock);
- // hoistPrimalInst(&builder, switchInst->getCondition());
-
- auto newSwitchInst = builder.emitSwitch(
+ builder.emitSwitch(
switchInst->getCondition(),
revBreakBlock,
revDefaultRegionEntry,
reverseSwitchArgs.getCount(),
reverseSwitchArgs.getBuffer());
- hoistPrimalOperands(&builder, newSwitchInst);
-
currentBlock = breakBlock;
break;
}
@@ -525,9 +514,7 @@ struct DiffTransposePass
// (i.e. not store per-func info in 'this')
// since it is reused for every reverse-mode call.
//
-
- hoistedPrimalsInfo = transposeInfo.hoistedPrimalsInfo;
-
+ primalVarsToHoist.clear();
// Grab all differentiable type information.
diffTypeContext.setFunc(revDiffFunc);
@@ -576,8 +563,10 @@ struct DiffTransposePass
// Emit empty rev-mode blocks for every fwd-mode block.
for (auto block : workList)
{
- revBlockMap[block] = builder.emitBlock();
- builder.markInstAsDifferential(revBlockMap[block]);
+ auto revBlock = builder.emitBlock();
+ revBlockMap[block] = revBlock;
+ if (auto diffDecor = block->findDecoration<IRDifferentialInstDecoration>())
+ builder.markInstAsDifferential(revBlockMap[block], builder.getBasicBlockType(), diffDecor->getPrimalInst());
}
// Keep track of first diff block, since this is where
@@ -637,20 +626,6 @@ struct DiffTransposePass
auto firstFwdDiffBlock = branchInst->getTargetBlock();
reverseCFGRegion(firstFwdDiffBlock, List<IRBlock*>());
- // Lower any loop-exit-value decorations into initializations for loop intermediate vals,
- // and convert loop initial values into terminating conditions.
- //
- // TODO: We need a way to confirm that all required vars have an initial value
- // (is there a built-in dataflow tool for this?)
- //
- for (auto block : workList)
- {
- if (auto loopInst = as<IRLoop>(block->getTerminator()))
- {
- invertLoopCondition(&builder, loopInst);
- }
- }
-
// Link the last differential fwd-mode block (which will be the first
// rev-mode block) as the successor to the last primal block.
// We assume that the original function is in single-return form
@@ -688,43 +663,9 @@ struct DiffTransposePass
for (auto block : workList)
block->removeFromParent();
- // Mark all primal operands for hoisting.
- // TODO: Can we just merge this with finishHoistingPrimalInsts?
- // TODO: Some of this logic is replicated in finishHoistingPrimalInsts. Merge it with the
- // maybeAddOperandsToWorkList logic there.
- //
- for (auto block : workList)
- {
- IRBlock* revBlock = revBlockMap[block];
-
- for (auto child = revBlock->getFirstChild(); child; child = child->getNextInst())
- {
- hoistPrimalOperands(&builder, child);
-
- for (auto decoration = child->getFirstDecoration(); decoration; decoration = decoration->getNextDecoration())
- {
- if (auto contextDecoration = as<IRBackwardDerivativePrimalContextDecoration>(decoration))
- hoistPrimalUse(&builder, &contextDecoration->primalContextVar);
-
- if (auto loopExitDecoration = as<IRLoopExitPrimalValueDecoration>(decoration))
- hoistPrimalUse(&builder, &loopExitDecoration->exitVal);
- }
-
- if (auto instType = child->getDataType())
- if (!as<IRModuleInst>(instType->getParent()))
- hoistPrimalUse(&builder, &child->typeUse);
- }
- }
-
finishHoistingPrimals(revDiffFunc);
- for (auto block : workList)
- {
- auto revBlock = as<IRBlock>(revBlockMap[block]);
- if (auto revLoop = as<IRLoop>(revBlock->getTerminator()))
- lowerLoopExitValues(&builder, revLoop);
- }
-
+
// At this point, the only block left without terminator insts
// should be the last one. Add a void return to complete it.
//
@@ -793,51 +734,6 @@ struct DiffTransposePass
return tempRevVar;
}
- IRVar* lookupInverseVar(IRInst* inst)
- {
- return inverseVarMap[inst];
- }
-
- IRVar* getOrCreateInverseVar(IRInst* primalInst, IRGlobalValueWithCode* func)
- {
- IRBlock* varBlock = firstRevDiffBlockMap[func];
- return getOrCreateInverseVar(primalInst, varBlock);
- }
-
- IRVar* getOrCreateInverseVar(IRInst* primalInst)
- {
- IRBlock* varBlock = firstRevDiffBlockMap[as<IRFunc>(primalInst->getParent()->getParent())];
- return getOrCreateInverseVar(primalInst, varBlock);
- }
-
- IRVar* getOrCreateInverseVar(IRInst* primalInst, IRBlock* varBlock)
- {
- // No need to store inverse values for constants.
- if (as<IRConstant>(primalInst))
- return nullptr;
-
- // Check if we have a var already.
- if (inverseVarMap.ContainsKey(primalInst))
- return inverseVarMap[primalInst];
-
- IRBuilder tempVarBuilder(autodiffContext->moduleInst);
-
- if (auto firstInst = varBlock->getFirstOrdinaryInst())
- tempVarBuilder.setInsertBefore(firstInst);
- else
- tempVarBuilder.setInsertInto(varBlock);
-
- auto primalType = primalInst->getDataType();
-
- // Emit a var in the top-level differential block to hold the inverse,
- // and initialize it.
- auto tempInvVar = tempVarBuilder.emitVar(primalType);
-
- inverseVarMap[primalInst] = tempInvVar;
-
- return tempInvVar;
- }
-
bool isInstUsedOutsideParentBlock(IRInst* inst)
{
auto currBlock = inst->getParent();
@@ -900,37 +796,9 @@ struct DiffTransposePass
revParam,
nullptr));
}
- else if (hasInverse(arg))
- {
- InversionInfo invInfo = this->hoistedPrimalsInfo->invertInfoMap[branchInst];
- if (invInfo.targetInsts.contains(arg))
- {
- SLANG_ASSERT(hasInverse(getParamAt(branchInst->getTargetBlock(), ii)));
-
- // If the output arg is a primal, emit a parameter
- // to accept it as an _input_ for the reverse-mode
- //
- auto primalType = arg->getDataType();
- auto primalInvParam = builder.emitParam(primalType);
-
- invBuilder.setInsertBefore(branchInst);
- setInverse(&invBuilder, fwdBlock, builder.getFunc(), arg, primalInvParam);
- }
- }
else
{
- if (hasInverse(getParamAt(branchInst->getTargetBlock(), ii)))
- {
- auto primalType = arg->getDataType();
- auto primalInvParam = builder.emitParam(primalType);
-
- invBuilder.setInsertBefore(branchInst);
- setInverse(&invBuilder, fwdBlock, builder.getFunc(), arg, primalInvParam);
- }
- else
- {
- SLANG_UNEXPECTED("Encountered phi-param is not differential and is not marked for inversion");
- }
+ SLANG_UNEXPECTED("Encountered phi-param is not differential and is not marked for inversion");
}
}
}
@@ -989,15 +857,6 @@ struct DiffTransposePass
if (isDifferentialInst(child))
transposeInst(&builder, child);
- else if (shouldInstBeInverted(child))
- {
- // We'll collect inverse insts in an orphaned block,
- // so disable IR validation temporarily.
- //
- disableIRValidationAtInsert();
- invertInst(&invBuilder, child);
- enableIRValidationAtInsert();
- }
}
// After processing the block's instructions, we 'flush' any remaining gradients
@@ -1046,10 +905,6 @@ struct DiffTransposePass
emitDZeroOfDiffInstType(&builder, tryGetPrimalTypeFromDiffInst(param)));
}
}
- else if (hasInverse(param))
- {
- phiParamRevGradInsts.add(param);
- }
else
{
SLANG_UNEXPECTED("param is neither differential inst nor marked for inversion");
@@ -1169,46 +1024,6 @@ struct DiffTransposePass
}
}
- // NOTE: This is a workaround for the fact that we expect inverses to use
- // single-use variables. The loop exit value will add a
- // second store to most inv-variables and mess with the primal hoisting mechanism.
- // Instead of emitting into the orphaned inverse block, we'll directly emit into
- // the reverse-mode block since we'll be running this _after_ the primal hoisting
- // pass.
- //
- // This workaround is fine for inverting loop counters, but when we want to
- // expand to supporting general-purpose adjoints, we would want to use per-region
- // inverse vars based on 'invInfo' (enforcing single-use vars)
- //
- void lowerLoopExitValues(IRBuilder* builder, IRLoop* revLoop)
- {
- List<IRDecoration*> processedDecorations;
- for (auto decoration : revLoop->getDecorations())
- {
- if (auto loopExitValueDecoration = as<IRLoopExitPrimalValueDecoration>(decoration))
- {
- builder->setInsertBefore(revLoop);
- setInverse(
- builder,
- nullptr,
- builder->getFunc(),
- loopExitValueDecoration->getTargetInst(),
- loopExitValueDecoration->getLoopExitValInst());
-
- processedDecorations.add(loopExitValueDecoration);
- }
- }
-
- for (auto decoration : processedDecorations)
- decoration->removeAndDeallocate();
- }
-
- void lowerLoopExitValues(IRBuilder* builder, IRBlock* block)
- {
- if (auto loopInst = as<IRLoop>(block->getTerminator()))
- lowerLoopExitValues(builder, loopInst);
- }
-
// Go through loop block phi-args, and look for loop counter
// arguments, which for a loop means inserting a check into
// loop condition block.
@@ -1253,41 +1068,9 @@ struct DiffTransposePass
loopCounterParam,
loopCounterInitVal).getBuffer());
- hoistPrimalOperands(builder, paramBoundsCheck);
-
as<IRIfElse>(revLoopCondBlock->getTerminator())->condition.set(paramBoundsCheck);
}
- List<InvInstPair> invertInst(IRBuilder* builder, IRInst* primalInst, InversionInfo invInfo)
- {
- switch (primalInst->getOp())
- {
- case kIROp_Add:
- case kIROp_Sub:
- return invertArithmetic(builder, primalInst, invInfo);
-
- default:
- SLANG_UNIMPLEMENTED_X("Unhandled inst type for inversion");
- }
- }
-
- bool hasInverse(IRInst* primalInst)
- {
- return this->hoistedPrimalsInfo->invertSet.Contains(primalInst);
- }
-
- IRInst* loadInverse(IRBuilder* builder, IRInst* primalInst)
- {
- // Note: There are other possible cases here, although not important
- // right now. For example, a value is available to load from the primal block.
- //
-
- if (auto invVar = getOrCreateInverseVar(primalInst, builder->getFunc()))
- return builder->emitLoad(invVar);
-
- return nullptr;
- }
-
IRInst* lookupInstInPrimalBlock(IRInst* invInst)
{
// Lookup the inst in the primal block whose value we can use as an operand
@@ -1296,37 +1079,7 @@ struct DiffTransposePass
// auto inversionInfo = this->hoistedPrimalsInfo->invertInfoMap[invInst];
return invInst;
}
-
- void setInverse(IRBuilder* builder, IRBlock* defBlock, IRGlobalValueWithCode* func, IRInst* inst, IRInst* invInst)
- {
- auto instBlock = as<IRBlock>(inst->getParent());
- if (!instBlock)
- return;
-
- disableIRValidationAtInsert();
- if (auto invVar = getOrCreateInverseVar(inst, func))
- {
- auto invStore = builder->emitStore(invVar, invInst);
- mapStoreToDefBlock[as<IRStore>(invStore)] = defBlock;
- }
- enableIRValidationAtInsert();
- }
-
- bool shouldInstBeInverted(IRInst* inst)
- {
-
- if (this->hoistedPrimalsInfo->instsToInvert.Contains(inst))
- return true;
-
- return false;
- }
-
- IRInst* hoistPrimalUse(IRBuilder*, IRUse* use)
- {
- primalUsesToHoist.add(use);
- return use->get();
- }
-
+
bool doesInstRequireHoisting(IRInst* inst)
{
if (as<IRModuleInst>(inst->getParent()))
@@ -1336,10 +1089,10 @@ struct DiffTransposePass
as<IRGlobalValueWithCode>(inst) ||
as<IRConstant>(inst))
return false;
-
+
if (as<IRTerminatorInst>(inst))
return false;
-
+
if (as<IRDecoration>(inst))
return doesInstRequireHoisting(getInstInBlock(inst));
@@ -1347,30 +1100,9 @@ struct DiffTransposePass
// that have not yet been moved to the 'active' blocks
// (i.e in diff blocks that do not have parents)
//
- return (!isDifferentialInst(inst) &&
- (isDifferentialInst(getBlock(inst)) || getBlock(inst) == tempInvBlock) &&
- getBlock(inst)->getParent() == nullptr);
- }
-
- // Builds a map from inst to a list of uses by primal _inverted_ insts.
- Dictionary<IRInst*, List<IRInst*>> buildInvOperandMap()
- {
- Dictionary<IRInst*, List<IRInst*>> invOperandMap;
- for (auto kvpair : this->hoistedPrimalsInfo->invertInfoMap)
- {
- InversionInfo invInfo = kvpair.Value;
-
- for (auto operand : invInfo.requiredOperands)
- {
- if (!invOperandMap.ContainsKey(operand))
- invOperandMap[operand] = List<IRInst*>();
-
- for (auto target : invInfo.targetInsts)
- invOperandMap[operand].GetValue().add(target);
- }
- }
-
- return invOperandMap;
+ return (!isDifferentialInst(inst) &&
+ (isDifferentialInst(getBlock(inst)) || getBlock(inst) == tempInvBlock) &&
+ getBlock(inst)->getParent() == nullptr);
}
IRBlock* walkToEndOfRegion(IRBlock* block)
@@ -1435,186 +1167,13 @@ struct DiffTransposePass
void finishHoistingPrimals(IRGlobalValueWithCode* func)
{
- List<IRInst*> workList;
-
- Dictionary<IRInst*, IRInst*> hoistedInstMap;
-
- RefPtr<IRDominatorTree> domTree = computeDominatorTree(func);
-
- Dictionary<IRInst*, List<IRInst*>> invOperandMap = buildInvOperandMap();
-
auto varBlock = func->getFirstBlock()->getNextBlock();
-
- // Load up pending insts into workList.
- for (auto use : primalUsesToHoist)
- workList.add(use->get());
-
- primalUsesToHoist.clear();
-
- auto maybeAddPrimalOperandsToWorkList = [&](IRInst* inst)
+ for (auto inst : primalVarsToHoist)
{
- UIndex opIndex = 0;
- for (auto operand = inst->getOperands(); opIndex < inst->getOperandCount(); operand++, opIndex++)
- {
- if (doesInstRequireHoisting(operand->get()) &&
- !hoistedInstMap.ContainsKey(operand->get()))
- {
- workList.add(operand->get());
- }
- }
-
- if (auto instType = inst->getDataType())
- {
- if (doesInstRequireHoisting(instType) &&
- !hoistedInstMap.ContainsKey(instType))
- workList.add(instType);
- }
- };
-
- auto maybeAddUsersToWorkList = [&](IRInst* inst)
- {
- for (auto use = inst->firstUse; use; use = use->nextUse)
- {
- if (doesInstRequireHoisting(use->getUser()))
- {
- if (as<IRVar>(inst) && as<IRStore>(use->getUser()))
- continue;
-
- // Uses that haven't already been hoisted into reverse-mode
- // blocks, and are not in the invert-set are pending uses.
- //
- if (!hoistedInstMap.ContainsKey(use->getUser()) && !hasInverse(use->getUser()))
- workList.add(use->getUser());
- }
- }
- };
-
- auto doesInstHavePendingUsers = [&](IRInst* inst)
- {
- for (auto use = inst->firstUse; use; use = use->nextUse)
- {
- if (doesInstRequireHoisting(use->getUser()))
- {
- if (as<IRVar>(inst) && as<IRStore>(use->getUser()))
- continue;
-
- // Users that haven't already been hoisted into reverse-mode
- // blocks are pending users.
- //
- if (!hoistedInstMap.ContainsKey(use->getUser()) && !hasInverse(use->getUser()))
- return true;
- }
- }
-
- return false;
- };
-
- auto isInstHoisted = [&](IRInst* inst)
- {
- return getBlock(inst)->getParent() != nullptr && isDifferentialInst(getBlock(inst));
- };
-
- while (workList.getCount() > 0)
- {
- // Pop work item
- auto inst = workList.getLast();
- workList.removeLast();
-
- // Already hoisted to reverse-mode block.
- // replace with mapped inst (in case it's different)
- // and continue on.. (this should actually never be hit)
- //
- if (hoistedInstMap.ContainsKey(inst))
- continue;
-
- if (invOperandMap.ContainsKey(inst))
- {
- List<IRInst*> pendingInvDependencies;
- for (auto dependency : invOperandMap[inst].GetValue())
- {
- if (doesInstRequireHoisting(dependency) &&
- !hoistedInstMap.ContainsKey(dependency))
- pendingInvDependencies.add(dependency);
- }
-
- if (pendingInvDependencies.getCount() > 0)
- {
- workList.add(inst);
- for (auto dependency : pendingInvDependencies)
- workList.add(dependency);
-
- // Skip until all the dependencies have been handled.
- continue;
- }
- }
-
- // Are the uses of this primal inst already hoisted into the reverse-mode
- // blocks? We cannot hoist this inst unless the uses are hoisted.
- //
- if (doesInstHavePendingUsers(inst))
- {
- // Add inst back to work list.
- workList.add(inst);
-
- // Then, add all the pending use to the top of
- // list, ensuring they are processed before we see
- // inst again.
- //
- maybeAddUsersToWorkList(inst);
-
- continue;
- }
-
- // The used inst is marked for inversion, lookup and load
- // an inverse.
- //
- if (this->hoistedPrimalsInfo->invertSet.Contains(inst))
- {
- // Replace with inverse.
- IRBuilder builder(func->getModule());
-
- for (auto use = inst->firstUse; use;)
- {
- auto nextUse = use->nextUse;
-
- if (!isInstHoisted(use->getUser()))
- {
- use = nextUse;
- continue;
- }
-
- // TODO: Hacky workaround to prevent the 'key' being overwritten,
- // avoid this by adding the decoration on the param instead of the loop
- //
- if (auto exitValDecoration = as<IRLoopExitPrimalValueDecoration>(use->getUser()))
- {
- if (&exitValDecoration->target == use)
- {
- use = nextUse;
- continue;
- }
- }
-
-
- builder.setInsertBefore(getInstInBlock(use->getUser()));
- use->set(loadInverse(&builder, inst));
-
- use = nextUse;
- }
-
- // If all uses of the invertible inst have been hoisted,
- // add the inv-var to the worklist.
- //
- workList.add(lookupInverseVar(inst));
- hoistedInstMap[inst] = nullptr;
-
+ if (!doesInstRequireHoisting(inst))
continue;
- }
-
- // Should not see an inst marked for inversion here.
- SLANG_RELEASE_ASSERT(!this->hoistedPrimalsInfo->invertSet.Contains(inst));
-
+
List<IRUse*> relevantUses;
IRBlock* defBlock = nullptr;
@@ -1641,7 +1200,7 @@ struct DiffTransposePass
if (!doesInstRequireHoisting(inst))
continue;
-
+
// Move this inst to after it's diff uses.
//
{
@@ -1662,62 +1221,9 @@ struct DiffTransposePass
inst->insertBefore(currTopBlock->getFirstOrdinaryInst());
enableIRValidationAtInsert();
}
-
- // Finish up..
- hoistedInstMap[inst] = inst;
- maybeAddPrimalOperandsToWorkList(inst);
- }
- }
-
- void hoistPrimalOperands(IRBuilder* revBuilder, IRInst* fwdInst)
- {
- for (UIndex ii = 0; ii < fwdInst->getOperandCount(); ii++)
- {
- // For now we'll only hoist primal operands that are
- // generated in differential blocks.
- // Eventually, we also want this method to move primal access
- // insts to the reverse-mode blocks (i.e. this method will
- // make sure all requried primal insts are moved to the right
- // place)
- //
- if (doesInstRequireHoisting(fwdInst->getOperand(ii)))
- {
- hoistPrimalUse(revBuilder, &fwdInst->getOperands()[ii]);
- }
}
}
- void invertInst(IRBuilder* builder, IRInst* primalInst)
- {
- // Look for an available inverse entry for this primalInst's *output*
- if (shouldInstBeInverted(primalInst))
- {
- // This logic is already handled in transposeBlock() so we skip
- // it here.
- //
- if (as<IRTerminatorInst>(primalInst))
- return;
-
- auto invInfo = this->hoistedPrimalsInfo->invertInfoMap[primalInst];
-
- IRBuilder invBuilder(builder->getModule());
- invBuilder.setInsertAfter(primalInst);
-
- auto invEntries = invertInst(&invBuilder, primalInst, invInfo);
-
- for (auto entry : invEntries)
- setInverse(
- &invBuilder,
- getBlock(primalInst),
- as<IRGlobalValueWithCode>(entry.inst->getParent()->getParent()),
- entry.inst,
- entry.invInst);
- }
- else
- {
- SLANG_UNEXPECTED("Could not find value for the output of inst. Unable to invert");
- }
- }
void transposeInst(IRBuilder* builder, IRInst* inst)
{
@@ -1880,7 +1386,8 @@ struct DiffTransposePass
auto pairType = as<IRPtrTypeBase>(arg->getDataType())->getValueType();
auto tempVar = builder->emitVar(pairType);
auto primalVal = builder->emitLoad(instPair->getPrimal());
- hoistPrimalOperands(builder, primalVal); // TODO(sai): Do we need to hoist other insts here?
+ auto primalVar = instPair->getPrimal();
+ primalVarsToHoist.add(primalVar);
auto diffVal = builder->emitLoad(instPair->getDiff());
auto pairVal = builder->emitMakeDifferentialPair(pairType, primalVal, diffVal);
@@ -1961,7 +1468,6 @@ struct DiffTransposePass
auto primalContextVar = primalContextDecor->getBackwardDerivativePrimalContextVar();
auto contextLoad = builder->emitLoad(primalContextVar);
- hoistPrimalOperands(builder, contextLoad);
args.add(contextLoad);
argTypes.add(as<IRPtrTypeBase>(
@@ -3477,7 +2983,7 @@ struct DiffTransposePass
DifferentialPairTypeBuilder pairBuilder;
- HoistedPrimalsInfo* hoistedPrimalsInfo;
+ List<IRInst*> primalVarsToHoist;
IRBlock* tempInvBlock;
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index c3ce32540..44e981404 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -249,7 +249,7 @@ struct ExtractPrimalFuncContext
List<IRBlock*> unusedBlocks;
for (auto block : func->getBlocks())
{
- if (isDiffInst(block))
+ if (isDiffInst(block) || block->findDecoration<IRRecomputeBlockDecoration>())
unusedBlocks.add(block);
}
for (auto block : unusedBlocks)
@@ -317,8 +317,11 @@ IRFunc* DiffUnzipPass::extractPrimalFunc(
// Remove propagate func specific primal insts from cloned func.
for (auto inst : paramInfo.propagateFuncSpecificPrimalInsts)
{
- auto newInst = subEnv.mapOldValToNew[inst].GetValue();
- newInst->removeAndDeallocate();
+ IRInst* newInst = nullptr;
+ if (subEnv.mapOldValToNew.TryGetValue(inst, newInst))
+ {
+ newInst->removeAndDeallocate();
+ }
}
HashSet<IRInst*> newPrimalParams;
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h
index 8b24b122e..65f45ece8 100644
--- a/source/slang/slang-ir-autodiff-unzip.h
+++ b/source/slang/slang-ir-autodiff-unzip.h
@@ -32,6 +32,7 @@ struct DiffUnzipPass
//
Dictionary<IRInst*, IRInst*> primalMap;
Dictionary<IRInst*, IRInst*> diffMap;
+ Dictionary<IRBlock*, IRBlock*> recomputeBlockMap;
// First diff block.
// TODO: Can the same pass object can be used for multiple functions?
@@ -40,8 +41,6 @@ struct DiffUnzipPass
RefPtr<IndexedRegionMap> indexRegionMap;
- Dictionary<IndexedRegion*, RefPtr<IndexTrackingInfo>> indexInfoMap;
-
DiffUnzipPass(
AutoDiffSharedContext* autodiffContext)
: autodiffContext(autodiffContext)
@@ -58,7 +57,7 @@ struct DiffUnzipPass
return diffMap[inst];
}
- RefPtr<HoistedPrimalsInfo> unzipDiffInsts(IRFunc* func)
+ void unzipDiffInsts(IRFunc* func)
{
diffTypeContext.setFunc(func);
@@ -138,7 +137,8 @@ struct DiffUnzipPass
// Mark the differential block as a differential inst
// (and add a reference to the primal block)
- builder->markInstAsDifferential(diffBlock, nullptr, primalMap[block]);
+ builder->markInstAsDifferential(
+ diffBlock, builder->getBasicBlockType(), primalMap[block]);
// Record the first differential (code) block,
// since we want all 'return' insts in primal blocks
@@ -154,16 +154,6 @@ struct DiffUnzipPass
splitBlock(block, as<IRBlock>(primalMap[block]), as<IRBlock>(diffMap[block]));
}
- // Emit counter variables and other supporting
- // instructions for all regions.
- //
- // TODO: Need to have maxIndex in _both_ IndexTrackingInfo & IndexedRegionInfo.
- // That way, we can do the various passes _before_ lowerIndexedRegions()
- // TODO: Remove the call to lowerIndexedRegions() once checkpointing works properly.
- //
- RefPtr<HoistedPrimalsInfo> primalsInfo = new HoistedPrimalsInfo();
- lowerIndexedRegions(primalsInfo);
-
// Copy regions from fwd-block to their split blocks
// to make it easier to do lookups.
//
@@ -189,217 +179,13 @@ struct DiffUnzipPass
firstBlock->replaceUsesWith(firstPrimalBlock);
RefPtr<BlockSplitInfo> splitInfo = new BlockSplitInfo();
+
for (auto block : mixedBlocks)
if (primalMap.ContainsKey(block))
splitInfo->diffBlockMap[as<IRBlock>(primalMap[block])] = as<IRBlock>(diffMap[block]);
-
- Dictionary<IRBlock*, List<IndexTrackingInfo*>> indexedBlocksInfo;
- for (auto block : mixedBlocks)
- {
- indexedBlocksInfo[as<IRBlock>(diffMap[block])] = getIndexInfoList(as<IRBlock>(diffMap[block]));
- indexedBlocksInfo[as<IRBlock>(primalMap[block])] = getIndexInfoList(as<IRBlock>(primalMap[block]));
- }
for (auto block : mixedBlocks)
block->removeAndDeallocate();
-
- // Run the three checkpointing passes to hoist/clone primal insts
- // to the right spots.
- //
- {
- RefPtr<AutodiffCheckpointPolicyBase> chkPolicy = new DefaultCheckpointPolicy(unzippedFunc->getModule());
- chkPolicy->preparePolicy(func);
-
- auto chkPrimalsInfo = chkPolicy->processFunc(func, splitInfo);
- primalsInfo->merge(chkPrimalsInfo);
-
- primalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlocksInfo);
- }
-
- return primalsInfo;
- }
-
- void tryInferMaxIndex(IndexedRegion* region, IndexTrackingInfo* info)
- {
- if (info->status != IndexTrackingInfo::CountStatus::Unresolved)
- return;
-
- auto loop = as<IRLoop>(region->getInitializerBlock()->getTerminator());
-
- if (auto maxItersDecoration = loop->findDecoration<IRLoopMaxItersDecoration>())
- {
- info->maxIters = (Count) maxItersDecoration->getMaxIters();
- info->status = IndexTrackingInfo::CountStatus::Static;
- }
-
- if (info->status == IndexTrackingInfo::CountStatus::Unresolved)
- {
- SLANG_UNEXPECTED("Could not resolve max iters \
- for loop appearing in reverse-mode");
- }
- }
-
- IRInst* addPhiInputParam(IRBuilder* builder, IRBlock* block, IRType* type)
- {
- builder->setInsertInto(block);
- return builder->emitParam(type);
- }
-
- IRInst* addPhiInputParam(IRBuilder* builder, IRBlock* block, IRType* type, UIndex index)
- {
- List<IRParam*> params;
- for (auto param : block->getParams())
- params.add(param);
-
- SLANG_RELEASE_ASSERT(index == (UCount)params.getCount());
-
- return addPhiInputParam(builder, block, type);
- }
-
- void lowerIndexedRegions(HoistedPrimalsInfo* primalsInfo)
- {
- IRBuilder builder(autodiffContext->moduleInst->getModule());
-
- for (auto region : indexRegionMap->regions)
- {
- RefPtr<IndexTrackingInfo> info = new IndexTrackingInfo();
- indexInfoMap[region] = info;
-
- // Grab first primal block.
- IRBlock* primalInitBlock = as<IRBlock>(primalMap[region->getInitializerBlock()]);
- builder.setInsertBefore(primalInitBlock->getTerminator());
-
- // Make variable in the top-most block (so it's visible to diff blocks)
- info->primalCountLastVar = builder.emitVar(builder.getIntType());
- builder.addNameHintDecoration(info->primalCountLastVar, UnownedStringSlice("_pc_last_var"));
- primalsInfo->storeSet.Add(info->primalCountLastVar);
-
- {
- auto primalCondBlock = as<IRUnconditionalBranch>(
- primalInitBlock->getTerminator())->getTargetBlock();
- builder.setInsertBefore(primalCondBlock->getTerminator());
-
- auto phiCounterArgLoopEntryIndex = addPhiOutputArg(
- &builder,
- primalInitBlock,
- builder.getIntValue(builder.getIntType(), 0));
-
- info->primalCountParam = addPhiInputParam(
- &builder,
- primalCondBlock,
- builder.getIntType(),
- phiCounterArgLoopEntryIndex);
- builder.addNameHintDecoration(info->primalCountParam, UnownedStringSlice("_pc"));
- builder.addLoopCounterDecoration(info->primalCountParam);
- builder.markInstAsPrimal(info->primalCountParam);
-
- IRBlock* primalUpdateBlock = as<IRBlock>(primalMap[region->getUpdateBlock()]);
- builder.setInsertBefore(primalUpdateBlock->getTerminator());
-
- auto incCounterVal = builder.emitAdd(
- builder.getIntType(),
- info->primalCountParam,
- builder.getIntValue(builder.getIntType(), 1));
- builder.markInstAsPrimal(incCounterVal);
-
- auto phiCounterArgLoopCycleIndex = addPhiOutputArg(&builder, primalUpdateBlock, incCounterVal);
-
- SLANG_RELEASE_ASSERT(phiCounterArgLoopEntryIndex == phiCounterArgLoopCycleIndex);
-
- IRBlock* primalBreakBlock = as<IRBlock>(primalMap[region->getBreakBlock()]);
- builder.setInsertBefore(primalBreakBlock->getTerminator());
-
- builder.emitStore(info->primalCountLastVar, info->primalCountParam);
- }
-
- {
- IRBlock* diffInitBlock = as<IRBlock>(diffMap[region->getInitializerBlock()]);
-
- auto diffCondBlock = as<IRUnconditionalBranch>(
- diffInitBlock->getTerminator())->getTargetBlock();
- builder.setInsertBefore(diffCondBlock->getTerminator());
-
- auto phiCounterArgLoopEntryIndex = addPhiOutputArg(
- &builder,
- diffInitBlock,
- builder.getIntValue(builder.getIntType(), 0));
-
- info->diffCountParam = addPhiInputParam(
- &builder,
- diffCondBlock,
- builder.getIntType(),
- phiCounterArgLoopEntryIndex);
- builder.addNameHintDecoration(info->diffCountParam, UnownedStringSlice("_dc"));
- builder.addLoopCounterDecoration(info->diffCountParam);
- builder.markInstAsPrimal(info->diffCountParam);
-
- IRBlock* diffUpdateBlock = as<IRBlock>(diffMap[region->getUpdateBlock()]);
- builder.setInsertBefore(diffUpdateBlock->getTerminator());
-
- auto incCounterVal = builder.emitAdd(
- builder.getIntType(),
- info->diffCountParam,
- builder.getIntValue(builder.getIntType(), 1));
- builder.markInstAsPrimal(incCounterVal);
-
- auto phiCounterArgLoopCycleIndex = addPhiOutputArg(&builder, diffUpdateBlock, incCounterVal);
-
- SLANG_RELEASE_ASSERT(phiCounterArgLoopEntryIndex == phiCounterArgLoopCycleIndex);
-
- auto loopInst = as<IRLoop>(diffInitBlock->getTerminator());
-
- builder.setInsertBefore(loopInst);
-
- auto primalCounterLastVal = builder.emitLoad(info->primalCountLastVar);
- builder.markInstAsPrimal(primalCounterLastVal);
- builder.addPrimalValueAccessDecoration(primalCounterLastVal);
-
- builder.addLoopExitPrimalValueDecoration(loopInst, info->diffCountParam, primalCounterLastVal);
-
- // We'll be manually creating the inversion entries for the counters
- // TODO: This logic can be moved to the checkpointing alg.
- //
- primalsInfo->invertSet.Add(info->diffCountParam);
- primalsInfo->instsToInvert.Add(incCounterVal);
- primalsInfo->invertInfoMap[incCounterVal] = InversionInfo(
- incCounterVal,
- List<IRInst*>(incCounterVal),
- List<IRInst*>(info->diffCountParam));
-
- primalsInfo->invertSet.Add(incCounterVal);
- primalsInfo->instsToInvert.Add(diffUpdateBlock->getTerminator());
- primalsInfo->invertInfoMap[diffUpdateBlock->getTerminator()] = InversionInfo(
- diffUpdateBlock->getTerminator(),
- List<IRInst*>(diffUpdateBlock->getTerminator()),
- List<IRInst*>(incCounterVal));
- }
-
- // Try to infer maximum possible number of iterations.
- // (only regions whose intermediates are used outside their region
- // require a maximum count, so we may see some unresolved regions
- // without any issues)
- //
- tryInferMaxIndex(region, info);
- }
- }
-
- void tagNewParams(IRBuilder* builder, IRFunc* func)
- {
- for (auto block : func->getBlocks())
- {
- for (auto param = block->getFirstParam(); param; param = param->getNextParam())
- if (!param->findDecoration<IRAutodiffInstDecoration>())
- builder->markInstAsPrimal(param);
- }
- }
-
- List<IndexTrackingInfo*> getIndexInfoList(IRBlock* block)
- {
- List<IndexTrackingInfo*> indices;
- for (auto region : indexRegionMap->getAllAncestorRegions(block))
- indices.add((IndexTrackingInfo*) indexInfoMap[region].GetValue());
-
- return indices;
}
IRFunc* extractPrimalFunc(
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 9a7a42619..a8af148d9 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -774,7 +774,9 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent)
case kIROp_PrimalInstDecoration:
case kIROp_DifferentialInstDecoration:
case kIROp_MixedDifferentialInstDecoration:
- case kIROp_PrimalValueAccessDecoration:
+ case kIROp_RecomputeBlockDecoration:
+ case kIROp_LoopCounterDecoration:
+ case kIROp_LoopCounterUpdateDecoration:
case kIROp_BackwardDerivativeDecoration:
case kIROp_BackwardDerivativeIntermediateTypeDecoration:
case kIROp_BackwardDerivativePropagateDecoration:
@@ -814,6 +816,7 @@ void stripTempDecorations(IRInst* inst)
{
case kIROp_DifferentialInstDecoration:
case kIROp_MixedDifferentialInstDecoration:
+ case kIROp_RecomputeBlockDecoration:
case kIROp_AutoDiffOriginalValueDecoration:
case kIROp_BackwardDerivativePrimalReturnDecoration:
case kIROp_PrimalValueStructKeyDecoration:
@@ -902,8 +905,9 @@ bool canTypeBeStored(IRInst* type)
case kIROp_FloatType:
case kIROp_VectorType:
case kIROp_MatrixType:
- case kIROp_AttributedType:
return true;
+ case kIROp_AttributedType:
+ return canTypeBeStored(type->getOperand(0));
default:
return false;
}
@@ -1770,7 +1774,7 @@ IRInst* getInstInBlock(IRInst* inst)
return getInstInBlock(inst->getParent());
}
-UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg)
+UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst*& inoutTerminatorInst, IRInst* arg)
{
SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(block->getTerminator()));
@@ -1786,16 +1790,22 @@ UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg)
switch (branchInst->getOp())
{
case kIROp_unconditionalBranch:
- builder->emitBranch(branchInst->getTargetBlock(), phiArgs.getCount(), phiArgs.getBuffer());
+ inoutTerminatorInst = builder->emitBranch(
+ branchInst->getTargetBlock(), phiArgs.getCount(), phiArgs.getBuffer());
break;
case kIROp_loop:
- builder->emitLoop(
- as<IRLoop>(branchInst)->getTargetBlock(),
- as<IRLoop>(branchInst)->getBreakBlock(),
- as<IRLoop>(branchInst)->getContinueBlock(),
- phiArgs.getCount(),
- phiArgs.getBuffer());
+ {
+ auto newLoop = builder->emitLoop(
+ as<IRLoop>(branchInst)->getTargetBlock(),
+ as<IRLoop>(branchInst)->getBreakBlock(),
+ as<IRLoop>(branchInst)->getContinueBlock(),
+ phiArgs.getCount(),
+ phiArgs.getBuffer());
+ branchInst->transferDecorationsTo(newLoop);
+ branchInst->replaceUsesWith(newLoop);
+ inoutTerminatorInst = newLoop;
+ }
break;
default:
@@ -1806,6 +1816,24 @@ UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg)
return phiArgs.getCount() - 1;
}
+bool isDifferentialOrRecomputeBlock(IRBlock* block)
+{
+ if (!block)
+ return false;
+ for (auto decor : block->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_DifferentialInstDecoration:
+ case kIROp_RecomputeBlockDecoration:
+ return true;
+ default:
+ break;
+ }
+ }
+ return false;
+}
+
IRUse* findUniqueStoredVal(IRVar* var)
{
if (isDerivativeContextVar(var))
diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h
index 167aa2357..d7d6119d4 100644
--- a/source/slang/slang-ir-autodiff.h
+++ b/source/slang/slang-ir-autodiff.h
@@ -340,7 +340,7 @@ IRBlock* getBlock(IRInst* inst);
IRInst* getInstInBlock(IRInst* inst);
-UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst* arg);
+UIndex addPhiOutputArg(IRBuilder* builder, IRBlock* block, IRInst*& inoutTerminatorInst, IRInst* arg);
IRUse* findUniqueStoredVal(IRVar* var);
@@ -348,4 +348,6 @@ bool isDerivativeContextVar(IRVar* var);
bool isDiffInst(IRInst* inst);
+bool isDifferentialOrRecomputeBlock(IRBlock* block);
+
};
diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp
index 1fe88e780..364abe68c 100644
--- a/source/slang/slang-ir-dce.cpp
+++ b/source/slang/slang-ir-dce.cpp
@@ -327,29 +327,7 @@ bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions o
//
if (inst->mightHaveSideEffects())
{
- // If the inst has side effect, we should keep it alive.
- // An exception is if we have a call to a pure function
- // that writes its output to a local variable, but we
- // don't have any uses of that local variable.
- auto call = as<IRCall>(inst);
- if (!call)
- return true;
- if (!getResolvedInstForDecorations(call->getCallee())->findDecoration<IRReadNoneDecoration>())
- return true;
- auto parentFunc = getParentFunc(inst);
- if (!parentFunc)
- return true;
- for (UInt i = 0; i < call->getArgCount(); i++)
- {
- auto arg = call->getArg(i);
- if (getParentFunc(arg) != parentFunc)
- return true;
- if (arg->getOp() != kIROp_Var)
- return true;
- if (arg->hasMoreThanOneUse())
- return true;
- }
- return false;
+ return true;
}
//
// The `mightHaveSideEffects` query is conservative, and will
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index a8ec5a66f..11143cebb 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -810,7 +810,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(BackwardDerivativePrimalReturnDecoration, BackwardDerivativePrimalReturnDecoration, 1, 0)
INST(LoopCounterDecoration, loopCounterDecoration, 0, 0)
- INST(PrimalValueAccessDecoration, primalValueAccessDecoration, 0, 0)
+ INST(LoopCounterUpdateDecoration, loopCounterUpdateDecoration, 0, 0)
/* Auto-diff inst decorations */
/// Used by the auto-diff pass to mark insts that compute
@@ -824,7 +824,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// Used by the auto-diff pass to mark insts that compute
/// BOTH a differential and a primal value.
INST(MixedDifferentialInstDecoration, mixedDiffInstDecoration, 1, 0)
- INST_RANGE(AutodiffInstDecoration, PrimalInstDecoration, MixedDifferentialInstDecoration)
+
+ INST(RecomputeBlockDecoration, RecomputeBlockDecoration, 0, 0)
+ INST_RANGE(AutodiffInstDecoration, PrimalInstDecoration, RecomputeBlockDecoration)
/// Used by the auto-diff pass to mark insts whose result is stored
/// in an intermediary struct for reuse in backward propagation phase.
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 356ccf4d6..f515baf8d 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -721,6 +721,16 @@ struct IRLoopCounterDecoration : IRDecoration
IR_LEAF_ISA(LoopCounterDecoration)
};
+struct IRLoopCounterUpdateDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_LoopCounterUpdateDecoration
+ };
+ IR_LEAF_ISA(LoopCounterUpdateDecoration)
+};
+
+
struct IRLoopExitPrimalValueDecoration : IRDecoration
{
enum
@@ -777,14 +787,14 @@ struct IRMixedDifferentialInstDecoration : IRAutodiffInstDecoration
IRType* getPairType() { return as<IRType>(getOperand(0)); }
};
-struct IRPrimalValueAccessDecoration : IRAutodiffInstDecoration
+struct IRRecomputeBlockDecoration : IRAutodiffInstDecoration
{
enum
{
- kOp = kIROp_PrimalValueAccessDecoration
+ kOp = kIROp_RecomputeBlockDecoration
};
- IR_LEAF_ISA(PrimalValueAccessDecoration)
+ IR_LEAF_ISA(RecomputeBlockDecoration)
};
struct IRPrimalValueStructKeyDecoration : IRDecoration
@@ -3532,6 +3542,7 @@ public:
IRInst* emitEql(IRInst* left, IRInst* right);
IRInst* emitNeq(IRInst* left, IRInst* right);
IRInst* emitLess(IRInst* left, IRInst* right);
+ IRInst* emitGeq(IRInst* left, IRInst* right);
IRInst* emitShr(IRType* type, IRInst* op0, IRInst* op1);
IRInst* emitShl(IRType* type, IRInst* op0, IRInst* op1);
@@ -3807,9 +3818,9 @@ public:
addDecoration(value, kIROp_LoopExitPrimalValueDecoration, primalInst, exitValue);
}
- void addPrimalValueAccessDecoration(IRInst* value)
+ void addLoopCounterUpdateDecoration(IRInst* value)
{
- addDecoration(value, kIROp_PrimalValueAccessDecoration);
+ addDecoration(value, kIROp_LoopCounterUpdateDecoration);
}
void markInstAsPrimal(IRInst* value)
diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp
index 121665c85..a368ff8c8 100644
--- a/source/slang/slang-ir-loop-unroll.cpp
+++ b/source/slang/slang-ir-loop-unroll.cpp
@@ -529,23 +529,6 @@ bool unrollLoopsInModule(IRModule* module, DiagnosticSink* sink)
return true;
}
-static void _moveParams(IRBlock* dest, IRBlock* src)
-{
- for (auto param = src->getFirstChild(); param;)
- {
- auto nextInst = param->getNextInst();
- if (as<IRDecoration>(param) || as<IRParam>(param))
- {
- param->insertAtEnd(dest);
- }
- else
- {
- break;
- }
- param = nextInst;
- }
-}
-
void eliminateContinueBlocks(IRModule* module, IRLoop* loopInst)
{
// Eliminate the continue jumps by turning a loop in the form of:
@@ -599,7 +582,7 @@ void eliminateContinueBlocks(IRModule* module, IRLoop* loopInst)
targetBlock->replaceUsesWith(innerBreakableRegionHeader);
// Move decorations and params from original targetBlock to innerBreakableRegionHeader.
- _moveParams(innerBreakableRegionHeader, targetBlock);
+ moveParams(innerBreakableRegionHeader, targetBlock);
builder.setInsertInto(innerBreakableRegionHeader);
builder.emitLoop(targetBlock, innerBreakableRegionBreakBlock, targetBlock);
@@ -607,7 +590,7 @@ void eliminateContinueBlocks(IRModule* module, IRLoop* loopInst)
continueBlock->replaceUsesWith(innerBreakableRegionBreakBlock);
builder.setInsertInto(innerBreakableRegionBreakBlock);
- _moveParams(innerBreakableRegionBreakBlock, continueBlock);
+ moveParams(innerBreakableRegionBreakBlock, continueBlock);
builder.emitBranch(continueBlock);
// If the original loop can be executed up to N times, the new loop may be executed
diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp
index dd92630b3..99cae22f0 100644
--- a/source/slang/slang-ir-redundancy-removal.cpp
+++ b/source/slang/slang-ir-redundancy-removal.cpp
@@ -66,6 +66,14 @@ struct RedundancyRemovalContext
case kIROp_Leq:
case kIROp_Neq:
case kIROp_Eql:
+ case kIROp_ExtractExistentialType:
+ case kIROp_ExtractExistentialValue:
+ case kIROp_ExtractExistentialWitnessTable:
+ case kIROp_PtrType:
+ case kIROp_ArrayType:
+ case kIROp_FuncType:
+ case kIROp_InOutType:
+ case kIROp_OutType:
return true;
case kIROp_Call:
return isPureFunctionalCall(as<IRCall>(inst));
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 83f6735bd..03b74b36a 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -681,6 +681,9 @@ bool isPureFunctionalCall(IRCall* call)
// are not dependent on whatever we do in the call here.
continue;
default:
+ // Skip the call itself, since we are checking if the call has side effect.
+ if (use->getUser() == call)
+ continue;
// We have some other unknown use of the variable address, they can
// be loads, or calls using addresses derived from the variable,
// we will treat the call as having side effect to be safe.
@@ -721,6 +724,23 @@ IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key)
return nullptr;
}
+void moveParams(IRBlock* dest, IRBlock* src)
+{
+ for (auto param = src->getFirstChild(); param;)
+ {
+ auto nextInst = param->getNextInst();
+ if (as<IRDecoration>(param) || as<IRParam>(param))
+ {
+ param->insertAtEnd(dest);
+ }
+ else
+ {
+ break;
+ }
+ param = nextInst;
+ }
+}
+
struct GenericChildrenMigrationContextImpl
{
IRCloneEnv cloneEnv;
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index ef7ff47bb..e7d182604 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -198,6 +198,8 @@ void removeLinkageDecorations(IRGlobalValueWithCode* func);
IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key);
IRInst* findWitnessTableEntry(IRWitnessTable* table, IRInst* key);
+
+void moveParams(IRBlock* dest, IRBlock* src);
}
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 97109274f..558fd7796 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5405,6 +5405,13 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitGeq(IRInst* left, IRInst* right)
+ {
+ auto inst = createInst<IRInst>(this, kIROp_Geq, getBoolType(), left, right);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitMul(IRType* type, IRInst* left, IRInst* right)
{
auto inst = createInst<IRInst>(