diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2025-07-28 15:29:58 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-07-28 19:29:58 +0000 |
| commit | 5f8475bee2589b8e851c856135cf10758e859e72 (patch) | |
| tree | 98517a1b532144cbab9000b40a3424321615d7c1 /source | |
| parent | 4ca545ed9e98fa49740b3537473e02b950c23a99 (diff) | |
Fix issue in multi-level break elimination by handling multi-level continue statements (#7953)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-eliminate-multilevel-break.cpp | 184 |
1 files changed, 149 insertions, 35 deletions
diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp index cbbd1c5c8..3507e83ba 100644 --- a/source/slang/slang-ir-eliminate-multilevel-break.cpp +++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp @@ -5,6 +5,7 @@ #include "slang-ir-dominators.h" #include "slang-ir-eliminate-phis.h" #include "slang-ir-insts.h" +#include "slang-ir-loop-unroll.h" #include "slang-ir-util.h" #include "slang-ir.h" @@ -28,6 +29,10 @@ struct EliminateMultiLevelBreakContext List<IRBlock*> blocks; HashSet<IRBlock*> blockSet; List<RefPtr<BreakableRegionInfo>> childRegions; + + // Track exit blocks for this region (break block and continue block for loops) + List<IRBlock*> exitBlocks; + IRBlock* getBreakBlock() { switch (headerInst->getOp()) @@ -41,6 +46,31 @@ struct EliminateMultiLevelBreakContext } } + IRBlock* getContinueBlock() + { + switch (headerInst->getOp()) + { + case kIROp_Loop: + return as<IRLoop>(headerInst)->getContinueBlock(); + case kIROp_Switch: + return nullptr; // Switches don't have continue blocks + default: + SLANG_UNREACHABLE("Unknown breakable inst"); + } + } + + void populateExitBlocks() + { + exitBlocks.clear(); + exitBlocks.add(getBreakBlock()); + + // If this is a loop, add any non-trivial continue block to the exit blocks + if (auto loop = as<IRLoop>(headerInst)) + if (auto continueBlock = getContinueBlock()) + if (continueBlock != loop->getTargetBlock()) + exitBlocks.add(continueBlock); + } + void replaceBreakBlock(IRBuilder* builder, IRBlock* block) { switch (headerInst->getOp()) @@ -65,32 +95,36 @@ struct EliminateMultiLevelBreakContext } }; - struct MultiLevelBreakInfo + struct MultiLevelBranchInfo { - IRUnconditionalBranch* breakInst; + IRUnconditionalBranch* branchInst; BreakableRegionInfo* currentRegion; - BreakableRegionInfo* breakTargetRegion; + BreakableRegionInfo* branchTargetRegion; }; struct FuncContext { List<RefPtr<BreakableRegionInfo>> regions; - HashSet<IRBlock*> breakBlocks; - Dictionary<IRBlock*, BreakableRegionInfo*> mapBreakBlockToRegion; + HashSet<IRBlock*> exitBlocks; + Dictionary<IRBlock*, BreakableRegionInfo*> mapExitBlockToRegion; Dictionary<IRBlock*, BreakableRegionInfo*> mapBlockToRegion; HashSet<IRBlock*> processedBlocks; - List<MultiLevelBreakInfo> multiLevelBreaks; + List<MultiLevelBranchInfo> multiLevelBranches; + + // Track how many multi-level branches target each exit block + Dictionary<IRBlock*, Count> exitBlockMultiLevelBranchCount; void collectBreakableRegionBlocks(BreakableRegionInfo& info) { - // Push break block to a stack so we can easily check if a block is a break block in its - // parent regions. - breakBlocks.add(info.getBreakBlock()); + // Push all exit blocks to a stack so we can easily check if a block is an exit block in + // its parent regions. + for (auto exitBlock : info.exitBlocks) + exitBlocks.add(exitBlock); auto successors = as<IRBlock>(info.headerInst->getParent())->getSuccessors(); for (auto successor : successors) { - if (!breakBlocks.add(successor)) + if (exitBlocks.contains(successor)) continue; if (info.blockSet.add(successor)) info.blocks.add(successor); @@ -111,6 +145,7 @@ struct EliminateMultiLevelBreakContext childRegion->headerInst = block->getTerminator(); childRegion->parent = &info; childRegion->level = info.level + 1; + childRegion->populateExitBlocks(); collectBreakableRegionBlocks(*childRegion); info.childRegions.add(childRegion); block = childRegion->getBreakBlock(); @@ -125,7 +160,7 @@ struct EliminateMultiLevelBreakContext } for (auto succ : block->getSuccessors()) { - if (!breakBlocks.contains(succ)) + if (!exitBlocks.contains(succ)) { if (info.blockSet.add(succ)) info.blocks.add(succ); @@ -133,8 +168,9 @@ struct EliminateMultiLevelBreakContext } } - // Pop the break block from stack since we are no longer processing the region. - breakBlocks.remove(info.getBreakBlock()); + // Pop the exit blocks. + for (auto exitBlock : info.exitBlocks) + exitBlocks.remove(exitBlock); } void gatherInfo(IRGlobalValueWithCode* func) @@ -151,6 +187,7 @@ struct EliminateMultiLevelBreakContext { RefPtr<BreakableRegionInfo> regionInfo = new BreakableRegionInfo(); regionInfo->headerInst = terminator; + regionInfo->populateExitBlocks(); collectBreakableRegionBlocks(*regionInfo); regions.add(regionInfo); } @@ -164,13 +201,29 @@ struct EliminateMultiLevelBreakContext l->forEach( [&](BreakableRegionInfo* region) { - if (!isUnreachableRootBlock(region->getBreakBlock())) - mapBreakBlockToRegion.add(region->getBreakBlock(), region); + for (auto exitBlock : region->exitBlocks) + if (!isUnreachableRootBlock(exitBlock)) + mapExitBlockToRegion.add(exitBlock, region); + for (auto block : region->blocks) mapBlockToRegion.add(block, region); }); } + // Initialize exit block multi-level branch counts + for (auto& l : regions) + { + l->forEach( + [&](BreakableRegionInfo* region) + { + for (auto exitBlock : region->exitBlocks) + { + if (!isUnreachableRootBlock(exitBlock)) + exitBlockMultiLevelBranchCount[exitBlock] = 0; + } + }); + } + for (auto block : func->getBlocks()) { auto terminator = block->getTerminator(); @@ -178,26 +231,44 @@ struct EliminateMultiLevelBreakContext { if (as<IRLoop>(terminator)) continue; - BreakableRegionInfo* breakTargetRegion = nullptr; + BreakableRegionInfo* targetRegion = nullptr; BreakableRegionInfo* currentRegion = nullptr; - if (!mapBreakBlockToRegion.tryGetValue( - branch->getTargetBlock(), - breakTargetRegion)) + + // Check if the target is an exit block of any region + if (!mapExitBlockToRegion.tryGetValue(branch->getTargetBlock(), targetRegion)) continue; if (mapBlockToRegion.tryGetValue(block, currentRegion)) { - if (currentRegion != breakTargetRegion) + if (currentRegion != targetRegion) { - MultiLevelBreakInfo breakInfo; - breakInfo.breakInst = branch; - breakInfo.breakTargetRegion = breakTargetRegion; - breakInfo.currentRegion = currentRegion; - multiLevelBreaks.add(breakInfo); + MultiLevelBranchInfo branchInfo; + branchInfo.branchInst = branch; + branchInfo.branchTargetRegion = targetRegion; + branchInfo.currentRegion = currentRegion; + multiLevelBranches.add(branchInfo); + + // Increment the count for this exit block + exitBlockMultiLevelBranchCount[branch->getTargetBlock()]++; } } } } } + + ShortList<IRBlock*, 2> getMultiLevelExitBlocks(BreakableRegionInfo* region) + { + ShortList<IRBlock*, 2> result; + for (auto exitBlock : region->exitBlocks) + { + Count branchCount = 0; + if (exitBlockMultiLevelBranchCount.tryGetValue(exitBlock, branchCount) && + branchCount > 0) + { + result.add(exitBlock); + } + } + return result; + } }; @@ -312,8 +383,31 @@ struct EliminateMultiLevelBreakContext FuncContext funcInfo; funcInfo.gatherInfo(func); - if (funcInfo.multiLevelBreaks.getCount() == 0) + if (funcInfo.multiLevelBranches.getCount() == 0) return; + + // Check if each region has a single exit block with multi-level branches + // and if it is the break block. If not, eliminate continue blocks first. + bool needsContinueElimination = false; + for (auto& region : funcInfo.regions) + region->forEach( + [&](BreakableRegionInfo* region) + { + // Ensure that each region has a unique exit block with multi-level branches + ShortList<IRBlock*, 2> multiLevelExitBlocks = + funcInfo.getMultiLevelExitBlocks(region); + if (multiLevelExitBlocks.getCount() == 0) + return; + + if (multiLevelExitBlocks.getCount() == 1 && + multiLevelExitBlocks[0] == region->getBreakBlock()) + return; + + needsContinueElimination = true; + }); + + if (needsContinueElimination) + eliminateContinueBlocksInFunc(irModule, func); } // To make things easy, eliminate Phis before perform transformations. @@ -327,9 +421,29 @@ struct EliminateMultiLevelBreakContext FuncContext funcInfo; funcInfo.gatherInfo(func); - if (funcInfo.multiLevelBreaks.getCount() == 0) + if (funcInfo.multiLevelBranches.getCount() == 0) return; + // Verify that the only multi-level branches we have to handle are into break blocks. + for (auto& region : funcInfo.regions) + region->forEach( + [&](BreakableRegionInfo* region) + { + // Ensure that each region has a unique exit block with multi-level branches + ShortList<IRBlock*, 2> multiLevelExitBlocks = + funcInfo.getMultiLevelExitBlocks(region); + if (multiLevelExitBlocks.getCount() == 0) + return; + + if (multiLevelExitBlocks.getCount() == 1 && + multiLevelExitBlocks[0] == region->getBreakBlock()) + return; + + SLANG_UNEXPECTED( + "Multi-level break elimination failed: unique exit block is not the break " + "block"); + }); + // Duplicate unreachable break blocks so that each break block is only mapped to a single duplicateUnreachableBreakBlocks(&funcInfo); @@ -342,22 +456,22 @@ struct EliminateMultiLevelBreakContext builder.emitUnreachable(); builder.setInsertInto(func); - // Rewrite multi-level breaks with single level break + target level argument. - for (auto breakInfo : funcInfo.multiLevelBreaks) + // Rewrite multi-level branches with single level "break" + target-level argument. + for (auto branchInfo : funcInfo.multiLevelBranches) { - auto region = breakInfo.currentRegion; + auto region = branchInfo.currentRegion; while (region) { skippedOverRegions.add(region); region = region->parent; - if (region == breakInfo.breakTargetRegion) + if (region == branchInfo.branchTargetRegion) break; } - builder.setInsertBefore(breakInfo.breakInst); + builder.setInsertBefore(branchInfo.branchInst); auto targetLevelInst = - builder.getIntValue(builder.getIntType(), breakInfo.breakTargetRegion->level); - builder.emitBranch(breakInfo.currentRegion->getBreakBlock(), 1, &targetLevelInst); - breakInfo.breakInst->removeAndDeallocate(); + builder.getIntValue(builder.getIntType(), branchInfo.branchTargetRegion->level); + builder.emitBranch(branchInfo.currentRegion->getBreakBlock(), 1, &targetLevelInst); + branchInfo.branchInst->removeAndDeallocate(); } // Rewrite skipped-over break blocks to accept a target level argument. |
