diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-01-15 15:00:20 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-15 12:00:20 -0800 |
| commit | 2c437498d3a09b58de17a8865242814d9ea92fde (patch) | |
| tree | 3a8ff790aa82b2b8a9217d7c6870073e0e4842f7 /source | |
| parent | 1c9b33157322751c456bf7abbd386edccf4413c3 (diff) | |
Switched to a much simpler method to transpose control flow, nested control flow works now (#2595)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 480 |
2 files changed, 276 insertions, 205 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 68a86bc00..fa3eb463e 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -798,7 +798,6 @@ InstPair ForwardDiffTranscriber::transcribeIfElse(IRBuilder* builder, IRIfElse* SLANG_ASSERT(diffTrueBlock); // Transcribe 'false' block (condition block branches into this if true) - // TODO (sai): What happens if there's no false block? auto diffFalseBlock = findOrTranscribeDiffInst(builder, origIfElse->getFalseBlock()); SLANG_ASSERT(diffFalseBlock); diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 78b8c5098..cbdb0a998 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -54,24 +54,6 @@ struct DiffTransposePass Flavor flavor; }; - struct BlockPredecessorEntry - { - // Previous block. - IRBlock* prevBlock; - - // Integer value corresponding to this predecessor. - IRIntegerValue* indexVal; - }; - - struct ControlFlowTranspositionInfo - { - // Variable used for recording control flow. - IRVar* controlVar; - - // Info about all possible predecessor blocks. - Dictionary<IRBlock*, IRInst*> predEntries; - }; - DiffTransposePass(AutoDiffSharedContext* autodiffContext) : autodiffContext(autodiffContext), pairBuilder(autodiffContext), diffTypeContext(autodiffContext) { } @@ -102,6 +84,50 @@ struct DiffTransposePass Dictionary<IRInst*, IRInst*>* primalsMap; }; + struct PendingBlockTerminatorEntry + { + IRBlock* fwdBlock; + List<IRInst*> phiGrads; + + PendingBlockTerminatorEntry() : fwdBlock(nullptr) + {} + + PendingBlockTerminatorEntry(IRBlock* fwdBlock, List<IRInst*> phiGrads) : + fwdBlock(fwdBlock), phiGrads(phiGrads) + {} + }; + + struct Region + { + IRBlock* exitBlock; + IRBlock* originBlock; + + Region* parent; + + Region() : + exitBlock(nullptr), + originBlock(nullptr), + parent(nullptr) + { } + + Region(IRBlock* exitBlock, Region* parent) : + exitBlock(exitBlock), + originBlock(nullptr), + parent(parent) + { } + + void finish(IRBlock* block) + { + SLANG_ASSERT(!this->originBlock); + this->originBlock = block; + } + + bool isComplete() + { + return (this->originBlock != nullptr); + } + }; + void transposeDiffBlocksInFunc( IRFunc* revDiffFunc, FuncTranspositionInfo transposeInfo) @@ -114,6 +140,11 @@ struct DiffTransposePass auto terminalPrimalBlocks = getTerminalPrimalBlocks(revDiffFunc); auto terminalDiffBlocks = getTerminalDiffBlocks(revDiffFunc); + // Add a top-level null region entry for the terminal diff block. + regionMap[terminalDiffBlocks[0]] = nullptr; + + buildAfterBlockMap(revDiffFunc); + // Traverse all instructions/blocks in reverse (starting from the terminator inst) // look for insts/blocks marked with IRDifferentialInstDecoration, // and transpose them in the revDiffFunc. @@ -170,6 +201,18 @@ struct DiffTransposePass this->transposeBlock(block, revBlock); } + // Some blocks may not have their control flow + // insts completed. Do them now that we have + // more information. + // + for (auto pendingBlockInfo : pendingBlocks) + { + builder.setInsertInto(revBlockMap[pendingBlockInfo.fwdBlock]); + completeEmitTerminator(&builder, pendingBlockInfo.fwdBlock, pendingBlockInfo.phiGrads); + } + + pendingBlocks.clear(); + // Link the last differential fwd-mode block (which will be the first // rev-mode block) as the successor to the last primal block. // We assume that the original function is in single-return form @@ -209,6 +252,8 @@ struct DiffTransposePass { block->removeAndDeallocate(); } + + cleanupRegionInfo(); } // Fetch or create a gradient accumulator var @@ -224,7 +269,11 @@ struct DiffTransposePass IRBuilder tempVarBuilder(autodiffContext->sharedBuilder); IRBlock* firstDiffBlock = firstRevDiffBlockMap[as<IRFunc>(fwdInst->getParent()->getParent())]; - tempVarBuilder.setInsertBefore(firstDiffBlock->getTerminator()); + + if (auto firstInst = firstDiffBlock->getFirstOrdinaryInst()) + tempVarBuilder.setInsertBefore(firstInst); + else + tempVarBuilder.setInsertInto(firstDiffBlock); auto primalType = tryGetPrimalTypeFromDiffInst(fwdInst); auto diffType = fwdInst->getDataType(); @@ -408,7 +457,13 @@ struct DiffTransposePass // We _should_ be completely out of gradients to process at this point. SLANG_ASSERT(gradientsMap.Count() == 0); - emitTerminator(&builder, fwdBlock, phiParamRevGradInsts); + if (!tryEmitTerminator(&builder, fwdBlock, phiParamRevGradInsts)) + { + // If we couldn't emit a terminator right away, defer for later. + pendingBlocks.add(PendingBlockTerminatorEntry( + fwdBlock, + phiParamRevGradInsts)); + } } void transposeInst(IRBuilder* builder, IRInst* inst) @@ -649,6 +704,54 @@ struct DiffTransposePass return terminalPrimalBlocks; } + IRBlock* getAfterBlock(IRBlock* block) + { + auto terminatorInst = block->getTerminator(); + switch (terminatorInst->getOp()) + { + case kIROp_unconditionalBranch: + case kIROp_Return: + return nullptr; + + case kIROp_ifElse: + return as<IRIfElse>(terminatorInst)->getAfterBlock(); + case kIROp_Switch: + return as<IRSwitch>(terminatorInst)->getBreakLabel(); + case kIROp_loop: + return as<IRLoop>(terminatorInst)->getBreakBlock(); + + default: + SLANG_UNIMPLEMENTED_X("Unhandled terminator inst when building after-block map"); + } + } + + void buildAfterBlockMap(IRGlobalValueWithCode* fwdFunc) + { + // Scan through a fwd-mode function, and build a list of blocks + // that appear as the 'after' block for any conditional control + // flow statement. + // + + for (auto block = fwdFunc->getFirstBlock(); block; block = block->getNextBlock()) + { + // Only need to process differential blocks. + if (!isDifferentialInst(block)) + continue; + + IRBlock* afterBlock = getAfterBlock(block); + + if (afterBlock) + { + // No block can by the after block for multiple control flow insts. + // + SLANG_ASSERT(!(afterBlockMap.ContainsKey(afterBlock) && \ + afterBlockMap[afterBlock] != block->getTerminator())); + + afterBlockMap[afterBlock] = block->getTerminator(); + } + } + } + List<IRBlock*> getTerminalDiffBlocks(IRGlobalValueWithCode* func) { // Terminal differential blocks are those with a return statement. @@ -663,244 +766,206 @@ struct DiffTransposePass return terminalDiffBlocks; } - IRInst* addPredecessorForBlock(IRBlock* block, IRBlock* predBlock) + bool doesBlockHaveDifferentialPredecessors(IRBlock* fwdBlock) { - if (!this->blockEntries.ContainsKey(block)) + for (auto block : fwdBlock->getPredecessors()) { - // We haven't encountered this block yet, create a var for this in the - // first code block. - auto firstCodeBlock = getFirstCodeBlock(block->getParent()); - - IRBuilder subBuilder(this->autodiffContext->sharedBuilder); - subBuilder.setInsertBefore(firstCodeBlock->getTerminator()); - auto controlVar = subBuilder.emitVar(subBuilder.getUIntType()); - - ControlFlowTranspositionInfo info; - info.controlVar = controlVar; - - this->blockEntries[block] = info; + if (isDifferentialInst(block)) + { + return true; + } } - auto info = this->blockEntries[block]; + return false; + } - // Does precessor block already exist? - if (info.GetValue().predEntries.ContainsKey(predBlock)) - { - return info.GetValue().predEntries[predBlock]; - } + IRBlock* insertPhiBlockBefore(IRBlock* revBlock, List<IRInst*> phiArgs) + { + IRBuilder phiBlockBuilder(autodiffContext->sharedBuilder); + phiBlockBuilder.setInsertBefore(revBlock); - // Otherwise, create an entry.. - auto uniqueIndex = info.GetValue().predEntries.Count(); + auto phiBlock = phiBlockBuilder.emitBlock(); - IRBuilder builder(this->autodiffContext->sharedBuilder); - auto uniqueIndexLiteral = builder.getIntValue(builder.getUIntType(), uniqueIndex); + if (isDifferentialInst(revBlock)) + phiBlockBuilder.markInstAsDifferential(phiBlock); - info.GetValue().predEntries[predBlock] = uniqueIndexLiteral; - - return uniqueIndexLiteral; + phiBlockBuilder.emitBranch( + revBlock, + phiArgs.getCount(), + phiArgs.getBuffer()); + + return phiBlock; } - IRVar* getControlVar(IRBlock* block) - { - return this->blockEntries[block].GetValue().controlVar; - } - - // Inserts a block between the branch from fwdPredecessorBlock to fwdBlock, which sets a control - // variable to a unique index. - // - IRInst* insertPreludeForPredecessor(IRBlock* fwdBlock, IRBlock* fwdPredecessorBlock) + // Create a region to track control flow from the + // the point of convergence (fwdConvBlock) back to the point of + // divergence, along one specific path (fwdExitBlock) + // + void pushRegion(IRBlock* fwdConvBlock, IRBlock* fwdExitBlock) { - // Get associated primal blocks for both the differential blocks. - auto primalPredecessorBlock = getPrimalBlock(fwdPredecessorBlock); - SLANG_ASSERT(primalPredecessorBlock); + SLANG_ASSERT(!regionMap.ContainsKey(fwdExitBlock)); + SLANG_ASSERT(regionMap.ContainsKey(fwdConvBlock)); - auto primalBlock = getPrimalBlock(fwdBlock); - SLANG_ASSERT(primalBlock); + Region* newRegion = new Region(fwdExitBlock, regionMap[fwdConvBlock]); + regions.add(newRegion); - // Add this block as a predecessor, and get an unique index (as an integer literal) - auto indexVal = addPredecessorForBlock(fwdBlock, fwdPredecessorBlock); - - IRBuilder subBuilder(this->autodiffContext->sharedBuilder); - subBuilder.setInsertInto(primalPredecessorBlock->getParent()); - - IRInst* preludeBlock = subBuilder.emitBlock(); - preludeBlock->insertAfter(primalPredecessorBlock); + regionMap[fwdExitBlock] = newRegion; + } - // Copy over phi parameters. - List<IRInst*> phiParams; - for (auto param = primalBlock->getFirstParam(); param; param = param->getNextParam()) + // If we have a conditional-branch from fwdBlock to fwdNextBlock + // complete the region, and remove from stack + // otherwise, copy the region over. + // + void propagateRegion(IRBlock* fwdNextBlock, IRBlock* fwdBlock) + { + if (as<IRConditionalBranch>(fwdBlock->getTerminator())) { - phiParams.add(subBuilder.emitParam(param->getDataType())); - } - - auto controlVar = getControlVar(fwdBlock); - subBuilder.emitStore(controlVar, indexVal); + Region* currentRegion = regionMap[fwdNextBlock]; + currentRegion->finish(fwdNextBlock); - // Branch into the successor block using all the same phi parameters. - subBuilder.emitBranch(primalBlock, phiParams.getCount(), phiParams.getBuffer()); - - // Scan through uses of primalBlock to find the ones that are in - // primalPredecessorBlock, and replace them with branches to - // preludeBlock. - // - List<IRUse*> relevantUses; - for (auto use = primalBlock->firstUse; use; use = use->nextUse) + regionMap[fwdBlock] = currentRegion->parent; + } + else if (as<IRUnconditionalBranch>(fwdBlock->getTerminator()) || + as<IRReturn>(fwdBlock->getTerminator())) { - if (use->getUser()->getParent() == primalPredecessorBlock) - relevantUses.add(use); + regionMap[fwdBlock] = regionMap[fwdNextBlock]; } - - for (auto use : relevantUses) - use->set(preludeBlock); - - return indexVal; } - bool doesBlockHaveDifferentialPredecessors(IRBlock* fwdBlock) + // Deallocate regions + void cleanupRegionInfo() { - for (auto block : fwdBlock->getPredecessors()) + for (auto region : regions) { - if (isDifferentialInst(block)) - { - return true; - } + delete region; } - return false; + regions.clear(); + regionMap.Clear(); } - void emitTerminator(IRBuilder* builder, IRBlock* fwdBlockInst, List<IRInst*> phiParamGrads) + bool tryEmitTerminator(IRBuilder* builder, IRBlock* fwdBlockInst, List<IRInst*> phiParamGrads) { // If this block has no differential predecessors, add a return statement. if (!doesBlockHaveDifferentialPredecessors(fwdBlockInst)) { // Emit a void return. builder->emitReturn(); - return; + return true; } + List<IRBlock*> fwdPredecesorBlocks; + // Check for predecessors count. for (auto predecessor : fwdBlockInst->getPredecessors()) { - // Insert code into the *primal* version of the predecessor block - // to set the control variable to indexVal before branching. - // - insertPreludeForPredecessor(fwdBlockInst, predecessor); + fwdPredecesorBlocks.add(predecessor); } - List<IRBlock*> revPredecessorBlocks; - List<IRInst*> indexVals; + SLANG_ASSERT(fwdPredecesorBlocks.getCount() > 0); - for (auto blockEntry : this->blockEntries[fwdBlockInst].GetValue().predEntries) + // If we have just one, we simply need the reverse-mode block to + // branch into the reverse-mode version of the predecessor block. + // (along with the appropriate phi args) + // + if (fwdPredecesorBlocks.getCount() == 1) { - revPredecessorBlocks.add(revBlockMap[blockEntry.Key]); - indexVals.add(blockEntry.Value); - } - - auto predCount = revPredecessorBlocks.getCount(); + builder->emitBranch( + revBlockMap[fwdPredecesorBlocks[0]], + phiParamGrads.getCount(), + phiParamGrads.getBuffer()); - SLANG_ASSERT(predCount > 0); - - List<IRBlock*> intermediateBranchBlocks; - - IRBuilder branchBlockBuilder(builder->getSharedBuilder()); - - branchBlockBuilder.setInsertInto(builder->getFunc()); + propagateRegion(fwdBlockInst, fwdPredecesorBlocks[0]); + return true; + } - // Make a block to unconditionally branch into predecessor-0 with the - // appropriate phi gradients. + // If we have more than one, then control flow 'converges' at this point. + // By convention, this block must be the after block for _some_ conditional + // control flow statement. + // If not, we are dealing with an inconsistent graph. // - auto firstBranchBlock = branchBlockBuilder.emitBlock(); - intermediateBranchBlocks.add(firstBranchBlock); - - branchBlockBuilder.markInstAsDifferential(firstBranchBlock); - branchBlockBuilder.emitBranch( - revPredecessorBlocks[0], - phiParamGrads.getCount(), - phiParamGrads.getBuffer()); - - // Create a builder to insert loads and comparison insts to figure - // out which block to branch into based on the control vars. - // This builder is set up to emit into the last _primal_ block. + // Rather than actually emitting the terminator here, we're going to + // defer to a pass after all the blocks have been transposed. + // This is because, while we know that this block is the point of convergence + // we don't know which predecessor belong to which side of the branch. + // We will instead create 'regions' to track each predecessor for every + // branch, and by the time all blocks are seen at-least once, we should have + // resolved the 'start' points for every predecessor. // - IRBuilder booleanIndicatorBuilder(builder->getSharedBuilder()); - auto terminalPrimalBlock = getTerminalPrimalBlocks(builder->getFunc())[0]; - booleanIndicatorBuilder.setInsertBefore(terminalPrimalBlock->getTerminator()); - - if (predCount == 1) + if (fwdPredecesorBlocks.getCount() > 1) { - builder->emitBranch(firstBranchBlock); + SLANG_ASSERT(afterBlockMap.ContainsKey(fwdBlockInst)); + + for (auto predecessor : fwdPredecesorBlocks) + { + // Trivial case when the predecessor itself is the point + // of divergence. + // + if (getAfterBlock(predecessor) == fwdBlockInst) + continue; + + pushRegion(fwdBlockInst, predecessor); + } } - else - { - IRBuilder ladderBlockBuilder(builder->getSharedBuilder()); - ladderBlockBuilder.setInsertInto(builder->getFunc()); - // TODO: For now, we're trivially setting 'afterBlock' to - // the first reverse block. This is not really optimal for the - // restructuring passes since the 'then' and 'else' regions - // can have significant overlap. - // - auto firstFwdDiffBlock = (*terminalPrimalBlock->getSuccessors().begin()); - SLANG_ASSERT(firstFwdDiffBlock); + return false; + } - auto defaultAfterBlock = revBlockMap[firstFwdDiffBlock]; + bool completeEmitTerminator(IRBuilder* builder, IRBlock* fwdBlockInst, List<IRInst*> phiParamGrads) + { + IRBlock* revBlock = revBlockMap[fwdBlockInst]; - auto nextLadderBlock = firstBranchBlock; - for (Index ii = 0; ii < predCount - 1; ii++) + // If we already have a terminator, we've resolved it during + // tryEmitTerminator() + // + if (revBlock->getTerminator() != nullptr) + return true; + + auto terminatorInst = as<IRInst>(afterBlockMap[fwdBlockInst]); + switch (terminatorInst->getOp()) + { + case kIROp_ifElse: { - // Make the 'leaf' block. This just branches into - // predecessor-i+1 with the appropriate phi args. - // - branchBlockBuilder.setInsertInto(branchBlockBuilder.getFunc()); - - auto thisIndexBlock = branchBlockBuilder.emitBlock(); - intermediateBranchBlocks.add(thisIndexBlock); - - branchBlockBuilder.markInstAsDifferential(thisIndexBlock); - branchBlockBuilder.emitBranch( - revPredecessorBlocks[ii+1], - phiParamGrads.getCount(), - phiParamGrads.getBuffer()); - - // Emit a boolean inst to represent whether we need to branch into - // block ii. - auto blockIndicatorInst = booleanIndicatorBuilder.emitEql( - booleanIndicatorBuilder.emitLoad(getControlVar(fwdBlockInst)), - indexVals[ii+1]); - - // Create a block to branch between i+1 and the rest of the ladder so far - // (0 ... i) - // - auto upperLadderBlock = ladderBlockBuilder.emitBlock(); - intermediateBranchBlocks.add(upperLadderBlock); - - ladderBlockBuilder.markInstAsDifferential(upperLadderBlock); - ladderBlockBuilder.emitIfElse( - blockIndicatorInst, - thisIndexBlock, - nextLadderBlock, - defaultAfterBlock); + auto ifElseInst = as<IRIfElse>(terminatorInst); - nextLadderBlock = upperLadderBlock; - } + auto condition = ifElseInst->getCondition(); + SLANG_ASSERT(!isDifferentialInst(condition)); - // Branch into the last ladder block. - builder->emitBranch(nextLadderBlock); - } + // fwd origin block is the reverse 'after' block. + auto revAfterBlock = as<IRBlock>( + revBlockMap[as<IRBlock>(ifElseInst->getParent())]); + + // Find region, and find the reverse-mode version of the + // exit block. + Region* trueRegion = regionMap[ifElseInst->getTrueBlock()]; + IRBlock* revTrueBlock = revBlockMap[trueRegion->exitBlock]; - // Insert all intermediate blocks in the order they were created, right after - // the current reverse block. - - auto revBlock = revBlockMap[fwdBlockInst]; - SLANG_ASSERT(revBlock); + Region* falseRegion = regionMap[ifElseInst->getFalseBlock()]; + IRBlock* revFalseBlock = revBlockMap[falseRegion->exitBlock]; - for (auto block : intermediateBranchBlocks) - { - block->insertAfter(revBlock); + // If we have phi derivatives to pass on, + // we need to add dummy blocks to pass them using + // an unconditional branch. + // + if (phiParamGrads.getCount() > 0) + { + revTrueBlock = insertPhiBlockBefore(revTrueBlock, phiParamGrads); + revFalseBlock = insertPhiBlockBefore(revFalseBlock, phiParamGrads); + + // Putting the phi blocks just after our current reverse-mode block + // is not necessary. Just to make intermediate IR easier to follow. + // + revTrueBlock->insertAfter(revBlock); + revFalseBlock->insertAfter(revBlock); + } + + builder->emitIfElse(condition, revTrueBlock, revFalseBlock, revAfterBlock); + break; + } + default: + SLANG_UNIMPLEMENTED_X("Unhandled control flow inst during transposition"); } - - return; + return false; } TranspositionResult transposeInst(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue) @@ -1622,12 +1687,19 @@ struct DiffTransposePass Dictionary<IRInst*, IRInst*>* primalsMap; List<IRInst*> usedPtrs; - - Dictionary<IRBlock*, ControlFlowTranspositionInfo> blockEntries; Dictionary<IRBlock*, IRBlock*> revBlockMap; Dictionary<IRGlobalValueWithCode*, IRBlock*> firstRevDiffBlockMap; + + Dictionary<IRBlock*, IRInst*> afterBlockMap; + + List<PendingBlockTerminatorEntry> pendingBlocks; + + Dictionary<IRBlock*, Region*> regionMap; + + List<Region*> regions; + }; |
