summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-01-15 15:00:20 -0500
committerGitHub <noreply@github.com>2023-01-15 12:00:20 -0800
commit2c437498d3a09b58de17a8865242814d9ea92fde (patch)
tree3a8ff790aa82b2b8a9217d7c6870073e0e4842f7 /source
parent1c9b33157322751c456bf7abbd386edccf4413c3 (diff)
Switched to a much simpler method to transpose control flow, nested control flow works now (#2595)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp1
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h480
2 files changed, 276 insertions, 205 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 68a86bc00..fa3eb463e 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -798,7 +798,6 @@ InstPair ForwardDiffTranscriber::transcribeIfElse(IRBuilder* builder, IRIfElse*
SLANG_ASSERT(diffTrueBlock);
// Transcribe 'false' block (condition block branches into this if true)
- // TODO (sai): What happens if there's no false block?
auto diffFalseBlock = findOrTranscribeDiffInst(builder, origIfElse->getFalseBlock());
SLANG_ASSERT(diffFalseBlock);
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 78b8c5098..cbdb0a998 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -54,24 +54,6 @@ struct DiffTransposePass
Flavor flavor;
};
- struct BlockPredecessorEntry
- {
- // Previous block.
- IRBlock* prevBlock;
-
- // Integer value corresponding to this predecessor.
- IRIntegerValue* indexVal;
- };
-
- struct ControlFlowTranspositionInfo
- {
- // Variable used for recording control flow.
- IRVar* controlVar;
-
- // Info about all possible predecessor blocks.
- Dictionary<IRBlock*, IRInst*> predEntries;
- };
-
DiffTransposePass(AutoDiffSharedContext* autodiffContext) :
autodiffContext(autodiffContext), pairBuilder(autodiffContext), diffTypeContext(autodiffContext)
{ }
@@ -102,6 +84,50 @@ struct DiffTransposePass
Dictionary<IRInst*, IRInst*>* primalsMap;
};
+ struct PendingBlockTerminatorEntry
+ {
+ IRBlock* fwdBlock;
+ List<IRInst*> phiGrads;
+
+ PendingBlockTerminatorEntry() : fwdBlock(nullptr)
+ {}
+
+ PendingBlockTerminatorEntry(IRBlock* fwdBlock, List<IRInst*> phiGrads) :
+ fwdBlock(fwdBlock), phiGrads(phiGrads)
+ {}
+ };
+
+ struct Region
+ {
+ IRBlock* exitBlock;
+ IRBlock* originBlock;
+
+ Region* parent;
+
+ Region() :
+ exitBlock(nullptr),
+ originBlock(nullptr),
+ parent(nullptr)
+ { }
+
+ Region(IRBlock* exitBlock, Region* parent) :
+ exitBlock(exitBlock),
+ originBlock(nullptr),
+ parent(parent)
+ { }
+
+ void finish(IRBlock* block)
+ {
+ SLANG_ASSERT(!this->originBlock);
+ this->originBlock = block;
+ }
+
+ bool isComplete()
+ {
+ return (this->originBlock != nullptr);
+ }
+ };
+
void transposeDiffBlocksInFunc(
IRFunc* revDiffFunc,
FuncTranspositionInfo transposeInfo)
@@ -114,6 +140,11 @@ struct DiffTransposePass
auto terminalPrimalBlocks = getTerminalPrimalBlocks(revDiffFunc);
auto terminalDiffBlocks = getTerminalDiffBlocks(revDiffFunc);
+ // Add a top-level null region entry for the terminal diff block.
+ regionMap[terminalDiffBlocks[0]] = nullptr;
+
+ buildAfterBlockMap(revDiffFunc);
+
// Traverse all instructions/blocks in reverse (starting from the terminator inst)
// look for insts/blocks marked with IRDifferentialInstDecoration,
// and transpose them in the revDiffFunc.
@@ -170,6 +201,18 @@ struct DiffTransposePass
this->transposeBlock(block, revBlock);
}
+ // Some blocks may not have their control flow
+ // insts completed. Do them now that we have
+ // more information.
+ //
+ for (auto pendingBlockInfo : pendingBlocks)
+ {
+ builder.setInsertInto(revBlockMap[pendingBlockInfo.fwdBlock]);
+ completeEmitTerminator(&builder, pendingBlockInfo.fwdBlock, pendingBlockInfo.phiGrads);
+ }
+
+ pendingBlocks.clear();
+
// Link the last differential fwd-mode block (which will be the first
// rev-mode block) as the successor to the last primal block.
// We assume that the original function is in single-return form
@@ -209,6 +252,8 @@ struct DiffTransposePass
{
block->removeAndDeallocate();
}
+
+ cleanupRegionInfo();
}
// Fetch or create a gradient accumulator var
@@ -224,7 +269,11 @@ struct DiffTransposePass
IRBuilder tempVarBuilder(autodiffContext->sharedBuilder);
IRBlock* firstDiffBlock = firstRevDiffBlockMap[as<IRFunc>(fwdInst->getParent()->getParent())];
- tempVarBuilder.setInsertBefore(firstDiffBlock->getTerminator());
+
+ if (auto firstInst = firstDiffBlock->getFirstOrdinaryInst())
+ tempVarBuilder.setInsertBefore(firstInst);
+ else
+ tempVarBuilder.setInsertInto(firstDiffBlock);
auto primalType = tryGetPrimalTypeFromDiffInst(fwdInst);
auto diffType = fwdInst->getDataType();
@@ -408,7 +457,13 @@ struct DiffTransposePass
// We _should_ be completely out of gradients to process at this point.
SLANG_ASSERT(gradientsMap.Count() == 0);
- emitTerminator(&builder, fwdBlock, phiParamRevGradInsts);
+ if (!tryEmitTerminator(&builder, fwdBlock, phiParamRevGradInsts))
+ {
+ // If we couldn't emit a terminator right away, defer for later.
+ pendingBlocks.add(PendingBlockTerminatorEntry(
+ fwdBlock,
+ phiParamRevGradInsts));
+ }
}
void transposeInst(IRBuilder* builder, IRInst* inst)
@@ -649,6 +704,54 @@ struct DiffTransposePass
return terminalPrimalBlocks;
}
+ IRBlock* getAfterBlock(IRBlock* block)
+ {
+ auto terminatorInst = block->getTerminator();
+ switch (terminatorInst->getOp())
+ {
+ case kIROp_unconditionalBranch:
+ case kIROp_Return:
+ return nullptr;
+
+ case kIROp_ifElse:
+ return as<IRIfElse>(terminatorInst)->getAfterBlock();
+ case kIROp_Switch:
+ return as<IRSwitch>(terminatorInst)->getBreakLabel();
+ case kIROp_loop:
+ return as<IRLoop>(terminatorInst)->getBreakBlock();
+
+ default:
+ SLANG_UNIMPLEMENTED_X("Unhandled terminator inst when building after-block map");
+ }
+ }
+
+ void buildAfterBlockMap(IRGlobalValueWithCode* fwdFunc)
+ {
+ // Scan through a fwd-mode function, and build a list of blocks
+ // that appear as the 'after' block for any conditional control
+ // flow statement.
+ //
+
+ for (auto block = fwdFunc->getFirstBlock(); block; block = block->getNextBlock())
+ {
+ // Only need to process differential blocks.
+ if (!isDifferentialInst(block))
+ continue;
+
+ IRBlock* afterBlock = getAfterBlock(block);
+
+ if (afterBlock)
+ {
+ // No block can by the after block for multiple control flow insts.
+ //
+ SLANG_ASSERT(!(afterBlockMap.ContainsKey(afterBlock) && \
+ afterBlockMap[afterBlock] != block->getTerminator()));
+
+ afterBlockMap[afterBlock] = block->getTerminator();
+ }
+ }
+ }
+
List<IRBlock*> getTerminalDiffBlocks(IRGlobalValueWithCode* func)
{
// Terminal differential blocks are those with a return statement.
@@ -663,244 +766,206 @@ struct DiffTransposePass
return terminalDiffBlocks;
}
- IRInst* addPredecessorForBlock(IRBlock* block, IRBlock* predBlock)
+ bool doesBlockHaveDifferentialPredecessors(IRBlock* fwdBlock)
{
- if (!this->blockEntries.ContainsKey(block))
+ for (auto block : fwdBlock->getPredecessors())
{
- // We haven't encountered this block yet, create a var for this in the
- // first code block.
- auto firstCodeBlock = getFirstCodeBlock(block->getParent());
-
- IRBuilder subBuilder(this->autodiffContext->sharedBuilder);
- subBuilder.setInsertBefore(firstCodeBlock->getTerminator());
- auto controlVar = subBuilder.emitVar(subBuilder.getUIntType());
-
- ControlFlowTranspositionInfo info;
- info.controlVar = controlVar;
-
- this->blockEntries[block] = info;
+ if (isDifferentialInst(block))
+ {
+ return true;
+ }
}
- auto info = this->blockEntries[block];
+ return false;
+ }
- // Does precessor block already exist?
- if (info.GetValue().predEntries.ContainsKey(predBlock))
- {
- return info.GetValue().predEntries[predBlock];
- }
+ IRBlock* insertPhiBlockBefore(IRBlock* revBlock, List<IRInst*> phiArgs)
+ {
+ IRBuilder phiBlockBuilder(autodiffContext->sharedBuilder);
+ phiBlockBuilder.setInsertBefore(revBlock);
- // Otherwise, create an entry..
- auto uniqueIndex = info.GetValue().predEntries.Count();
+ auto phiBlock = phiBlockBuilder.emitBlock();
- IRBuilder builder(this->autodiffContext->sharedBuilder);
- auto uniqueIndexLiteral = builder.getIntValue(builder.getUIntType(), uniqueIndex);
+ if (isDifferentialInst(revBlock))
+ phiBlockBuilder.markInstAsDifferential(phiBlock);
- info.GetValue().predEntries[predBlock] = uniqueIndexLiteral;
-
- return uniqueIndexLiteral;
+ phiBlockBuilder.emitBranch(
+ revBlock,
+ phiArgs.getCount(),
+ phiArgs.getBuffer());
+
+ return phiBlock;
}
- IRVar* getControlVar(IRBlock* block)
- {
- return this->blockEntries[block].GetValue().controlVar;
- }
-
- // Inserts a block between the branch from fwdPredecessorBlock to fwdBlock, which sets a control
- // variable to a unique index.
- //
- IRInst* insertPreludeForPredecessor(IRBlock* fwdBlock, IRBlock* fwdPredecessorBlock)
+ // Create a region to track control flow from the
+ // the point of convergence (fwdConvBlock) back to the point of
+ // divergence, along one specific path (fwdExitBlock)
+ //
+ void pushRegion(IRBlock* fwdConvBlock, IRBlock* fwdExitBlock)
{
- // Get associated primal blocks for both the differential blocks.
- auto primalPredecessorBlock = getPrimalBlock(fwdPredecessorBlock);
- SLANG_ASSERT(primalPredecessorBlock);
+ SLANG_ASSERT(!regionMap.ContainsKey(fwdExitBlock));
+ SLANG_ASSERT(regionMap.ContainsKey(fwdConvBlock));
- auto primalBlock = getPrimalBlock(fwdBlock);
- SLANG_ASSERT(primalBlock);
+ Region* newRegion = new Region(fwdExitBlock, regionMap[fwdConvBlock]);
+ regions.add(newRegion);
- // Add this block as a predecessor, and get an unique index (as an integer literal)
- auto indexVal = addPredecessorForBlock(fwdBlock, fwdPredecessorBlock);
-
- IRBuilder subBuilder(this->autodiffContext->sharedBuilder);
- subBuilder.setInsertInto(primalPredecessorBlock->getParent());
-
- IRInst* preludeBlock = subBuilder.emitBlock();
- preludeBlock->insertAfter(primalPredecessorBlock);
+ regionMap[fwdExitBlock] = newRegion;
+ }
- // Copy over phi parameters.
- List<IRInst*> phiParams;
- for (auto param = primalBlock->getFirstParam(); param; param = param->getNextParam())
+ // If we have a conditional-branch from fwdBlock to fwdNextBlock
+ // complete the region, and remove from stack
+ // otherwise, copy the region over.
+ //
+ void propagateRegion(IRBlock* fwdNextBlock, IRBlock* fwdBlock)
+ {
+ if (as<IRConditionalBranch>(fwdBlock->getTerminator()))
{
- phiParams.add(subBuilder.emitParam(param->getDataType()));
- }
-
- auto controlVar = getControlVar(fwdBlock);
- subBuilder.emitStore(controlVar, indexVal);
+ Region* currentRegion = regionMap[fwdNextBlock];
+ currentRegion->finish(fwdNextBlock);
- // Branch into the successor block using all the same phi parameters.
- subBuilder.emitBranch(primalBlock, phiParams.getCount(), phiParams.getBuffer());
-
- // Scan through uses of primalBlock to find the ones that are in
- // primalPredecessorBlock, and replace them with branches to
- // preludeBlock.
- //
- List<IRUse*> relevantUses;
- for (auto use = primalBlock->firstUse; use; use = use->nextUse)
+ regionMap[fwdBlock] = currentRegion->parent;
+ }
+ else if (as<IRUnconditionalBranch>(fwdBlock->getTerminator()) ||
+ as<IRReturn>(fwdBlock->getTerminator()))
{
- if (use->getUser()->getParent() == primalPredecessorBlock)
- relevantUses.add(use);
+ regionMap[fwdBlock] = regionMap[fwdNextBlock];
}
-
- for (auto use : relevantUses)
- use->set(preludeBlock);
-
- return indexVal;
}
- bool doesBlockHaveDifferentialPredecessors(IRBlock* fwdBlock)
+ // Deallocate regions
+ void cleanupRegionInfo()
{
- for (auto block : fwdBlock->getPredecessors())
+ for (auto region : regions)
{
- if (isDifferentialInst(block))
- {
- return true;
- }
+ delete region;
}
- return false;
+ regions.clear();
+ regionMap.Clear();
}
- void emitTerminator(IRBuilder* builder, IRBlock* fwdBlockInst, List<IRInst*> phiParamGrads)
+ bool tryEmitTerminator(IRBuilder* builder, IRBlock* fwdBlockInst, List<IRInst*> phiParamGrads)
{
// If this block has no differential predecessors, add a return statement.
if (!doesBlockHaveDifferentialPredecessors(fwdBlockInst))
{
// Emit a void return.
builder->emitReturn();
- return;
+ return true;
}
+ List<IRBlock*> fwdPredecesorBlocks;
+ // Check for predecessors count.
for (auto predecessor : fwdBlockInst->getPredecessors())
{
- // Insert code into the *primal* version of the predecessor block
- // to set the control variable to indexVal before branching.
- //
- insertPreludeForPredecessor(fwdBlockInst, predecessor);
+ fwdPredecesorBlocks.add(predecessor);
}
- List<IRBlock*> revPredecessorBlocks;
- List<IRInst*> indexVals;
+ SLANG_ASSERT(fwdPredecesorBlocks.getCount() > 0);
- for (auto blockEntry : this->blockEntries[fwdBlockInst].GetValue().predEntries)
+ // If we have just one, we simply need the reverse-mode block to
+ // branch into the reverse-mode version of the predecessor block.
+ // (along with the appropriate phi args)
+ //
+ if (fwdPredecesorBlocks.getCount() == 1)
{
- revPredecessorBlocks.add(revBlockMap[blockEntry.Key]);
- indexVals.add(blockEntry.Value);
- }
-
- auto predCount = revPredecessorBlocks.getCount();
+ builder->emitBranch(
+ revBlockMap[fwdPredecesorBlocks[0]],
+ phiParamGrads.getCount(),
+ phiParamGrads.getBuffer());
- SLANG_ASSERT(predCount > 0);
-
- List<IRBlock*> intermediateBranchBlocks;
-
- IRBuilder branchBlockBuilder(builder->getSharedBuilder());
-
- branchBlockBuilder.setInsertInto(builder->getFunc());
+ propagateRegion(fwdBlockInst, fwdPredecesorBlocks[0]);
+ return true;
+ }
- // Make a block to unconditionally branch into predecessor-0 with the
- // appropriate phi gradients.
+ // If we have more than one, then control flow 'converges' at this point.
+ // By convention, this block must be the after block for _some_ conditional
+ // control flow statement.
+ // If not, we are dealing with an inconsistent graph.
//
- auto firstBranchBlock = branchBlockBuilder.emitBlock();
- intermediateBranchBlocks.add(firstBranchBlock);
-
- branchBlockBuilder.markInstAsDifferential(firstBranchBlock);
- branchBlockBuilder.emitBranch(
- revPredecessorBlocks[0],
- phiParamGrads.getCount(),
- phiParamGrads.getBuffer());
-
- // Create a builder to insert loads and comparison insts to figure
- // out which block to branch into based on the control vars.
- // This builder is set up to emit into the last _primal_ block.
+ // Rather than actually emitting the terminator here, we're going to
+ // defer to a pass after all the blocks have been transposed.
+ // This is because, while we know that this block is the point of convergence
+ // we don't know which predecessor belong to which side of the branch.
+ // We will instead create 'regions' to track each predecessor for every
+ // branch, and by the time all blocks are seen at-least once, we should have
+ // resolved the 'start' points for every predecessor.
//
- IRBuilder booleanIndicatorBuilder(builder->getSharedBuilder());
- auto terminalPrimalBlock = getTerminalPrimalBlocks(builder->getFunc())[0];
- booleanIndicatorBuilder.setInsertBefore(terminalPrimalBlock->getTerminator());
-
- if (predCount == 1)
+ if (fwdPredecesorBlocks.getCount() > 1)
{
- builder->emitBranch(firstBranchBlock);
+ SLANG_ASSERT(afterBlockMap.ContainsKey(fwdBlockInst));
+
+ for (auto predecessor : fwdPredecesorBlocks)
+ {
+ // Trivial case when the predecessor itself is the point
+ // of divergence.
+ //
+ if (getAfterBlock(predecessor) == fwdBlockInst)
+ continue;
+
+ pushRegion(fwdBlockInst, predecessor);
+ }
}
- else
- {
- IRBuilder ladderBlockBuilder(builder->getSharedBuilder());
- ladderBlockBuilder.setInsertInto(builder->getFunc());
- // TODO: For now, we're trivially setting 'afterBlock' to
- // the first reverse block. This is not really optimal for the
- // restructuring passes since the 'then' and 'else' regions
- // can have significant overlap.
- //
- auto firstFwdDiffBlock = (*terminalPrimalBlock->getSuccessors().begin());
- SLANG_ASSERT(firstFwdDiffBlock);
+ return false;
+ }
- auto defaultAfterBlock = revBlockMap[firstFwdDiffBlock];
+ bool completeEmitTerminator(IRBuilder* builder, IRBlock* fwdBlockInst, List<IRInst*> phiParamGrads)
+ {
+ IRBlock* revBlock = revBlockMap[fwdBlockInst];
- auto nextLadderBlock = firstBranchBlock;
- for (Index ii = 0; ii < predCount - 1; ii++)
+ // If we already have a terminator, we've resolved it during
+ // tryEmitTerminator()
+ //
+ if (revBlock->getTerminator() != nullptr)
+ return true;
+
+ auto terminatorInst = as<IRInst>(afterBlockMap[fwdBlockInst]);
+ switch (terminatorInst->getOp())
+ {
+ case kIROp_ifElse:
{
- // Make the 'leaf' block. This just branches into
- // predecessor-i+1 with the appropriate phi args.
- //
- branchBlockBuilder.setInsertInto(branchBlockBuilder.getFunc());
-
- auto thisIndexBlock = branchBlockBuilder.emitBlock();
- intermediateBranchBlocks.add(thisIndexBlock);
-
- branchBlockBuilder.markInstAsDifferential(thisIndexBlock);
- branchBlockBuilder.emitBranch(
- revPredecessorBlocks[ii+1],
- phiParamGrads.getCount(),
- phiParamGrads.getBuffer());
-
- // Emit a boolean inst to represent whether we need to branch into
- // block ii.
- auto blockIndicatorInst = booleanIndicatorBuilder.emitEql(
- booleanIndicatorBuilder.emitLoad(getControlVar(fwdBlockInst)),
- indexVals[ii+1]);
-
- // Create a block to branch between i+1 and the rest of the ladder so far
- // (0 ... i)
- //
- auto upperLadderBlock = ladderBlockBuilder.emitBlock();
- intermediateBranchBlocks.add(upperLadderBlock);
-
- ladderBlockBuilder.markInstAsDifferential(upperLadderBlock);
- ladderBlockBuilder.emitIfElse(
- blockIndicatorInst,
- thisIndexBlock,
- nextLadderBlock,
- defaultAfterBlock);
+ auto ifElseInst = as<IRIfElse>(terminatorInst);
- nextLadderBlock = upperLadderBlock;
- }
+ auto condition = ifElseInst->getCondition();
+ SLANG_ASSERT(!isDifferentialInst(condition));
- // Branch into the last ladder block.
- builder->emitBranch(nextLadderBlock);
- }
+ // fwd origin block is the reverse 'after' block.
+ auto revAfterBlock = as<IRBlock>(
+ revBlockMap[as<IRBlock>(ifElseInst->getParent())]);
+
+ // Find region, and find the reverse-mode version of the
+ // exit block.
+ Region* trueRegion = regionMap[ifElseInst->getTrueBlock()];
+ IRBlock* revTrueBlock = revBlockMap[trueRegion->exitBlock];
- // Insert all intermediate blocks in the order they were created, right after
- // the current reverse block.
-
- auto revBlock = revBlockMap[fwdBlockInst];
- SLANG_ASSERT(revBlock);
+ Region* falseRegion = regionMap[ifElseInst->getFalseBlock()];
+ IRBlock* revFalseBlock = revBlockMap[falseRegion->exitBlock];
- for (auto block : intermediateBranchBlocks)
- {
- block->insertAfter(revBlock);
+ // If we have phi derivatives to pass on,
+ // we need to add dummy blocks to pass them using
+ // an unconditional branch.
+ //
+ if (phiParamGrads.getCount() > 0)
+ {
+ revTrueBlock = insertPhiBlockBefore(revTrueBlock, phiParamGrads);
+ revFalseBlock = insertPhiBlockBefore(revFalseBlock, phiParamGrads);
+
+ // Putting the phi blocks just after our current reverse-mode block
+ // is not necessary. Just to make intermediate IR easier to follow.
+ //
+ revTrueBlock->insertAfter(revBlock);
+ revFalseBlock->insertAfter(revBlock);
+ }
+
+ builder->emitIfElse(condition, revTrueBlock, revFalseBlock, revAfterBlock);
+ break;
+ }
+ default:
+ SLANG_UNIMPLEMENTED_X("Unhandled control flow inst during transposition");
}
-
- return;
+ return false;
}
TranspositionResult transposeInst(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue)
@@ -1622,12 +1687,19 @@ struct DiffTransposePass
Dictionary<IRInst*, IRInst*>* primalsMap;
List<IRInst*> usedPtrs;
-
- Dictionary<IRBlock*, ControlFlowTranspositionInfo> blockEntries;
Dictionary<IRBlock*, IRBlock*> revBlockMap;
Dictionary<IRGlobalValueWithCode*, IRBlock*> firstRevDiffBlockMap;
+
+ Dictionary<IRBlock*, IRInst*> afterBlockMap;
+
+ List<PendingBlockTerminatorEntry> pendingBlocks;
+
+ Dictionary<IRBlock*, Region*> regionMap;
+
+ List<Region*> regions;
+
};