diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-02-27 23:42:06 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-27 23:42:06 -0500 |
| commit | 10e2d9c7c532c204f26bb2c9f383f21b121b2ff2 (patch) | |
| tree | 9ae0dd84b505a7ecd3fb45de9dbde74f8dd1ebe9 /source | |
| parent | a3ba22b51c371d5a20d61aa4e35233ba4f4f68db (diff) | |
More fixes for reverse-mode on complicated loops (#2675)
* Multiple fixes to get various loop tests to pass.
* Create reverse-nested-loop.slang
* Fix for variables becoming inaccessible during cfg normalization
* Removed comments and moved break-branch-normalization to eliminateMultiLevelBreaks
* Fix.
* Override liveness tests
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-cfg-norm.cpp | 57 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir-eliminate-multilevel-break.cpp | 54 | ||||
| -rw-r--r-- | source/slang/slang-ir-loop-unroll.cpp | 10 |
5 files changed, 124 insertions, 13 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp index 9116f67e9..2199b0771 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.cpp +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -43,6 +43,7 @@ struct BreakableRegionInfo { IRVar* breakVar; IRBlock* breakBlock; + IRBlock* headerBlock; }; struct CFGNormalizationContext @@ -57,13 +58,39 @@ IRBlock* getOrCreateTopLevelCondition(IRLoop* loopInst) // For now, we're going to naively assume the next block is the condition block. // Add in more support for more cases as necessary. // - + auto firstBlock = loopInst->getTargetBlock(); - auto ifElse = as<IRIfElse>(firstBlock->getTerminator()); - SLANG_RELEASE_ASSERT(ifElse); + if (as<IRIfElse>(firstBlock->getTerminator())) + { + return firstBlock; + } + else + { + // If there isn't a condition we need to make one with a dummy condition that + // always evaluates to true + // - return firstBlock; + IRBuilder condBuilder(loopInst->getModule()); + + auto condBlock = condBuilder.emitBlock(); + condBlock->insertAfter(as<IRBlock>(loopInst->getParent())); + + // Make loop go into the condition block + firstBlock->replaceUsesWith(condBlock); + + // Emit a condition: true side goes to the loop body, and + // false side goes into the break block. + // + condBuilder.setInsertInto(condBlock); + condBuilder.emitIfElse( + condBuilder.getBoolValue(true), + firstBlock, + loopInst->getBreakBlock(), + firstBlock); + + return condBlock; + } } struct CFGNormalizationPass @@ -133,6 +160,20 @@ struct CFGNormalizationPass return false; } + void _moveVarsToRegionHeader(BreakableRegionInfo* region, IRBlock* block) + { + for (auto child = block->getFirstChild(); child;) + { + auto nextChild = child->getNextInst(); + + if (as<IRVar>(child)) + { + child->insertBefore(region->headerBlock->getTerminator()); + } + + child = nextChild; + } + } RegionEndpoint getNormalizedRegionEndpoint( BreakableRegionInfo* parentRegion, @@ -140,6 +181,7 @@ struct CFGNormalizationPass List<IRBlock*> afterBlocks) { IRBlock* currentBlock = entryBlock; + _moveVarsToRegionHeader(parentRegion, currentBlock); // By default a region starts off with the 'base' control flow // and not in the 'break' control flow @@ -343,6 +385,8 @@ struct CFGNormalizationPass SLANG_UNEXPECTED("Unhandled control flow inst"); break; } + + _moveVarsToRegionHeader(parentRegion, currentBlock); } // Resolve all intermediate after-blocks @@ -399,6 +443,7 @@ struct CFGNormalizationPass { BreakableRegionInfo info; info.breakBlock = as<IRLoop>(branchInst)->getBreakBlock(); + info.headerBlock = as<IRBlock>(branchInst->getParent()); // Emit var into parent block. builder.setInsertBefore( @@ -426,7 +471,7 @@ struct CFGNormalizationPass &info, firstLoopBlock, List<IRBlock*>(info.breakBlock)); - + // Should not be empty.. but check anyway SLANG_RELEASE_ASSERT(!preBreakEndPoint.isRegionEmpty); @@ -495,7 +540,7 @@ struct CFGNormalizationPass // Add a test for the break variable into the condition. auto cond = ifElse->getCondition(); - builder.setInsertAfter(cond); + builder.setInsertBefore(ifElse); auto breakFlagVal = builder.emitLoad(info.breakVar); // Need to invert the break flag if the loop is diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 640f516ed..709968f77 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -14,6 +14,7 @@ #include "slang-ir-init-local-var.h" #include "slang-ir-redundancy-removal.h" #include "slang-ir-dominators.h" +#include "slang-ir-loop-unroll.h" namespace Slang { @@ -583,6 +584,9 @@ namespace Slang { convertFuncToSingleReturnForm(func->getModule(), func); } + + eliminateContinueBlocksInFunc(func->getModule(), func); + eliminateMultiLevelBreakForFunc(func->getModule(), func); IRCFGNormalizationPass cfgPass = {this->getSink()}; diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index a30826370..3678bd4b3 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -409,16 +409,14 @@ struct DiffUnzipPass for (auto region : indexRegions) { // Grab first primal block. - IRBlock* firstPrimalBlock = as<IRBlock>(primalMap[region->breakBlock->getParent()->getFirstBlock()->getNextBlock()]); - builder.setInsertBefore(firstPrimalBlock->getTerminator()); + IRBlock* primalInitBlock = as<IRBlock>(primalMap[region->initBlock]); + builder.setInsertBefore(primalInitBlock->getTerminator()); // Make variable in the top-most block (so it's visible to diff blocks) region->primalCountLastVar = builder.emitVar(builder.getIntType()); builder.addNameHintDecoration(region->primalCountLastVar, UnownedStringSlice("_pc_last_var")); - { - IRBlock* primalInitBlock = as<IRBlock>(primalMap[region->initBlock]); - + { auto primalCondBlock = as<IRUnconditionalBranch>( primalInitBlock->getTerminator())->getTargetBlock(); builder.setInsertBefore(primalCondBlock->getTerminator()); @@ -664,8 +662,8 @@ struct DiffUnzipPass storageVar); // 4. Store current value into the array and replace uses with a load. - // TODO: If an index is missing, use the 'last' value of the primal index. - + // If an index is missing, use the 'last' value of the primal index. + { if (!isIntermediateContext) setInsertAfterOrdinaryInst(&builder, valueToStore); diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp index 3618e1326..e73fae982 100644 --- a/source/slang/slang-ir-eliminate-multilevel-break.cpp +++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp @@ -175,8 +175,62 @@ struct EliminateMultiLevelBreakContext } }; + + void insertBlockBetween(IRBlock* block, IRBlock* successor) + { + IRBuilder builder(block->getModule()); + + List<IRUse*> relevantUses; + for (auto use = successor->firstUse; use; use = use->nextUse) + { + if (auto terminator = as<IRTerminatorInst>(use->getUser())) + { + if (as<IRBlock>(terminator->getParent()) == block) + { + relevantUses.add(use); + } + } + } + + SLANG_RELEASE_ASSERT(relevantUses.getCount() == 1); + + builder.insertBlockAlongEdge(block->getModule(), IREdge(relevantUses[0])); + } + + bool normalizeBranchesIntoBreakBlocks(IRGlobalValueWithCode* func) + { + bool changed = false; + + List<IRBlock*> workList; + + for (auto block : func->getBlocks()) + workList.add(block); + + for (auto block : workList) + { + if (auto loop = as<IRLoop>(block->getTerminator())) + { + auto breakBlock = loop->getBreakBlock(); + + for (auto predecessor : breakBlock->getPredecessors()) + { + if (!as<IRUnconditionalBranch>(predecessor->getTerminator())) + { + insertBlockBetween(predecessor, breakBlock); + changed = true; + } + } + } + } + + return changed; + } + void processFunc(IRGlobalValueWithCode* func) { + + normalizeBranchesIntoBreakBlocks(func); + // If func does not have any multi-level breaks, return. { FuncContext funcInfo; diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp index 2f689ebde..4f9b8d272 100644 --- a/source/slang/slang-ir-loop-unroll.cpp +++ b/source/slang/slang-ir-loop-unroll.cpp @@ -603,6 +603,16 @@ void eliminateContinueBlocks(IRModule* module, IRLoop* loopInst) builder.setInsertInto(innerBreakableRegionBreakBlock); _moveParams(innerBreakableRegionBreakBlock, continueBlock); builder.emitBranch(continueBlock); + + // If the original loop can be executed up to N times, the new loop may be executed + // upto N+1 times (although most insts are skipped in the last traversal) + // + if (auto maxItersDecoration = loopInst->findDecoration<IRLoopMaxItersDecoration>()) + { + auto maxIters = maxItersDecoration->getMaxIters(); + maxItersDecoration->removeAndDeallocate(); + builder.addLoopMaxItersDecoration(loopInst, maxIters + 1); + } } void eliminateContinueBlocksInFunc(IRModule* module, IRGlobalValueWithCode* func) |
