diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-08-25 14:53:12 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-25 14:53:12 -0400 |
| commit | 06f7ef354cdde4cf8e8797d8853ed2d9c3208b5b (patch) | |
| tree | 43458d031c791b1e03b469f2b059391cf4a755b6 /source | |
| parent | ef4c9f1f1c297f1a33be95795a7a7561e0cc3bde (diff) | |
Fix various issues with trivial loops (#3149)
* Fix issue with trivial loop detection
* Fix issue with unreachable blocks in break elimination
Add logic to avoid eliminating loops with multi-level breaks.
* Incorporate feedback
- Use a boolean for multi-level break check
- Use dominator trees for region check instead of exhaustive enumeration
- Fix potential issue with enumerating parent break blocks.
* fix
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-dominators.cpp | 25 | ||||
| -rw-r--r-- | source/slang/slang-ir-dominators.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-eliminate-multilevel-break.cpp | 78 | ||||
| -rw-r--r-- | source/slang/slang-ir-loop-unroll.cpp | 43 | ||||
| -rw-r--r-- | source/slang/slang-ir-loop-unroll.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-simplify-cfg.cpp | 71 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.cpp | 123 | ||||
| -rw-r--r-- | source/slang/slang-ir-util.h | 31 |
8 files changed, 314 insertions, 63 deletions
diff --git a/source/slang/slang-ir-dominators.cpp b/source/slang/slang-ir-dominators.cpp index e57099321..8527fbf36 100644 --- a/source/slang/slang-ir-dominators.cpp +++ b/source/slang/slang-ir-dominators.cpp @@ -188,7 +188,7 @@ Int IRDominatorTree::getBlockIndex(IRBlock* block) bool IRDominatorTree::isUnreachable(IRBlock* block) { - return !mapBlockToIndex.containsKey(block); + return !reachableSet.contains(block); } @@ -333,9 +333,24 @@ struct PostorderComputationContext : public DepthFirstSearchContext } }; +void computeReachableSet(IRGlobalValueWithCode* code, HashSet<IRBlock*>& outSet) +{ + DepthFirstSearchContext context; + if (code->getFirstBlock()) + context.walk(code->getFirstBlock(), [](IRBlock* block) {return block->getSuccessors(); }); + outSet = _Move(context.visited); +} + /// Compute a postorder traversal of the blocks in `code`, writing the resulting order to `outOrder`. void computePostorder(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder) { + HashSet<IRBlock*> reachableSet; + computePostorder(code, outOrder, reachableSet); +} + +/// Compute a postorder traversal of the blocks in `code`, writing the resulting order to `outOrder`. +void computePostorder(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder, HashSet<IRBlock*>& outReachableSet) +{ PostorderComputationContext context; context.order = &outOrder; if (code->getFirstBlock()) @@ -352,6 +367,7 @@ void computePostorder(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder) } prefix.addRange(outOrder); outOrder = _Move(prefix); + outReachableSet = _Move(context.visited); } void computePostorderOnReverseCFG(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder) @@ -397,6 +413,10 @@ struct DominatorTreeComputationContext // traversal, so that we can look up a block based on its "name" // List<IRBlock*> postorder; + // + // Also maintain a set of reachable blocks. + // + HashSet<IRBlock*> reachableSet; // // We need a way to map our actual IR blocks to their names for @@ -426,7 +446,7 @@ struct DominatorTreeComputationContext void iterativelyComputeImmediateDominators(IRGlobalValueWithCode* code) { // First we compute the postorder traversal order for the blocks in the CFG. - computePostorder(code, postorder); + computePostorder(code, postorder, reachableSet); // We will initialize our map from the block objects to their "name" // (index in the traversal order), before moving on. @@ -746,6 +766,7 @@ struct DominatorTreeComputationContext RefPtr<IRDominatorTree> dominatorTree = new IRDominatorTree(); dominatorTree->code = code; dominatorTree->nodes.setCount(blockCount); + dominatorTree->reachableSet = _Move(reachableSet); // We will iterate over all of the blocks, and fill in the corresponding // dominator tree node for each. diff --git a/source/slang/slang-ir-dominators.h b/source/slang/slang-ir-dominators.h index 14e84eac6..dbeed2ccc 100644 --- a/source/slang/slang-ir-dominators.h +++ b/source/slang/slang-ir-dominators.h @@ -114,6 +114,9 @@ namespace Slang /// Dictionary used to accelerate `getBlockIndex` Dictionary<IRBlock*, Int> mapBlockToIndex; + /// Reachability information for the CFG + HashSet<IRBlock*> reachableSet; + // // In order to accelerate queries on the tree structure, we will order the tree nodes // carefully, so that all of the descendants of a node are contiguous, with all of @@ -170,6 +173,7 @@ namespace Slang RefPtr<IRDominatorTree> computeDominatorTree(IRGlobalValueWithCode* code); void computePostorder(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder); + void computePostorder(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder, HashSet<IRBlock*>& outReachableSet); void computePostorderOnReverseCFG(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder); inline List<IRBlock*> getPostorder(IRGlobalValueWithCode* code) diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp index 19c95edfd..5ff71a248 100644 --- a/source/slang/slang-ir-eliminate-multilevel-break.cpp +++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp @@ -8,6 +8,11 @@ namespace Slang { + +bool isUnreachableRootBlock(IRBlock* block) +{ + return block->getPredecessors().getCount() == 0; +} struct EliminateMultiLevelBreakContext { @@ -34,6 +39,23 @@ struct EliminateMultiLevelBreakContext } } + void replaceBreakBlock(IRBuilder* builder, IRBlock* block) + { + switch (headerInst->getOp()) + { + case kIROp_loop: + builder->replaceOperand( + &(as<IRLoop>(headerInst)->breakBlock), block); + break; + case kIROp_Switch: + builder->replaceOperand( + &(as<IRSwitch>(headerInst)->breakLabel), block); + break; + default: + SLANG_UNREACHABLE("Unknown breakable inst"); + } + } + template<typename Func> void forEach(const Func& f) { @@ -59,11 +81,6 @@ struct EliminateMultiLevelBreakContext HashSet<IRBlock*> processedBlocks; List<MultiLevelBreakInfo> multiLevelBreaks; - bool isUnreachable(IRBlock* block) - { - return block->getPredecessors().getCount() == 0; - } - void collectBreakableRegionBlocks(BreakableRegionInfo& info) { // Push break block to a stack so we can easily check if a block is a break block in its @@ -97,7 +114,7 @@ struct EliminateMultiLevelBreakContext collectBreakableRegionBlocks(*childRegion); info.childRegions.add(childRegion); block = childRegion->getBreakBlock(); - if (!isUnreachable(block) && info.blockSet.add(block)) + if (!isUnreachableRootBlock(block) && info.blockSet.add(block)) { info.blocks.add(block); } @@ -147,7 +164,7 @@ struct EliminateMultiLevelBreakContext l->forEach( [&](BreakableRegionInfo* region) { - if(!isUnreachable(region->getBreakBlock())) + if(!isUnreachableRootBlock(region->getBreakBlock())) mapBreakBlockToRegion.add(region->getBreakBlock(), region); for (auto block : region->blocks) mapBlockToRegion.add(block, region); @@ -240,6 +257,50 @@ struct EliminateMultiLevelBreakContext return changed; } + void duplicateUnreachableBreakBlocks(FuncContext* context) + { + Dictionary<IRBlock*, BreakableRegionInfo*> mapBreakBlocksToRegion; + + // If we already have a region mapped for a break block, and the break block + // is unreachable, create a new unreachable block and map it. + // + for (auto& l : context->regions) + { + l->forEach( + [&](BreakableRegionInfo* region) + { + if (isUnreachableRootBlock(region->getBreakBlock())) + { + if (mapBreakBlocksToRegion.containsKey(region->getBreakBlock())) + { + if (mapBreakBlocksToRegion[region->getBreakBlock()] != region) + { + // We have a break block that is unreachable, and we have already + // mapped it to a region, and that region is not the current region. + // + // We need to create a new unreachable block, and map it to the + // current region. + // + IRBuilder builder(irModule); + builder.setInsertInto(region->getBreakBlock()->getParent()); + auto newBreakBlock = builder.createBlock(); + newBreakBlock->insertAfter(region->getBreakBlock()); + builder.setInsertInto(newBreakBlock); + builder.emitUnreachable(); + mapBreakBlocksToRegion.add(newBreakBlock, region); + region->replaceBreakBlock(&builder, newBreakBlock); + return; + } + } + else + mapBreakBlocksToRegion.add(region->getBreakBlock(), region); + } + else + mapBreakBlocksToRegion.add(region->getBreakBlock(), region); + }); + } + } + void processFunc(IRGlobalValueWithCode* func) { @@ -264,6 +325,9 @@ struct EliminateMultiLevelBreakContext if (funcInfo.multiLevelBreaks.getCount() == 0) return; + // Duplicate unreachable break blocks so that each break block is only mapped to a single + duplicateUnreachableBreakBlocks(&funcInfo); + IRBuilder builder(irModule); builder.setInsertInto(func); diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp index c9ac4191b..c4ef1650c 100644 --- a/source/slang/slang-ir-loop-unroll.cpp +++ b/source/slang/slang-ir-loop-unroll.cpp @@ -48,45 +48,6 @@ static bool _eliminateDeadBlocks(List<IRBlock*>& blocks, IRBlock* unreachableBlo return changed; } -List<IRBlock*> _collectBlocksInLoop(IRDominatorTree* dom, IRLoop* loopInst) -{ - List<IRBlock*> loopBlocks; - HashSet<IRBlock*> loopBlocksSet; - auto addBlock = [&](IRBlock* block) - { - if (loopBlocksSet.add(block)) - loopBlocks.add(block); - }; - auto firstBlock = as<IRBlock>(loopInst->block.get()); - auto breakBlock = as<IRBlock>(loopInst->breakBlock.get()); - - addBlock(firstBlock); - for (Index i = 0; i < loopBlocks.getCount(); i++) - { - auto block = loopBlocks[i]; - for (auto succ : block->getSuccessors()) - { - if (succ == breakBlock) - continue; - if (!dom->dominates(firstBlock, succ)) - continue; - if (!as<IRUnreachable>(breakBlock->getTerminator())) - { - if (dom->dominates(breakBlock, succ)) - continue; - } - addBlock(succ); - } - } - return loopBlocks; -} - -List<IRBlock*> collectBlocksInLoop(IRGlobalValueWithCode* func, IRLoop* loopInst) -{ - auto dom = computeDominatorTree(func); - return _collectBlocksInLoop(dom, loopInst); -} - static int _getLoopMaxIterationsToUnroll(IRLoop* loopInst) { static constexpr int kMaxIterationsToAttempt = 4096; @@ -440,7 +401,7 @@ static bool _unrollLoop( firstIterationBreakBlock->removeAndDeallocateAllDecorationsAndChildren(); builder.setInsertInto(firstIterationBreakBlock); - builder.emitBranch(unreachableBlock); + builder.emitUnreachable(); break; } @@ -487,7 +448,7 @@ bool unrollLoopsInFunc( // Remove any continue jumps from the loop. eliminateContinueBlocks(module, loop); - auto blocks = collectBlocksInLoop(func, loop); + auto blocks = collectBlocksInRegion(func, loop); auto loopLoc = loop->sourceLoc; if (!_unrollLoop(module, loop, blocks)) { diff --git a/source/slang/slang-ir-loop-unroll.h b/source/slang/slang-ir-loop-unroll.h index 6f7a41192..90d530556 100644 --- a/source/slang/slang-ir-loop-unroll.h +++ b/source/slang/slang-ir-loop-unroll.h @@ -16,8 +16,6 @@ namespace Slang bool unrollLoopsInModule(IRModule* module, DiagnosticSink* sink); - List<IRBlock*> collectBlocksInLoop(IRGlobalValueWithCode* func, IRLoop* loop); - // Turn a loop with continue block into a loop with only back jumps and breaks. // Each iteration will be wrapped in a breakable region, where everything before `continue` // is within the breakable region, and everything after `continue` is outside the breakable diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index e848d11c1..44a8909e4 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -29,6 +29,32 @@ static BreakableRegion* findBreakableRegion(Region* region) } } +static bool isBlockInRegion(IRDominatorTree* domTree, IRTerminatorInst* regionHeader, IRBlock* block) +{ + auto headerBlock = cast<IRBlock>(regionHeader->getParent()); + IRBlock* breakBlock = nullptr; + if (auto loop = as<IRLoop>(regionHeader)) + breakBlock = loop->getBreakBlock(); + else if (auto switchInst = as<IRSwitch>(regionHeader)) + breakBlock = switchInst->getBreakLabel(); + + auto parentBreakBlocks = getParentBreakBlockSet(domTree, headerBlock); + + if (!domTree->dominates(headerBlock, block)) + return false; + + if (domTree->dominates(breakBlock, block)) + return false; + + for (auto parentBreakBlock : parentBreakBlocks) + { + if (domTree->dominates(parentBreakBlock, block)) + return false; + } + + return true; +} + // Test if a loop is trivial: a trivial loop runs for a single iteration without any back edges, and // there is only one break out of the loop at the very end. The function generates `regionTree` if // it is needed and hasn't been generated yet. @@ -102,19 +128,36 @@ static bool isTrivialSingleIterationLoop( // Track the break block backwards through the dominator tree, and see if we find a loop block // that is not the current loop. // - auto currBlock = loop->getBreakBlock(); - for (;;) + auto breakPredList = loop->getBreakBlock()->getPredecessors(); + + if (breakPredList.getCount() > 0) { - auto parent = context.domTree->getImmediateDominator(currBlock); - if (!parent) - break; - currBlock = parent; - if (auto _loop = as<IRLoop>(currBlock->getTerminator())) + auto breakOriginBlock = *loop->getBreakBlock()->getPredecessors().begin(); + + for (auto currBlock = breakOriginBlock; + currBlock; + currBlock = context.domTree->getImmediateDominator(currBlock)) { - if (loop != _loop) - return false; - if (loop == _loop) + auto terminator = currBlock->getTerminator(); + if (terminator == loop) + break; + + // Check if the break originated from an inner breakable region. + // If so, the outer loop cannot be trivially removed. + // + switch (terminator->getOp()) + { + case kIROp_loop: + if (isBlockInRegion(context.domTree, as<IRLoop>(terminator), breakOriginBlock)) + return false; break; + case kIROp_Switch: + if (isBlockInRegion(context.domTree, as<IRSwitch>(terminator), breakOriginBlock)) + return false; + break; + default: + break; + } } } @@ -123,7 +166,13 @@ static bool isTrivialSingleIterationLoop( static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst) { - auto blocks = collectBlocksInLoop(func, loopInst); + bool hasMultiLevelBreaks = false; + auto blocks = collectBlocksInRegion(func, loopInst, &hasMultiLevelBreaks); + + // We'll currently not deal with loops that contain multi-level breaks. + if (hasMultiLevelBreaks) + return true; + HashSet<IRBlock*> loopBlocks; for (auto b : blocks) loopBlocks.add(b); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 467580c83..5ead1a1f4 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -998,11 +998,134 @@ void resetScratchDataBit(IRInst* inst, int bitIndex) } } +List<IRBlock*> collectBlocksInRegion( + IRDominatorTree* dom, + IRLoop* loop, + bool* outHasMultiLevelBreaks) +{ + return collectBlocksInRegion(dom, loop->getBreakBlock(), loop->getTargetBlock(), true, outHasMultiLevelBreaks); +} + +List<IRBlock*> collectBlocksInRegion( + IRDominatorTree* dom, + IRLoop* loop) +{ + bool hasMultiLevelBreaks = false; + return collectBlocksInRegion(dom, loop->getBreakBlock(), loop->getTargetBlock(), true, &hasMultiLevelBreaks); +} + +List<IRBlock*> collectBlocksInRegion( + IRDominatorTree* dom, + IRSwitch* switchInst, + bool* outHasMultiLevelBreaks) +{ + return collectBlocksInRegion(dom, switchInst->getBreakLabel(), as<IRBlock>(switchInst->getParent()), false, outHasMultiLevelBreaks); +} + +List<IRBlock*> collectBlocksInRegion( + IRDominatorTree* dom, + IRSwitch* switchInst) +{ + bool hasMultiLevelBreaks = false; + return collectBlocksInRegion(dom, switchInst->getBreakLabel(), as<IRBlock>(switchInst->getParent()), false, &hasMultiLevelBreaks); +} + +HashSet<IRBlock*> getParentBreakBlockSet(IRDominatorTree* dom, IRBlock* block) +{ + HashSet<IRBlock*> parentBreakBlocksSet; + for (IRBlock* currBlock = dom->getImmediateDominator(block); + currBlock; + currBlock = dom->getImmediateDominator(currBlock)) + { + if (auto loopInst = as<IRLoop>(currBlock->getTerminator())) + if (!dom->dominates(loopInst->getBreakBlock(), block)) + parentBreakBlocksSet.add(loopInst->getBreakBlock()); + else if (auto switchInst = as<IRSwitch>(currBlock->getTerminator())) + if (!dom->dominates(switchInst->getBreakLabel(), block)) + parentBreakBlocksSet.add(switchInst->getBreakLabel()); + } + + return parentBreakBlocksSet; +} + +List<IRBlock*> collectBlocksInRegion( + IRDominatorTree* dom, + IRBlock* breakBlock, + IRBlock* firstBlock, + bool includeFirstBlock, + bool* outHasMultiLevelBreaks) +{ + List<IRBlock*> regionBlocks; + HashSet<IRBlock*> regionBlocksSet; + auto addBlock = [&](IRBlock* block) + { + if (regionBlocksSet.add(block)) + regionBlocks.add(block); + }; + + // Use dominator tree heirarchy to find break blocks of + // all parent regions. We'll need to this to detect breaks + // to outer regions (particularly when our region has no reachable + // break block of its own) + // + HashSet<IRBlock*> parentBreakBlocksSet = getParentBreakBlockSet(dom, firstBlock); + + *outHasMultiLevelBreaks = false; + + addBlock(firstBlock); + for (Index i = 0; i < regionBlocks.getCount(); i++) + { + auto block = regionBlocks[i]; + for (auto succ : block->getSuccessors()) + { + if (parentBreakBlocksSet.contains(succ) && succ != breakBlock) + { + *outHasMultiLevelBreaks = true; + continue; + } + + if (succ == breakBlock) + continue; + if (!dom->dominates(firstBlock, succ)) + continue; + if (!as<IRUnreachable>(breakBlock->getTerminator())) + { + if (dom->dominates(breakBlock, succ)) + continue; + } + + addBlock(succ); + } + } + + if (!includeFirstBlock) + { + regionBlocksSet.remove(firstBlock); + regionBlocks.remove(firstBlock); + } + + return regionBlocks; +} + +List<IRBlock *> collectBlocksInRegion(IRGlobalValueWithCode *func, IRLoop *loopInst, bool* outHasMultiLevelBreaks) +{ + auto dom = computeDominatorTree(func); + return collectBlocksInRegion(dom, loopInst, outHasMultiLevelBreaks); +} + +List<IRBlock*> collectBlocksInRegion(IRGlobalValueWithCode* func, IRLoop* loopInst) +{ + auto dom = computeDominatorTree(func); + bool hasMultiLevelBreaks = false; + return collectBlocksInRegion(dom, loopInst, &hasMultiLevelBreaks); +} + IRVarLayout* findVarLayout(IRInst* value) { if (auto layoutDecoration = value->findDecoration<IRLayoutDecoration>()) return as<IRVarLayout>(layoutDecoration->getLayout()); return nullptr; + } UnownedStringSlice getBasicTypeNameHint(IRType* basicType) diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index c107ec24a..20bac0cbf 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -224,6 +224,37 @@ bool isOne(IRInst* inst); void initializeScratchData(IRInst* inst); void resetScratchDataBit(IRInst* inst, int bitIndex); +List<IRBlock*> collectBlocksInRegion( + IRDominatorTree* dom, + IRLoop* loop); + +List<IRBlock*> collectBlocksInRegion( + IRDominatorTree* dom, + IRSwitch* switchInst); + +List<IRBlock*> collectBlocksInRegion( + IRDominatorTree* dom, + IRSwitch* switchInst, + bool* outHasMultilevelBreaks); + +List<IRBlock*> collectBlocksInRegion( + IRDominatorTree* dom, + IRLoop* loop, + bool* outHasMultilevelBreaks); + +List<IRBlock*> collectBlocksInRegion( + IRDominatorTree* dom, + IRBlock* breakBlock, + IRBlock* firstBlock, + bool includeFirstBlock, + bool* outHasMultilevelBreaks); + +List<IRBlock*> collectBlocksInRegion(IRGlobalValueWithCode* func, IRLoop* loopInst, bool* outHasMultilevelBreaks); + +List<IRBlock*> collectBlocksInRegion(IRGlobalValueWithCode* func, IRLoop* loopInst); + +HashSet<IRBlock*> getParentBreakBlockSet(IRDominatorTree* dom, IRBlock* block); + IRVarLayout* findVarLayout(IRInst* value); // Run an operation over every block in a module |
