diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-01-30 11:46:36 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-01-30 08:46:36 -0800 |
| commit | 134dd7eb26fc7988ae13559d276cbf337b4b9d27 (patch) | |
| tree | 35bd06e6bebb4518bca805e14e85f8f9ef4341c6 /source | |
| parent | 4a66e9729175a89833e5db784bb64e6a7f60cdf2 (diff) | |
Overhauled reverse-mode control flow handling (#2608)
* Added switch-case support; fixed non-diff parameter transposition
* Made region propagation much more robust. Partial loop unzip implementation
* WIP: Added most loop handling code, and a test. Still untested
* Added CFG Normalization pass + CFG Reversal Pass + Loop Unzipping + most loop transcription
* Add single-iter-loop test.
* proj files
* removed comments
* Update reverse-loop.slang
* Removed out-of-date code
* Disabled IR validation during constructSSA phase of normalizeCFG. constructSSA now reuses sharedBuilder
* Moved normalizeCFG() call to prepareFuncForBackwardDiff()
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-cfg-norm.cpp | 629 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-cfg-norm.h | 26 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 694 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 532 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 21 | ||||
| -rw-r--r-- | source/slang/slang-ir-ssa.cpp | 23 | ||||
| -rw-r--r-- | source/slang/slang-ir-ssa.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-validate.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 26 |
12 files changed, 1683 insertions, 287 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp new file mode 100644 index 000000000..4e0a413db --- /dev/null +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -0,0 +1,629 @@ +// slang-ir-autodiff-cfg-norm.cpp +#include "slang-ir-autodiff-cfg-norm.h" +#include "slang-ir-eliminate-phis.h" +#include "slang-ir-ssa.h" + +#include "slang-ir-validate.h" + +namespace Slang +{ + +struct RegionEndpoint +{ + bool inBreakRegion = false; + bool inBaseRegion = false; + + IRBlock* exitBlock = nullptr; + + bool isRegionEmpty = false; + + RegionEndpoint(IRBlock* exitBlock, bool inBreakRegion, bool inBaseRegion) : + exitBlock(exitBlock), + inBreakRegion(inBreakRegion), + inBaseRegion(inBaseRegion), + isRegionEmpty(false) + { } + + RegionEndpoint( + IRBlock* exitBlock, + bool inBreakRegion, + bool inBaseRegion, + bool isRegionEmpty) : + exitBlock(exitBlock), + inBreakRegion(inBreakRegion), + inBaseRegion(inBaseRegion), + isRegionEmpty(isRegionEmpty) + { } + + RegionEndpoint() + { } +}; + +struct BreakableRegionInfo +{ + IRVar* breakVar; + IRBlock* breakBlock; +}; + +struct CFGNormalizationContext +{ + SharedIRBuilder* sharedBuilder; + DiagnosticSink* sink; +}; + + +IRBlock* getOrCreateTopLevelCondition(IRLoop* loopInst) +{ + // For now, we're going to naively assume the next block is the condition block. + // Add in more support for more cases as necessary. + // + + auto firstBlock = loopInst->getTargetBlock(); + + auto ifElse = as<IRIfElse>(firstBlock->getTerminator()); + SLANG_RELEASE_ASSERT(ifElse); + + return firstBlock; +} + +struct CFGNormalizationPass +{ + CFGNormalizationContext cfgContext; + + CFGNormalizationPass(CFGNormalizationContext ctx) : + cfgContext(ctx) + { } + + void replaceBreakWithAfterBlock( + IRBuilder* builder, + BreakableRegionInfo* info, + IRBlock* currBlock, + IRBlock* afterBlock, + IRBlock* parentAfterBlock) + { + SLANG_ASSERT(as<IRUnconditionalBranch>(currBlock->getTerminator())); + + currBlock->getTerminator()->removeAndDeallocate(); + + builder->setInsertInto(currBlock); + + builder->emitStore(info->breakVar, builder->getBoolValue(false)); + builder->emitBranch(afterBlock); + + // Is after-block unreachable? + if (auto unreachInst = as<IRUnreachable>(afterBlock->getFirstOrdinaryInst())) + { + // Link it to the parentAfterBlock. + builder->setInsertInto(afterBlock); + unreachInst->removeAndDeallocate(); + + /* + HashSet<IRBlock*> predecessorSet; + for (auto predecessor : parentAfterBlock->getPredecessors()) + predecessorSet.Add(predecessor); + + SLANG_ASSERT(predecessorSet.Count() <= 1); + */ + + builder->emitBranch(parentAfterBlock); + } + } + + IRBlock* getUnconditionalTarget(RegionEndpoint endpoint) + { + if (!endpoint.isRegionEmpty) + { + auto branchInst = as<IRUnconditionalBranch>(endpoint.exitBlock->getTerminator()); + SLANG_ASSERT(branchInst); + + return branchInst->getTargetBlock(); + } + else + { + return endpoint.exitBlock; + } + } + + IRBlock* maybeGetUnconditionalTarget(IRBlock* block) + { + auto branchInst = as<IRUnconditionalBranch>(block->getTerminator()); + + return branchInst ? branchInst->getTargetBlock() : nullptr; + } + + + bool isSuccessorBlock(IRBlock* baseBlock, IRBlock* succBlock) + { + for (auto successor : baseBlock->getSuccessors()) + if (successor == succBlock) + return true; + + return false; + } + + + RegionEndpoint getNormalizedRegionEndpoint( + BreakableRegionInfo* parentRegion, + IRBlock* entryBlock, + List<IRBlock*> afterBlocks) + { + IRBlock* currentBlock = entryBlock; + + // By default a region starts off with the 'base' control flow + // and not in the 'break' control flow + // It is the job of the *caller* to make sure the break flow + // does not reach this point. + // + bool currBreakRegion = false; + bool currBaseRegion = true; + + // Detect the trivial case. The current block is alredy + // in the next region => this region is empty. + // + if (afterBlocks.contains(currentBlock)) + return RegionEndpoint(currentBlock, currBreakRegion, currBaseRegion, true); + + IRBuilder builder(cfgContext.sharedBuilder); + + List<IRBlock*> pendingAfterBlocks; + + IRBlock* parentAfterBlock = afterBlocks[0]; + + // Follow this thread of execution till we hit an + // acceptable after block. + // + while (!afterBlocks.contains(maybeGetUnconditionalTarget(currentBlock))) + { + // Check the terminator. + auto terminator = currentBlock->getTerminator(); + switch (terminator->getOp()) + { + case kIROp_unconditionalBranch: + { + auto targetBlock = as<IRUnconditionalBranch>(terminator)->getTargetBlock(); + currentBlock = targetBlock; + break; + } + + case kIROp_ifElse: + { + auto ifElse = as<IRIfElse>(terminator); + + // Special case. One of the branches will + // lead back to the condition. + // + SLANG_ASSERT(ifElse->getAfterBlock() != parentRegion->breakBlock); + + auto trueEndPoint = getNormalizedRegionEndpoint( + parentRegion, + ifElse->getTrueBlock(), + List<IRBlock*>(ifElse->getAfterBlock(), parentRegion->breakBlock)); + + auto falseEndPoint = getNormalizedRegionEndpoint( + parentRegion, + ifElse->getFalseBlock(), + List<IRBlock*>(ifElse->getAfterBlock(), parentRegion->breakBlock)); + + auto trueTargetBlock = getUnconditionalTarget(trueEndPoint); + auto falseTargetBlock = getUnconditionalTarget(falseEndPoint); + + auto afterBlock = ifElse->getAfterBlock(); + + // Trivial case, both end-points branch into the after block + if (trueTargetBlock == afterBlock && + falseTargetBlock == afterBlock) + { + currentBlock = afterBlock; + break; + } + + auto afterBreakRegion = false; + auto afterBaseRegion = false; + + if (trueTargetBlock == parentRegion->breakBlock) + { + // Branch into after block (and set break variable) + replaceBreakWithAfterBlock( + &builder, + parentRegion, + trueEndPoint.exitBlock, + afterBlock, + parentAfterBlock); + + // If this branch breaks, then the after-block + // definitely has break-flow. + // + afterBreakRegion = true; + } + else + { + // If this branch naturally branches into our + // after-block, copy whatever flags the endpoints + // have. + // + afterBreakRegion = afterBreakRegion || trueEndPoint.inBreakRegion; + afterBaseRegion = afterBaseRegion || trueEndPoint.inBaseRegion; + } + + if (falseTargetBlock == parentRegion->breakBlock) + { + // Branch into after block (and set break variable) + replaceBreakWithAfterBlock( + &builder, + parentRegion, + falseEndPoint.exitBlock, + afterBlock, + parentAfterBlock); + + // If this branch breaks, then the after-block + // definitely has break-flow. + // + afterBreakRegion = true; + } + else + { + // If this branch naturally branches into our + // after-block, copy whatever flags the endpoints + // have. + // + afterBreakRegion = afterBreakRegion || falseEndPoint.inBreakRegion; + afterBaseRegion = afterBaseRegion || falseEndPoint.inBaseRegion; + } + + // TODO: For now, we're being overly cautious and assuming + // the after region might have something to execute. + // Ideally, we should check if the block is empty, and + // hold off on splitting until we encounter non-empty + // blocks. + // + afterBaseRegion = true; + + // Do we need to split the after region? + if (afterBaseRegion && afterBreakRegion) + { + // We could arrive at the after-block before or + // after encountering a break statement. + // To handle this, we'll split the flow by checking the break flag + // + builder.setInsertAfter(afterBlock); + + auto preAfterSplitBlock = builder.emitBlock(); + preAfterSplitBlock->insertBefore(afterBlock); + + auto afterSplitBlock = builder.emitBlock(); + afterSplitBlock->insertBefore(afterBlock); + + afterBlock->replaceUsesWith(preAfterSplitBlock); + + builder.setInsertInto(preAfterSplitBlock); + builder.emitBranch(afterSplitBlock); + + // Converging block for the split that we're making. + auto afterSplitAfterBlock = builder.emitBlock(); + + builder.setInsertInto(afterSplitBlock); + auto breakFlagValue = builder.emitLoad(parentRegion->breakVar); + + builder.emitIfElse( + breakFlagValue, + afterBlock, + afterSplitAfterBlock, + afterSplitAfterBlock); + + // At this point, we need to place afterSplitAfterBlock between + // at the _end_ of this region, but we aren't there yet (and + // don't know which block is the end of this region) + // Therefore, we'll defer this step and add it to a list for later. + // + pendingAfterBlocks.add(afterSplitAfterBlock); + + // Update current block. + currentBlock = afterBlock; + afterBreakRegion = false; + afterBaseRegion = true; + } + + currentBlock = afterBlock; + currBreakRegion = afterBreakRegion; + currBaseRegion = afterBaseRegion; + break; + } + + case kIROp_loop: + { + auto breakBlock = normalizeBreakableRegion(terminator); + + // Advance to the break block (no updates to the control flags) + currentBlock = breakBlock; + break; + } + + default: + // Do proper diagnosing + SLANG_UNEXPECTED("Unhandled control flow inst"); + break; + } + } + + // Resolve all intermediate after-blocks + pendingAfterBlocks.reverse(); + + for (auto block : pendingAfterBlocks) + { + builder.setInsertInto(block); + auto nextRegionBlock = maybeGetUnconditionalTarget(currentBlock); + SLANG_ASSERT(nextRegionBlock); + + builder.emitBranch(nextRegionBlock); + + builder.setInsertInto(currentBlock); + currentBlock->getTerminator()->removeAndDeallocate(); + builder.emitBranch(block); + + block->insertAfter(currentBlock); + + currentBlock = block; + currBaseRegion = true; + currBreakRegion = true; + } + + return RegionEndpoint(currentBlock, currBreakRegion, currBaseRegion); + } + + HashSet<IRBlock*> getPredecessorSet(IRBlock* block) + { + HashSet<IRBlock*> predecessorSet; + for (auto predecessor : block->getPredecessors()) + predecessorSet.Add(predecessor); + + return predecessorSet; + } + + bool isLoopTrivial(IRLoop* loop) + { + // Get 'looping' block (first block in loop) + auto firstLoopBlock = loop->getTargetBlock(); + + // If we only have one predecessor, the loop is trivial. + return (getPredecessorSet(firstLoopBlock).Count() == 1); + } + + IRBlock* normalizeBreakableRegion( + IRInst* branchInst) + { + IRBuilder builder(cfgContext.sharedBuilder); + + switch (branchInst->getOp()) + { + case kIROp_loop: + { + BreakableRegionInfo info; + info.breakBlock = as<IRLoop>(branchInst)->getBreakBlock(); + + // Emit var into parent block. + builder.setInsertBefore( + as<IRBlock>(branchInst->getParent())->getTerminator()); + + // Create and initialize break var to true + // true -> no break yet. + // false -> atleast one break statement hit. + // + info.breakVar = builder.emitVar(builder.getBoolType()); + builder.emitStore(info.breakVar, builder.getBoolValue(true)); + + // If the loop is trivial (i.e. single iteration, with no + // edges actually in a loop), we're just going to remove + // it.. (we can do this, because the normalization pass + // will transform any break and continue statements) + // + if (isLoopTrivial(as<IRLoop>(branchInst))) + { + auto firstLoopBlock = as<IRLoop>(branchInst)->getTargetBlock(); + auto terminator = firstLoopBlock->getTerminator(); + + // We really shouldn't see a conditional branch on a trivial loop + // but if we hit this assert, handle this case. + // + SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>(terminator)); + + // Normalize the region from the first loop block till break. + auto preBreakEndPoint = getNormalizedRegionEndpoint( + &info, + firstLoopBlock, + List<IRBlock*>(info.breakBlock)); + + // Should not be empty.. but check anyway + SLANG_RELEASE_ASSERT(!preBreakEndPoint.isRegionEmpty); + + // Quick consistency check.. preBreakEndPoint should be + // branching into break block. + SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>( + preBreakEndPoint.exitBlock->getTerminator())->getTargetBlock() == info.breakBlock); + + auto currentBlock = branchInst->getParent(); + + // Now get rid of the loop inst and replace with unconditional branch. + branchInst->removeAndDeallocate(); + builder.setInsertInto(currentBlock); + builder.emitBranch(firstLoopBlock); + + return info.breakBlock; + } + + auto condBlock = getOrCreateTopLevelCondition(as<IRLoop>(branchInst)); + + auto ifElse = as<IRIfElse>(condBlock->getTerminator()); + + auto trueEndPoint = getNormalizedRegionEndpoint( + &info, + ifElse->getTrueBlock(), + List<IRBlock*>(condBlock, info.breakBlock)); + + auto falseEndPoint = getNormalizedRegionEndpoint( + &info, + ifElse->getFalseBlock(), + List<IRBlock*>(condBlock, info.breakBlock)); + + RegionEndpoint loopEndPoint; + bool isLoopOnTrueSide = true; + + // First figure out which side belongs to the loop body. + if (isSuccessorBlock(trueEndPoint.exitBlock, condBlock)) + { + loopEndPoint = trueEndPoint; + isLoopOnTrueSide = true; + } + + if (isSuccessorBlock(falseEndPoint.exitBlock, condBlock)) + { + loopEndPoint = falseEndPoint; + isLoopOnTrueSide = false; + } + + SLANG_RELEASE_ASSERT(loopEndPoint.exitBlock); + + // Special case.. the if-else of a loop needs it's + // after block to be pointing at the last block before + // it loops back to the if-else. + // + // ifElse->afterBlock.set(loopEndPoint.exitBlock); + + // Does the loop endpoint have both 'break' and 'base' + // control flows? + // + if (loopEndPoint.inBaseRegion && loopEndPoint.inBreakRegion) + { + // Add a test for the break variable into the condition. + auto cond = ifElse->getCondition(); + + builder.setInsertAfter(cond); + auto breakFlagVal = builder.emitLoad(info.breakVar); + + // Need to invert the break flag if the loop is + // on the false side. + // + if (!isLoopOnTrueSide) + { + IRInst* args[1] = {breakFlagVal}; + breakFlagVal = builder.emitIntrinsicInst( + builder.getBoolType(), + kIROp_Not, + 1, + args); + } + + IRInst* args[2] = {cond, breakFlagVal}; + + // If break-var = true, direct flow to the loop + // otherwise, direct flow to break + // + auto complexCond = builder.emitIntrinsicInst( + builder.getBoolType(), + kIROp_And, + 2, + args); + + ifElse->condition.set(complexCond); + } + + return info.breakBlock; + } + case kIROp_Switch: + { + auto switchInst = as<IRSwitch>(branchInst); + + // SLANG_UNEXPECTED("Switch-case normalization not implemented yet."); + BreakableRegionInfo info; + info.breakBlock = as<IRSwitch>(branchInst)->getBreakLabel(); + + // Emit var into parent block. + builder.setInsertBefore( + as<IRBlock>(branchInst->getParent())->getTerminator()); + + // Create and initialize break var to true + // true -> no break yet. + // false -> atleast one break statement hit. + // + info.breakVar = builder.emitVar(builder.getBoolType()); + builder.emitStore(info.breakVar, builder.getBoolValue(true)); + + // Go over case labels and normalize all sub-regions. + for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii++) + { + auto caseBlock = switchInst->getCaseLabel(ii); + auto caseEndPoint = getNormalizedRegionEndpoint( + &info, + caseBlock, + List<IRBlock*>(info.breakBlock)).exitBlock; + + // Consistency check (if this case hits, it's probably + // because the switch has fall-through, which we don't support) + SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>( + caseEndPoint->getTerminator())->getTargetBlock() == info.breakBlock); + } + + auto defaultEndPoint = getNormalizedRegionEndpoint( + &info, + switchInst->getDefaultLabel(), + List<IRBlock*>(info.breakBlock)).exitBlock; + + // Consistency check (if this case hits, it's probably + // because the switch has fall-through, which we don't support) + SLANG_RELEASE_ASSERT(as<IRUnconditionalBranch>( + defaultEndPoint->getTerminator())->getTargetBlock() == info.breakBlock); + + return info.breakBlock; + } + default: + break; + } + + SLANG_UNEXPECTED("Unhandled control-flow inst"); + } +}; + +void normalizeCFG( + IRGlobalValueWithCode* func, + IRCFGNormalizationPass const& options) +{ + // Remove phis to simplify our pass. We'll add them back in later + // with constructSSA. + // + eliminatePhisInFunc(LivenessMode::Disabled, func->getModule(), func); + + SharedIRBuilder sharedBuilder(func->getModule()); + sharedBuilder.deduplicateAndRebuildGlobalNumberingMap(); + CFGNormalizationContext context = {&sharedBuilder, options.sink}; + CFGNormalizationPass cfgPass(context); + + List<IRBlock*> workList; + workList.add(func->getFirstBlock()); + + while (workList.getCount() > 0) + { + auto block = workList.getLast(); + workList.removeLast(); + + if (auto loop = as<IRLoop>(block->getTerminator())) + { + auto breakBlock = cfgPass.normalizeBreakableRegion(loop); + workList.add(breakBlock); + } + else if (auto switchCase = as<IRSwitch>(block->getTerminator())) + { + auto breakBlock = cfgPass.normalizeBreakableRegion(switchCase); + workList.add(breakBlock); + } + else + { + for (auto successor : block->getSuccessors()) + workList.add(successor); + } + } + + disableIRValidationAtInsert(); + constructSSA(&sharedBuilder, func); + enableIRValidationAtInsert(); +} + +}
\ No newline at end of file diff --git a/source/slang/slang-ir-autodiff-cfg-norm.h b/source/slang/slang-ir-autodiff-cfg-norm.h new file mode 100644 index 000000000..2a39f7695 --- /dev/null +++ b/source/slang/slang-ir-autodiff-cfg-norm.h @@ -0,0 +1,26 @@ +// slang-ir-autodiff-cfg-norm.h +#pragma once + +#include "slang-ir-insts.h" + +namespace Slang +{ + struct IRModule; + + struct IRCFGNormalizationPass + { + DiagnosticSink* sink; + }; + + /// Eliminate "break" statements from breakable regions + /// (loops, switch-case). This will use temporary booleans + /// instead of a break statement, in order to ensure all + /// branches inside the breakable region always have a valid + /// "after" block. + /// + void normalizeCFG( + IRGlobalValueWithCode* func, + IRCFGNormalizationPass const& options = IRCFGNormalizationPass()); + + IRBlock* getOrCreateTopLevelCondition(IRLoop* loopInst); +} diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index fce2043eb..6f18a3d8a 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -3,6 +3,7 @@ #include "slang-ir-clone.h" #include "slang-ir-dce.h" #include "slang-ir-eliminate-phis.h" +#include "slang-ir-autodiff-cfg-norm.h" #include "slang-ir-util.h" #include "slang-ir-inst-pass-base.h" #include "slang-ir-ssa-simplification.h" @@ -16,7 +17,7 @@ namespace Slang IRFuncType* BackwardDiffTranscriberBase::differentiateFunctionTypeImpl(IRBuilder* builder, IRFuncType* funcType, IRInst* intermeidateType) { List<IRType*> newParameterTypes; - IRType* diffReturnType; + IRType* diffReturnType; for (UIndex i = 0; i < funcType->getParamCount(); i++) { @@ -509,6 +510,9 @@ namespace Slang } eliminateMultiLevelBreakForFunc(func->getModule(), func); + IRCFGNormalizationPass cfgPass = {this->getSink()}; + normalizeCFG(func); + AutoDiffAddressConversionPolicy cvtPolicty; cvtPolicty.diffTypeContext = &diffTypeContext; auto result = eliminateAddressInsts(sharedBuilder, &cvtPolicty, func, sink); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index f43206333..05a5f8f56 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -960,8 +960,8 @@ InstPair AutoDiffTranscriberBase::transcribeInst(IRBuilder* builder, IRInst* ori } else { - auto diffType = _differentiateTypeImpl(builder, origType); IRInst* primal = maybeCloneForPrimalInst(builder, origType); + auto diffType = _differentiateTypeImpl(builder, origType); result = InstPair(primal, diffType); } } diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 0d45c6a84..901649f3c 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -7,6 +7,7 @@ #include "slang-ir-autodiff.h" #include "slang-ir-autodiff-fwd.h" +#include "slang-ir-autodiff-cfg-norm.h" namespace Slang { @@ -96,37 +97,384 @@ struct DiffTransposePass fwdBlock(fwdBlock), phiGrads(phiGrads) {} }; - - struct Region + + bool isBlockLastInRegion(IRBlock* block, List<IRBlock*> endBlocks) { - IRBlock* exitBlock; - IRBlock* originBlock; + if (auto branchInst = as<IRUnconditionalBranch>(block->getTerminator())) + { + if (endBlocks.contains(branchInst->getTargetBlock())) + return true; + else + return false; + } + else if (as<IRReturn>(block->getTerminator())) + { + return true; + } - Region* parent; + return false; + } + + List<IRInst*> getPhiGrads(IRBlock* block) + { + if (!phiGradsMap.ContainsKey(block)) + return List<IRInst*>(); + + return phiGradsMap[block]; + } - Region() : - exitBlock(nullptr), - originBlock(nullptr), - parent(nullptr) + struct RegionEntryPoint + { + IRBlock* revEntry; + IRBlock* fwdEndPoint; + bool isTrivial; + + RegionEntryPoint(IRBlock* revEntry, IRBlock* fwdEndPoint) : + revEntry(revEntry), + fwdEndPoint(fwdEndPoint), + isTrivial(false) { } - Region(IRBlock* exitBlock, Region* parent) : - exitBlock(exitBlock), - originBlock(nullptr), - parent(parent) + RegionEntryPoint(IRBlock* revEntry, IRBlock* fwdEndPoint, bool isTrivial) : + revEntry(revEntry), + fwdEndPoint(fwdEndPoint), + isTrivial(isTrivial) { } + }; + + IRBlock* getUniquePredecessor(IRBlock* block) + { + HashSet<IRBlock*> predecessorSet; + for (auto predecessor : block->getPredecessors()) + predecessorSet.Add(predecessor); + + SLANG_ASSERT(predecessorSet.Count() == 1); + + return (*predecessorSet.begin()); + } + + RegionEntryPoint reverseCFGRegion(IRBlock* block, List<IRBlock*> endBlocks) + { + IRBlock* revBlock = revBlockMap[block]; - void finish(IRBlock* block) + if (endBlocks.contains(block)) { - SLANG_ASSERT(!this->originBlock); - this->originBlock = block; + return RegionEntryPoint(revBlock, block, true); + } + + // We shouldn't already have a terminator for this block + SLANG_ASSERT(revBlock->getTerminator() == nullptr); + + IRBuilder builder(autodiffContext->sharedBuilder); + + auto currentBlock = block; + while (!isBlockLastInRegion(currentBlock, endBlocks)) + { + auto terminator = currentBlock->getTerminator(); + switch(terminator->getOp()) + { + case kIROp_Return: + return RegionEntryPoint(revBlockMap[currentBlock], nullptr); + + case kIROp_unconditionalBranch: + { + auto branchInst = as<IRUnconditionalBranch>(terminator); + auto nextBlock = as<IRBlock>(branchInst->getTargetBlock()); + IRBlock* nextRevBlock = revBlockMap[nextBlock]; + IRBlock* currRevBlock = revBlockMap[currentBlock]; + + SLANG_ASSERT(nextRevBlock->getTerminator() == nullptr); + builder.setInsertInto(nextRevBlock); + + builder.emitBranch(currRevBlock, + getPhiGrads(nextBlock).getCount(), + getPhiGrads(nextBlock).getBuffer()); + + + currentBlock = nextBlock; + break; + } + + case kIROp_ifElse: + { + auto ifElse = as<IRIfElse>(terminator); + + auto trueBlock = ifElse->getTrueBlock(); + auto falseBlock = ifElse->getFalseBlock(); + auto afterBlock = ifElse->getAfterBlock(); + + auto revTrueRegionInfo = reverseCFGRegion( + trueBlock, + List<IRBlock*>(afterBlock)); + auto revFalseRegionInfo = reverseCFGRegion( + falseBlock, + List<IRBlock*>(afterBlock)); + //bool isTrueTrivial = (trueBlock == afterBlock); + //bool isFalseTrivial = (falseBlock == afterBlock); + + IRBlock* revCondBlock = revBlockMap[afterBlock]; + SLANG_ASSERT(revCondBlock->getTerminator() == nullptr); + + + IRBlock* revTrueEntryBlock = revTrueRegionInfo.revEntry; + IRBlock* revFalseEntryBlock = revFalseRegionInfo.revEntry; + + IRBlock* revTrueExitBlock = revBlockMap[trueBlock]; + IRBlock* revFalseExitBlock = revBlockMap[falseBlock]; + + auto phiGrads = getPhiGrads(afterBlock); + if (phiGrads.getCount() > 0) + { + revTrueEntryBlock = insertPhiBlockBefore(revTrueEntryBlock, phiGrads); + revFalseEntryBlock = insertPhiBlockBefore(revFalseEntryBlock, phiGrads); + } + + IRBlock* revAfterBlock = revBlockMap[currentBlock]; + + builder.setInsertInto(revCondBlock); + builder.emitIfElse( + ifElse->getCondition(), + revTrueEntryBlock, + revFalseEntryBlock, + revAfterBlock); + + if (!revTrueRegionInfo.isTrivial) + { + builder.setInsertInto(revTrueExitBlock); + SLANG_ASSERT(revTrueExitBlock->getTerminator() == nullptr); + builder.emitBranch( + revAfterBlock, + getPhiGrads(trueBlock).getCount(), + getPhiGrads(trueBlock).getBuffer()); + } + + if (!revFalseRegionInfo.isTrivial) + { + builder.setInsertInto(revFalseExitBlock); + SLANG_ASSERT(revFalseExitBlock->getTerminator() == nullptr); + builder.emitBranch( + revAfterBlock, + getPhiGrads(falseBlock).getCount(), + getPhiGrads(falseBlock).getBuffer()); + } + + currentBlock = afterBlock; + break; + } + + case kIROp_loop: + { + auto loop = as<IRLoop>(terminator); + + auto firstLoopBlock = loop->getTargetBlock(); + auto breakBlock = loop->getBreakBlock(); + + auto condBlock = getOrCreateTopLevelCondition(loop); + + auto ifElse = as<IRIfElse>(condBlock->getTerminator()); + + auto trueBlock = ifElse->getTrueBlock(); + auto falseBlock = ifElse->getFalseBlock(); + + auto trueRegionInfo = reverseCFGRegion( + trueBlock, + List<IRBlock*>(breakBlock, condBlock)); + + auto falseRegionInfo = reverseCFGRegion( + falseBlock, + List<IRBlock*>(breakBlock, condBlock)); + + auto preCondRegionInfo = reverseCFGRegion( + firstLoopBlock, + List<IRBlock*>(condBlock)); + + // assume loop[next] -> cond can be a region and reverse it. + // assume cond[false] -> break can be a region and reverse it. + // assume cond[true] -> cond can be a region and reverse it. + // rev-loop = rev[break] + // rev-cond = rev[cond] + // rev-cond[true] -> entry of (cond[true] -> cond) + // rev-cond[false] -> entry of (loop[next] -> cond) + // exit of (cond[false]->break) branches into rev-cond + // rev-loop[next] -> entry of (cond[false] -> break) + // exit of (cond[true] -> cond) branches into rev-cond + // exit of (loop[next] -> cond) branches into rev[loop] (rev-break) + + // For now, we'll assume the loop is always on the 'true' side + // If this assert fails, add in the case where the loop + // may be on the 'false' side. + // + SLANG_RELEASE_ASSERT(trueRegionInfo.fwdEndPoint == condBlock); + + auto revTrueBlock = trueRegionInfo.revEntry; + auto revFalseBlock = (preCondRegionInfo.isTrivial) ? + revBlockMap[currentBlock] : preCondRegionInfo.revEntry; + + // The block that will become target of the new loop inst + // (the old false-region) This _could_ be the condition itself + // + IRBlock* revPreCondBlock = (falseRegionInfo.isTrivial) ? + revBlockMap[condBlock] : falseRegionInfo.revEntry; + + // Old cond block remains new cond block. + IRBlock* revCondBlock = revBlockMap[condBlock]; + + // Old cond block becomes new pre-break block. + IRBlock* revBreakBlock = revBlockMap[currentBlock]; + + // Old true-side starting block becomes loop end block. + IRBlock* revLoopEndBlock = revBlockMap[trueBlock]; + builder.setInsertInto(revLoopEndBlock); + builder.emitBranch( + revCondBlock, + getPhiGrads(trueBlock).getCount(), + getPhiGrads(trueBlock).getBuffer()); + + // Old false-side starting block becomes end block + // for the new pre-cond region (which could be empty) + // + IRBlock* revPreCondEndBlock = revBlockMap[falseBlock]; + if (!falseRegionInfo.isTrivial) + { + builder.setInsertInto(revPreCondEndBlock); + builder.emitBranch( + revCondBlock, + getPhiGrads(falseBlock).getCount(), + getPhiGrads(falseBlock).getBuffer()); + } + + IRBlock* revBreakRegionExitBlock = revBlockMap[firstLoopBlock]; + if (!preCondRegionInfo.isTrivial) + { + builder.setInsertInto(revBreakRegionExitBlock); + builder.emitBranch( + revBreakBlock, + getPhiGrads(firstLoopBlock).getCount(), + getPhiGrads(firstLoopBlock).getBuffer()); + } + + // Emit condition into the new cond block. + builder.setInsertInto(revCondBlock); + builder.emitIfElse( + ifElse->getCondition(), + revTrueBlock, + revFalseBlock, + revLoopEndBlock); + + // Emit loop into rev-version of the break block. + auto revLoopBlock = revBlockMap[breakBlock]; + builder.setInsertInto(revLoopBlock); + builder.emitLoop( + revPreCondBlock, + revBreakBlock, + revLoopEndBlock, + getPhiGrads(breakBlock).getCount(), + getPhiGrads(breakBlock).getBuffer()); + + currentBlock = breakBlock; + break; + } + + case kIROp_Switch: + { + auto switchInst = as<IRSwitch>(terminator); + + auto breakBlock = switchInst->getBreakLabel(); + + IRBlock* revBreakBlock = revBlockMap[currentBlock]; + + // Reverse each case label + List<IRInst*> reverseSwitchArgs; + Dictionary<IRBlock*, IRBlock*> reverseLabelEntryBlocks; + + for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii++) + { + reverseSwitchArgs.add(switchInst->getCaseValue(ii)); + + auto caseLabel = switchInst->getCaseLabel(ii); + if (!reverseLabelEntryBlocks.ContainsKey(caseLabel)) + { + auto labelRegionInfo = reverseCFGRegion( + caseLabel, + List<IRBlock*>(breakBlock)); + + // Handle this case eventually. + SLANG_ASSERT(!labelRegionInfo.isTrivial); + + // Wire the exit to the break block + IRBlock* revLabelExit = revBlockMap[caseLabel]; + SLANG_ASSERT(revLabelExit->getTerminator() == nullptr); + + builder.setInsertInto(revLabelExit); + builder.emitBranch(revBreakBlock); + + reverseLabelEntryBlocks[caseLabel] = labelRegionInfo.revEntry; + reverseSwitchArgs.add(labelRegionInfo.revEntry); + } + else + { + reverseSwitchArgs.add(reverseLabelEntryBlocks[caseLabel]); + } + } + + auto defaultRegionInfo = reverseCFGRegion( + switchInst->getDefaultLabel(), + List<IRBlock*>(breakBlock)); + SLANG_ASSERT(!defaultRegionInfo.isTrivial); + + auto revDefaultRegionEntry = defaultRegionInfo.revEntry; + + builder.setInsertInto(revBlockMap[switchInst->getDefaultLabel()]); + builder.emitBranch(revBreakBlock); + + auto phiGrads = getPhiGrads(breakBlock); + if (phiGrads.getCount() > 0) + { + for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii++) + { + reverseSwitchArgs[ii * 2 + 1] = + insertPhiBlockBefore(as<IRBlock>(reverseSwitchArgs[ii * 2 + 1]), phiGrads); + } + revDefaultRegionEntry = + insertPhiBlockBefore(as<IRBlock>(revDefaultRegionEntry), phiGrads); + } + + auto revSwitchBlock = revBlockMap[breakBlock]; + builder.setInsertInto(revSwitchBlock); + builder.emitSwitch( + switchInst->getCondition(), + revBreakBlock, + revDefaultRegionEntry, + reverseSwitchArgs.getCount(), + reverseSwitchArgs.getBuffer()); + + currentBlock = breakBlock; + break; + } + + } } - bool isComplete() + if (auto branchInst = as<IRUnconditionalBranch>(currentBlock->getTerminator())) { - return (this->originBlock != nullptr); + return RegionEntryPoint( + revBlockMap[currentBlock], + branchInst->getTargetBlock(), + false); } - }; + else if (auto returnInst = as<IRReturn>(currentBlock->getTerminator())) + { + return RegionEntryPoint( + revBlockMap[currentBlock], + nullptr, + true); + } + else + { + // Regions should _really_ not end on a conditional branch (I think) + SLANG_UNEXPECTED("Unexpected: Region ended on a conditional branch"); + } + } void transposeDiffBlocksInFunc( IRFunc* revDiffFunc, @@ -140,11 +488,6 @@ struct DiffTransposePass auto terminalPrimalBlocks = getTerminalPrimalBlocks(revDiffFunc); auto terminalDiffBlocks = getTerminalDiffBlocks(revDiffFunc); - // Add a top-level null region entry for the terminal diff block. - regionMap[terminalDiffBlocks[0]] = nullptr; - - buildAfterBlockMap(revDiffFunc); - // Traverse all instructions/blocks in reverse (starting from the terminator inst) // look for insts/blocks marked with IRDifferentialInstDecoration, // and transpose them in the revDiffFunc. @@ -184,7 +527,7 @@ struct DiffTransposePass // Keep track of first diff block, since this is where // we'll emit temporary vars to hold per-block derivatives. // - firstRevDiffBlockMap[revDiffFunc] = revBlockMap[workList[0]]; + firstRevDiffBlockMap[revDiffFunc] = revBlockMap[terminalDiffBlocks[0]]; IRInst* retVal = nullptr; @@ -201,17 +544,14 @@ struct DiffTransposePass this->transposeBlock(block, revBlock); } - // Some blocks may not have their control flow - // insts completed. Do them now that we have - // more information. + // At this point all insts have been transposed, but the blocks + // have no control flow. + // reverseCFG will use fwd-mode blocks as reference, and + // wire the corresponding rev-mode blocks in reverse. // - for (auto pendingBlockInfo : pendingBlocks) - { - builder.setInsertInto(revBlockMap[pendingBlockInfo.fwdBlock]); - completeEmitTerminator(&builder, pendingBlockInfo.fwdBlock, pendingBlockInfo.phiGrads); - } - - pendingBlocks.clear(); + auto branchInst = as<IRUnconditionalBranch>(terminalPrimalBlocks[0]->getTerminator()); + auto firstFwdDiffBlock = branchInst->getTargetBlock(); + reverseCFGRegion(firstFwdDiffBlock, List<IRBlock*>()); // Link the last differential fwd-mode block (which will be the first // rev-mode block) as the successor to the last primal block. @@ -223,7 +563,7 @@ struct DiffTransposePass SLANG_ASSERT(terminalDiffBlocks.getCount() == 1); auto terminalPrimalBlock = terminalPrimalBlocks[0]; - auto terminalRevBlock = as<IRBlock>(revBlockMap[terminalDiffBlocks[0]]); + auto firstRevBlock = as<IRBlock>(revBlockMap[terminalDiffBlocks[0]]); terminalPrimalBlock->getTerminator()->removeAndDeallocate(); @@ -231,9 +571,9 @@ struct DiffTransposePass subBuilder.setInsertInto(terminalPrimalBlock); // There should be no parameters in the first reverse-mode block. - SLANG_ASSERT(terminalRevBlock->getFirstParam() == nullptr); + SLANG_ASSERT(firstRevBlock->getFirstParam() == nullptr); - auto branch = subBuilder.emitBranch(terminalRevBlock); + auto branch = subBuilder.emitBranch(firstRevBlock); if (!retVal) { @@ -247,13 +587,20 @@ struct DiffTransposePass subBuilder.addBackwardDerivativePrimalReturnDecoration(branch, retVal); } + // At this point, the only block left without terminator insts + // should be the last one. Add a void return to complete it. + // + IRBlock* lastRevBlock = revBlockMap[firstFwdDiffBlock]; + SLANG_ASSERT(lastRevBlock->getTerminator() == nullptr); + + builder.setInsertInto(lastRevBlock); + builder.emitReturn(); + // Remove fwd-mode blocks. for (auto block : workList) { block->removeAndDeallocate(); } - - cleanupRegionInfo(); } // Fetch or create a gradient accumulator var @@ -385,6 +732,17 @@ struct DiffTransposePass List<IRInst*> phiParamRevGradInsts; for (IRParam* param = fwdBlock->getFirstParam(); param; param = param->getNextParam()) { + // This param might be used outside this block. + // If so, add/get an accumulator. + // + if (isInstUsedOutsideParentBlock(param)) + { + auto accVar = getOrCreateAccumulatorVar(param); + addRevGradientForFwdInst( + param, + RevGradient(param, builder.emitLoad(accVar), nullptr)); + } + if (hasRevGradients(param)) { auto gradients = popRevGradients(param); @@ -396,6 +754,11 @@ struct DiffTransposePass phiParamRevGradInsts.add(gradInst); } + else + { + phiParamRevGradInsts.add( + emitDZeroOfDiffInstType(&builder, tryGetPrimalTypeFromDiffInst(param))); + } } // Also handle any remaining gradients for insts that appear in prior blocks. @@ -448,13 +811,9 @@ struct DiffTransposePass // We _should_ be completely out of gradients to process at this point. SLANG_ASSERT(gradientsMap.Count() == 0); - if (!tryEmitTerminator(&builder, fwdBlock, phiParamRevGradInsts)) - { - // If we couldn't emit a terminator right away, defer for later. - pendingBlocks.add(PendingBlockTerminatorEntry( - fwdBlock, - phiParamRevGradInsts)); - } + // Record any phi gradients for the CFG reversal pass. + phiGradsMap[fwdBlock] = phiParamRevGradInsts; + } void transposeInst(IRBuilder* builder, IRInst* inst) @@ -467,6 +826,15 @@ struct DiffTransposePass break; } + // Some special instructions simply need to be copied over. + // These do not deal with differentials. + // + if (inst->findDecoration<IRLoopCounterDecoration>()) + { + inst->insertAtEnd(builder->getBlock()); + return; + } + // Look for gradient entries for this inst. List<RevGradient> gradients; if (hasRevGradients(inst)) @@ -787,234 +1155,6 @@ struct DiffTransposePass return phiBlock; } - - // Create a region to track control flow from the - // the point of convergence (fwdConvBlock) back to the point of - // divergence, along one specific path (fwdExitBlock) - // - void pushRegion(IRBlock* fwdConvBlock, IRBlock* fwdExitBlock) - { - SLANG_ASSERT(!regionMap.ContainsKey(fwdExitBlock)); - SLANG_ASSERT(regionMap.ContainsKey(fwdConvBlock)); - - Region* newRegion = new Region(fwdExitBlock, regionMap[fwdConvBlock]); - regions.add(newRegion); - - regionMap[fwdExitBlock] = newRegion; - } - - // If we have a conditional-branch from fwdBlock to fwdNextBlock - // complete the region, and remove from stack - // otherwise, copy the region over. - // - void propagateRegion(IRBlock* fwdNextBlock, IRBlock* fwdBlock) - { - if (as<IRConditionalBranch>(fwdBlock->getTerminator())) - { - Region* currentRegion = regionMap[fwdNextBlock]; - currentRegion->finish(fwdNextBlock); - - regionMap[fwdBlock] = currentRegion->parent; - } - else if (as<IRUnconditionalBranch>(fwdBlock->getTerminator()) || - as<IRReturn>(fwdBlock->getTerminator())) - { - regionMap[fwdBlock] = regionMap[fwdNextBlock]; - } - } - - // Deallocate regions - void cleanupRegionInfo() - { - for (auto region : regions) - { - delete region; - } - - regions.clear(); - regionMap.Clear(); - } - - bool tryEmitTerminator(IRBuilder* builder, IRBlock* fwdBlockInst, List<IRInst*> phiParamGrads) - { - // If this block has no differential predecessors, add a return statement. - if (!doesBlockHaveDifferentialPredecessors(fwdBlockInst)) - { - // Emit a void return. - builder->emitReturn(); - return true; - } - - List<IRBlock*> fwdPredecesorBlocks; - // Check for predecessors count. - for (auto predecessor : fwdBlockInst->getPredecessors()) - { - if (!fwdPredecesorBlocks.contains(predecessor)) - fwdPredecesorBlocks.add(predecessor); - } - - SLANG_ASSERT(fwdPredecesorBlocks.getCount() > 0); - - // If we have just one, we simply need the reverse-mode block to - // branch into the reverse-mode version of the predecessor block. - // (along with the appropriate phi args) - // - if (fwdPredecesorBlocks.getCount() == 1) - { - builder->emitBranch( - revBlockMap[fwdPredecesorBlocks[0]], - phiParamGrads.getCount(), - phiParamGrads.getBuffer()); - - propagateRegion(fwdBlockInst, fwdPredecesorBlocks[0]); - return true; - } - - // If we have more than one, then control flow 'converges' at this point. - // By convention, this block must be the after block for _some_ conditional - // control flow statement. - // If not, we are dealing with an inconsistent graph. - // - // Rather than actually emitting the terminator here, we're going to - // defer to a pass after all the blocks have been transposed. - // This is because, while we know that this block is the point of convergence - // we don't know which predecessor belong to which side of the branch. - // We will instead create 'regions' to track each predecessor for every - // branch, and by the time all blocks are seen at-least once, we should have - // resolved the 'start' points for every predecessor. - // - - if (fwdPredecesorBlocks.getCount() > 1) - { - SLANG_ASSERT(afterBlockMap.ContainsKey(fwdBlockInst)); - - for (auto predecessor : fwdPredecesorBlocks) - { - // Trivial case when the predecessor itself is the point - // of divergence. - // - if (getAfterBlock(predecessor) == fwdBlockInst) - continue; - - pushRegion(fwdBlockInst, predecessor); - } - } - - return false; - } - - bool completeEmitTerminator(IRBuilder* builder, IRBlock* fwdBlockInst, List<IRInst*> phiParamGrads) - { - IRBlock* revBlock = revBlockMap[fwdBlockInst]; - - // If we already have a terminator, we've probably resolved it during - // tryEmitTerminator() - // - if (revBlock->getTerminator() != nullptr) - return true; - - auto terminatorInst = as<IRInst>(afterBlockMap[fwdBlockInst]); - switch (terminatorInst->getOp()) - { - case kIROp_ifElse: - { - auto ifElseInst = as<IRIfElse>(terminatorInst); - - auto condition = ifElseInst->getCondition(); - SLANG_ASSERT(!isDifferentialInst(condition)); - - // fwd origin block is the reverse 'after' block. - auto revAfterBlock = as<IRBlock>( - revBlockMap[as<IRBlock>(ifElseInst->getParent())]); - - // Find region, and find the reverse-mode version of the - // exit block. - Region* trueRegion = regionMap[ifElseInst->getTrueBlock()]; - IRBlock* revTrueBlock = revBlockMap[trueRegion->exitBlock]; - - Region* falseRegion = regionMap[ifElseInst->getFalseBlock()]; - IRBlock* revFalseBlock = revBlockMap[falseRegion->exitBlock]; - - // If we have phi derivatives to pass on, - // we need to add dummy blocks to pass them using - // an unconditional branch. - // - if (phiParamGrads.getCount() > 0) - { - revTrueBlock = insertPhiBlockBefore(revTrueBlock, phiParamGrads); - revFalseBlock = insertPhiBlockBefore(revFalseBlock, phiParamGrads); - - // Putting the phi blocks just after our current reverse-mode block - // is not necessary. Just to make intermediate IR easier to follow. - // - revTrueBlock->insertAfter(revBlock); - revFalseBlock->insertAfter(revBlock); - } - - builder->emitIfElse(condition, revTrueBlock, revFalseBlock, revAfterBlock); - return true; - } - case kIROp_Switch: - { - auto switchInst = as<IRSwitch>(terminatorInst); - - auto condition = switchInst->getCondition(); - SLANG_ASSERT(!isDifferentialInst(condition)); - - // fwd origin block is the reverse 'break' block. - auto revAfterBlock = as<IRBlock>( - revBlockMap[as<IRBlock>(switchInst->getParent())]); - - // Find regions for every branch, and find the reverse-mode - // version of the each exit block. - Region* defaultRegion = regionMap[switchInst->getDefaultLabel()]; - IRBlock* revDefaultBlock = revBlockMap[defaultRegion->exitBlock]; - - List<IRBlock*> revCaseBlocks; - for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++) - { - Region* caseRegion = regionMap[switchInst->getCaseLabel(ii)]; - IRBlock* revCaseBlock = revBlockMap[caseRegion->exitBlock]; - revCaseBlocks.add(revCaseBlock); - } - - // If we have phi derivatives to pass on, - // we need to add dummy blocks to pass them using - // an unconditional branch. - // - if (phiParamGrads.getCount() > 0) - { - revDefaultBlock = insertPhiBlockBefore(revDefaultBlock, phiParamGrads); - revDefaultBlock->insertAfter(revBlock); - - for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++) - { - revCaseBlocks[ii] = insertPhiBlockBefore(revCaseBlocks[ii], phiParamGrads); - revCaseBlocks[ii]->insertAfter(revBlock); - } - } - - List<IRInst*> revCaseArgs; - for (UIndex ii = 0; ii < switchInst->getCaseCount(); ii ++) - { - revCaseArgs.add(switchInst->getCaseValue(ii)); - revCaseArgs.add(revCaseBlocks[ii]); - } - - builder->emitSwitch( - condition, - revAfterBlock, - revDefaultBlock, - revCaseArgs.getCount(), - revCaseArgs.getBuffer()); - - return true; - } - default: - SLANG_UNIMPLEMENTED_X("Unhandled control flow inst during transposition"); - } - return false; - } TranspositionResult transposeInst(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue) { @@ -1972,9 +2112,7 @@ struct DiffTransposePass List<PendingBlockTerminatorEntry> pendingBlocks; - Dictionary<IRBlock*, Region*> regionMap; - - List<Region*> regions; + Dictionary<IRBlock*, List<IRInst*>> phiGradsMap; }; diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index c525191a3..d808cbb5e 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -33,6 +33,77 @@ struct DiffUnzipPass // might run into an issue here? IRBlock* firstDiffBlock; + struct IndexedRegion + { + // Parent indexed region (for nested loops) + IndexedRegion* parent = nullptr; + + // Intializer block for the index. + IRBlock* initBlock = nullptr; + + // Index 'starts' at the first loop block (included) + IRBlock* firstBlock = nullptr; + + // Index stops at the break block (not included) + IRBlock* breakBlock = nullptr; + + // Block where index updates happen. + IRBlock* continueBlock = nullptr; + + // After lowering, store references to the count + // variables associated with this region + // + IRVar* primalCountVar = nullptr; + IRVar* diffCountVar = nullptr; + + enum CountStatus + { + Unresolved, + Dynamic, + Static + }; + + CountStatus status = CountStatus::Unresolved; + + // Inferred maximum number of iterations. + Count maxIters = -1; + + IndexedRegion() : + parent(nullptr), + initBlock(nullptr), + firstBlock(nullptr), + breakBlock(nullptr), + continueBlock(nullptr), + primalCountVar(nullptr), + diffCountVar(nullptr), + status(CountStatus::Unresolved), + maxIters(-1) + { } + + IndexedRegion( + IndexedRegion* parent, + IRBlock* initBlock, + IRBlock* firstBlock, + IRBlock* breakBlock, + IRBlock* continueBlock) : + parent(parent), + initBlock(initBlock), + firstBlock(firstBlock), + breakBlock(breakBlock), + continueBlock(continueBlock), + primalCountVar(nullptr), + diffCountVar(nullptr), + status(CountStatus::Unresolved), + maxIters(-1) + { } + }; + + // Keep track of indexed blocks and their corresponding index heirarchy. + Dictionary<IRBlock*, IndexedRegion*> indexRegionMap; + + List<IndexedRegion*> indexRegions; + + DiffUnzipPass( AutoDiffSharedContext* autodiffContext) : autodiffContext(autodiffContext) @@ -73,8 +144,8 @@ struct DiffUnzipPass // SLANG_ASSERT(unzippedFunc->getFirstBlock() != nullptr); SLANG_ASSERT(unzippedFunc->getFirstBlock()->getNextBlock() != nullptr); - - IRBlock* firstBlock = unzippedFunc->getFirstBlock()->getNextBlock(); + + IRBlock* firstBlock = as<IRUnconditionalBranch>(unzippedFunc->getFirstBlock()->getTerminator())->getTargetBlock(); List<IRBlock*> mixedBlocks; for (IRBlock* block = firstBlock; block; block = block->getNextBlock()) @@ -122,9 +193,42 @@ struct DiffUnzipPass splitBlock(block, as<IRBlock>(primalMap[block]), as<IRBlock>(diffMap[block])); } + // Propagate indexed region information. + propagateAllIndexRegions(); + + // Try to infer maximum counts for all regions. + // (only regions whose intermediates are used outside their region + // require a maximum count, so we may see some unresolved regions + // without any issues) + // + for (auto region : indexRegions) + { + tryInferMaxIndex(region); + } + + // Emit counter variables and other supporting + // instructions for all regions. + // + lowerIndexedRegions(); + + // Process intermediate insts in indexed blocks + // into array loads/stores. + // + for (auto block : mixedBlocks) + { + auto primalBlock = primalMap[block]; + + if (isBlockIndexed(block)) + { + processIndexedFwdBlock(block); + } + } + // Swap the first block's occurences out for the first primal block. firstBlock->replaceUsesWith(firstPrimalBlock); + cleanupIndexRegionInfo(); + // Remove old blocks. for (auto block : mixedBlocks) block->removeAndDeallocate(); @@ -132,6 +236,239 @@ struct DiffUnzipPass return unzippedFunc; } + IRBlock* getInitializerBlock(IndexedRegion* region) + { + return region->initBlock; + } + + IRBlock* getUpdateBlock(IndexedRegion* region) + { + return region->continueBlock; + } + + void tryInferMaxIndex(IndexedRegion* region) + { + if (region->status != IndexedRegion::CountStatus::Unresolved) + return; + + // We're going to fix this at a some random number + // for now, and then add some basic inference + user-defined decoration + // + region->maxIters = 5; + region->status = IndexedRegion::CountStatus::Static; + } + + // Make a primal value *available* to the differential block. + // This can get quite involved, and we're going to rely on + // constructSSA to do most of the heavy-lifting & optimization + // For now, we'll simply create a variable in the top-most + // primal block, then load it in the last primal block + // + //void hoistValue(IRInst* primalInst) + //{ + // IRBlock* terminalPrimalBlock = getTerminalPrimalBlock(); + // IRBlock* firstPrimalBlock = getFirstPrimalBlock(); + //} + + void lowerIndexedRegions() + { + IRBuilder builder(autodiffContext->sharedBuilder); + + + for (auto region : indexRegions) + { + + IRBlock* initializerBlock = getInitializerBlock(region); + + // Grab first primal block. + auto firstPrimalBlock = primalMap[region->breakBlock->getParent()->getFirstBlock()->getNextBlock()]; + + // Make variable in the top-most block (so it's visible to diff blocks) + builder.setInsertInto(firstPrimalBlock); + region->primalCountVar = builder.emitVar(builder.getUIntType()); + + // Make another variable in the diff block initialized to the + // final value of the primal counter. + // + builder.setInsertInto(diffMap[initializerBlock]); + auto primalCounterValue = builder.emitLoad(region->primalCountVar); + region->diffCountVar = builder.emitVar(builder.getUIntType()); + builder.emitStore(region->diffCountVar, primalCounterValue); + + IRBlock* updateBlock = getUpdateBlock(region); + + { + // TODO: Figure out if the counter update needs to go before or after + // the rest of the update block. + // + builder.setInsertBefore(as<IRBlock>(primalMap[updateBlock])->getTerminator()); + + auto counterVal = builder.emitLoad(region->primalCountVar); + auto incCounterVal = builder.emitAdd( + builder.getUIntType(), + counterVal, + builder.getIntValue(builder.getUIntType(), 1)); + + auto incStore = builder.emitStore(region->primalCountVar, incCounterVal); + + builder.addLoopCounterDecoration(counterVal); + builder.addLoopCounterDecoration(incCounterVal); + builder.addLoopCounterDecoration(incStore); + } + + { + // NOTE: This is a hacky shortcut we're taking here. + // Technically the unzip pass should not affect the + // correctness (it must still compute the proper fwd-mode derivative) + // However, we're currently making the loop counter go backwards to + // make it easier on the transposition pass, so the output from + // the unzip pass is neither fwd-mode or rev-mode until the transposition + // step is complete. + // + // TODO: Ideally this needs to be replaced with a small inversion step + // within the transposition pass. + // + + builder.setInsertBefore(as<IRBlock>(diffMap[updateBlock])->getTerminator()); + + auto counterVal = builder.emitLoad(region->diffCountVar); + auto decCounterVal = builder.emitSub( + builder.getUIntType(), + counterVal, + builder.getIntValue(builder.getUIntType(), 0)); + + auto decStore = builder.emitStore(region->diffCountVar, decCounterVal); + + // Mark insts as loop counter insts to avoid removing them. + // + builder.addLoopCounterDecoration(counterVal); + builder.addLoopCounterDecoration(decCounterVal); + builder.addLoopCounterDecoration(decStore); + } + + } + } + + void processIndexedFwdBlock(IRBlock* fwdBlock) + { + if (!isBlockIndexed(fwdBlock)) + return; + + // Grab first primal block. + IRBlock* firstPrimalBlock = as<IRBlock>(primalMap[fwdBlock->getParent()->getFirstBlock()->getNextBlock()]); + + // Scan through instructions and identify those that are used + // outside the local block. + // + IRBlock* primalBlock = as<IRBlock>(primalMap[fwdBlock]); + + List<IRInst*> primalInsts; + for (auto child = primalBlock->getFirstChild(); child; child = child->getNextInst()) + primalInsts.add(child); + + IRBuilder builder(autodiffContext->sharedBuilder); + + // Build list of indices that this block is affected by. + List<IndexedRegion*> regions; + { + IndexedRegion* region = indexRegionMap[fwdBlock]; + for (; region; region = region->parent) + regions.add(region); + } + + for (auto inst : primalInsts) + { + // 1. Check if we need to store inst (is it used in a differential block?) + + bool shouldStore = false; + for (auto use = inst->firstUse; use; use = use->nextUse) + { + IRBlock* useBlock = as<IRBlock>(use->getUser()->getParent()); + + if (isDifferentialInst(useBlock)) + { + shouldStore = true; + } + } + + if (!shouldStore) continue; + + // 2. Emit an array to top-level to allocate space. + + builder.setInsertBefore(firstPrimalBlock->getTerminator()); + + IRType* arrayType = inst->getDataType(); + SLANG_ASSERT(!as<IRPtrTypeBase>(arrayType)); // can't store pointers. + + for (auto region : regions) + { + SLANG_ASSERT(region->status == IndexedRegion::CountStatus::Static); + SLANG_ASSERT(region->maxIters >= 0); + + arrayType = builder.getArrayType( + arrayType, + builder.getIntValue( + builder.getUIntType(), + region->maxIters)); + } + + // Reverse the list since the indices needs to be + // emitted in reverse order. + // + regions.reverse(); + + auto storageVar = builder.emitVar(arrayType); + + // 3. Store current value into the array and replace uses with a load. + { + builder.setInsertAfter(inst); + + IRInst* storeAddr = storageVar; + IRType* currType = storageVar->getDataType(); + + for (auto region : regions) + { + currType = as<IRArrayType>(currType)->getElementType(); + + storeAddr = builder.emitElementAddress( + currType, + storeAddr, + region->primalCountVar); + } + + builder.emitStore(storeAddr, inst); + } + + // 4. Replace uses in differential blocks with loads from the array. + { + for (auto use = inst->firstUse; use; use = use->nextUse) + { + IRBlock* useBlock = as<IRBlock>(use->getUser()->getParent()); + + if (isDifferentialInst(useBlock)) + { + builder.setInsertBefore(use->getUser()); + + IRInst* loadAddr = storageVar; + IRType* currType = storageVar->getDataType(); + + for (auto region : regions) + { + currType = as<IRArrayType>(currType)->getElementType(); + + loadAddr = builder.emitElementAddress( + currType, + loadAddr, + region->diffCountVar); + } + + use->set(builder.emitLoad(loadAddr)); + } + } + } + } + } + IRFunc* extractPrimalFunc(IRFunc* func, IRFunc* originalFunc, IRInst*& intermediateType); bool isRelevantDifferentialPair(IRType* type) @@ -327,6 +664,188 @@ struct DiffUnzipPass return InstPair(primalBranch, returnInst); } + bool isBlockIndexed(IRBlock* block) + { + return indexRegionMap.ContainsKey(block) && indexRegionMap[block] != nullptr; + } + + void addNewIndex(IRLoop* targetLoop) + { + // Create indexed region without a parent for now. + // The parent will be filled in during propagation. + // + IndexedRegion* region = new IndexedRegion( + nullptr, + as<IRBlock>(targetLoop->getParent()), + targetLoop->getTargetBlock(), + targetLoop->getBreakBlock(), + targetLoop->getContinueBlock()); + + indexRegionMap[targetLoop->getTargetBlock()] = region; + indexRegions.add(region); + } + + // Deallocate regions + void cleanupIndexRegionInfo() + { + for (auto region : indexRegions) + { + delete region; + } + + indexRegions.clear(); + indexRegionMap.Clear(); + } + + void propagateAllIndexRegions() + { + + + // Load up the starting block of every region into + // initial worklist. + // + List<IRBlock*> workList; + HashSet<IRBlock*> workSet; + for (auto region : indexRegions) + { + workList.add(region->firstBlock); + workSet.Add(region->firstBlock); + } + + // Keep propagating from initial work list to predecessors + // Add blocks to work list if their region assignment has changed + // Add the beginning blocks for complete regions if region parent has changed. + // + while (workList.getCount() > 0) + { + auto block = workList.getLast(); + workList.removeLast(); + workSet.Remove(block); + + HashSet<IRBlock*> successors; + + for (auto successor : block->getSuccessors()) + { + if (successors.Contains(successor)) + continue; + + if (propagateIndexRegion(block, successor)) + { + if (!workSet.Contains(successor)) + { + workList.add(successor); + workSet.Add(successor); + } + + // Do we have an index region for the successor, which is + // also the starting block of that region? + // Then the change might have been the addition of + // a parent node. Add the break block so the + // change can be propagated further. + // + if (isBlockIndexed(successor)) + { + IndexedRegion* succRegion = indexRegionMap[successor]; + if (succRegion->firstBlock == successor) + { + if (!workSet.Contains(succRegion->breakBlock)) + { + workList.add(succRegion->breakBlock); + workSet.Add(succRegion->breakBlock); + } + } + } + } + + successors.Add(successor); + } + } + } + + bool setIndexRegion(IRBlock* block, IndexedRegion* region) + { + if (!region) return false; + + if (indexRegionMap.ContainsKey(block) + && indexRegionMap[block] == region) + return false; + + indexRegionMap[block] = region; + return true; + } + + bool propagateIndexRegion(IRBlock* srcBlock, IRBlock* nextBlock) + { + // Is the current region indexed? + // If not, there's nothing to propagate + // + if (!isBlockIndexed(srcBlock)) + return false; + + IndexedRegion* region = indexRegionMap[srcBlock]; + + // If the target's index is already resolved, + // check if it's a sub-region. + // + if (isBlockIndexed(nextBlock)) + { + IndexedRegion* nextRegion = indexRegionMap[nextBlock]; + + // If we're at the first block of a region, + // set current region as continue-region's + // parent. + // + if (nextBlock == nextRegion->firstBlock && nextRegion != region) + { + nextRegion->parent = region; + return true; + } + + return false; + } + + // If we're at the break block, move up to the parent index. + if (nextBlock == region->breakBlock) + return setIndexRegion(nextBlock, region->parent); + + // If none of the special cases hit, copy the + // current region to the next block. + // + return setIndexRegion(nextBlock, region); + } + + // Splitting a loop is one of the trickiest parts of the unzip pass. + // Thus far, we've been dealing with blocks that are only run once, so we + // could arbitrarily move intermediate instructions to other blocks since they are + // generated and consumed at-most one time. + // + // Intermediate instructions in a loop can take on a different value each iteration + // and thus need to be stored explicitly to an array. + // + // We also need to ascertain an upper limit on the iteration count. + // With very few exceptions, this is a fundamental requirement. + // + InstPair splitLoop(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRLoop* mixedLoop) + { + + auto breakBlock = mixedLoop->getBreakBlock(); + auto continueBlock = mixedLoop->getContinueBlock(); + auto nextBlock = mixedLoop->getTargetBlock(); + + // Push a new index. + addNewIndex(mixedLoop); + + return InstPair( + primalBuilder->emitLoop( + as<IRBlock>(primalMap[nextBlock]), + as<IRBlock>(primalMap[breakBlock]), + as<IRBlock>(primalMap[continueBlock])), + diffBuilder->emitLoop( + as<IRBlock>(diffMap[nextBlock]), + as<IRBlock>(diffMap[breakBlock]), + as<IRBlock>(diffMap[continueBlock]))); + } + InstPair splitControlFlow(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRInst* branchInst) { switch (branchInst->getOp()) @@ -430,6 +949,9 @@ struct DiffUnzipPass diffCaseArgs.getBuffer())); } + case kIROp_loop: + return splitLoop(primalBuilder, diffBuilder, as<IRLoop>(branchInst)); + default: SLANG_UNEXPECTED("Unhandled instruction"); } @@ -544,11 +1066,13 @@ struct DiffUnzipPass (use->getUser()->getParent() != diffBlock)); } - inst->removeAndDeallocate(); + // Leave terminator in to keep CFG info. + if (!as<IRTerminatorInst>(inst)) + inst->removeAndDeallocate(); } // Nothing should be left in the original block. - SLANG_ASSERT(block->getFirstChild() == nullptr); + SLANG_ASSERT(block->getFirstChild() == block->getTerminator()); // Branch from primal to differential block. // Functionally, the new blocks should produce the same output as the diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 6b6b3924a..f2294671e 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -749,6 +749,8 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(BackwardDerivativePrimalContextDecoration, BackwardDerivativePrimalContextDecoration, 1, 0) INST(BackwardDerivativePrimalReturnDecoration, BackwardDerivativePrimalReturnDecoration, 1, 0) + INST(LoopCounterDecoration, loopCounterDecoration, 0, 0) + /// Used by the auto-diff pass to mark insts that compute /// a differential value. INST(DifferentialInstDecoration, diffInstDecoration, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 8b30a02dd..5669a12d7 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -663,6 +663,15 @@ struct IRBackwardDerivativeDecoration : IRDecoration IRInst* getBackwardDerivativeFunc() { return getOperand(0); } }; +struct IRLoopCounterDecoration : IRDecoration +{ + enum + { + kOp = kIROp_LoopCounterDecoration + }; + IR_LEAF_ISA(LoopCounterDecoration) +}; + struct IRDifferentialInstDecoration : IRDecoration { enum @@ -3243,6 +3252,13 @@ public: IRBlock* target, IRBlock* breakBlock, IRBlock* continueBlock); + + IRInst* emitLoop( + IRBlock* target, + IRBlock* breakBlock, + IRBlock* continueBlock, + Int argCount, + IRInst*const* args); IRInst* emitBranch( IRInst* val, @@ -3590,6 +3606,11 @@ public: addDecoration(value, kIROp_BackwardDerivativePrimalContextDecoration, ctx); } + void addLoopCounterDecoration(IRInst* value) + { + addDecoration(value, kIROp_LoopCounterDecoration); + } + void markInstAsDifferential(IRInst* value) { addDecoration(value, kIROp_DifferentialInstDecoration, nullptr); diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index 0bd5c6e9f..ee55a6546 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -84,7 +84,7 @@ struct ConstructSSAContext Dictionary<IRBlock*, RefPtr<SSABlockInfo>> blockInfos; // IR building state to use during the operation - SharedIRBuilder sharedBuilder; + SharedIRBuilder* sharedBuilder; // Instructions to remove during cleanup List<IRInst*> instsToRemove; @@ -1043,7 +1043,7 @@ static void breakCriticalEdges( for (auto edge : criticalEdges) { - context->sharedBuilder.insertBlockAlongEdge(edge); + context->sharedBuilder->insertBlockAlongEdge(edge); } } @@ -1205,7 +1205,8 @@ bool constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal) ConstructSSAContext context; context.globalVal = globalVal; - context.sharedBuilder.init(module); + SharedIRBuilder sharedBuilder(module); + context.sharedBuilder = &sharedBuilder; context.builder.init(context.sharedBuilder); context.builder.setInsertInto(module); @@ -1213,6 +1214,22 @@ bool constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal) return constructSSA(&context); } +// Construct SSA form for a global value with code and reuse +// an existing sharedBuilder +// +bool constructSSA(SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* globalVal) +{ + ConstructSSAContext context; + context.globalVal = globalVal; + + context.sharedBuilder = sharedBuilder; + + context.builder.init(sharedBuilder); + context.builder.setInsertInto(sharedBuilder->getModule()); + + return constructSSA(&context); +} + bool constructSSA(IRModule* module, IRInst* globalVal) { switch (globalVal->getOp()) diff --git a/source/slang/slang-ir-ssa.h b/source/slang/slang-ir-ssa.h index d455439df..02c9c4831 100644 --- a/source/slang/slang-ir-ssa.h +++ b/source/slang/slang-ir-ssa.h @@ -6,7 +6,9 @@ namespace Slang struct IRModule; struct IRGlobalValueWithCode; struct IRInst; + struct SharedIRBuilder; bool constructSSA(IRModule* module, IRGlobalValueWithCode* globalVal); + bool constructSSA(SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* globalVal); bool constructSSA(IRModule* module); bool constructSSA(IRInst* globalVal); } diff --git a/source/slang/slang-ir-validate.cpp b/source/slang/slang-ir-validate.cpp index a49eda322..03db96ac5 100644 --- a/source/slang/slang-ir-validate.cpp +++ b/source/slang/slang-ir-validate.cpp @@ -199,6 +199,13 @@ namespace Slang if(inst->getFullType()) validateIRInstOperand(context, inst, &inst->typeUse); + // Avoid validating decoration operands + // since they don't have to conform to inst visibility + // constraints. + // + if (as<IRDecoration>(inst)) + return; + UInt operandCount = inst->getOperandCount(); for (UInt ii = 0; ii < operandCount; ++ii) { diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 845232ae6..e72ba8c9f 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -4764,6 +4764,32 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitLoop( + IRBlock* target, + IRBlock* breakBlock, + IRBlock* continueBlock, + Int argCount, + IRInst*const* args) + { + List<IRInst*> argList; + + argList.add(target); + argList.add(breakBlock); + argList.add(continueBlock); + + for (Count ii = 0; ii < argCount; ii++) + argList.add(args[ii]); + + auto inst = createInst<IRLoop>( + this, + kIROp_loop, + nullptr, + argList.getCount(), + argList.getBuffer()); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitBranch( IRInst* val, IRBlock* trueBlock, |
