From 5f8475bee2589b8e851c856135cf10758e859e72 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Mon, 28 Jul 2025 15:29:58 -0400 Subject: Fix issue in multi-level break elimination by handling multi-level continue statements (#7953) --- .../slang/slang-ir-eliminate-multilevel-break.cpp | 184 +++++++++++++++++---- 1 file changed, 149 insertions(+), 35 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp index cbbd1c5c8..3507e83ba 100644 --- a/source/slang/slang-ir-eliminate-multilevel-break.cpp +++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp @@ -5,6 +5,7 @@ #include "slang-ir-dominators.h" #include "slang-ir-eliminate-phis.h" #include "slang-ir-insts.h" +#include "slang-ir-loop-unroll.h" #include "slang-ir-util.h" #include "slang-ir.h" @@ -28,6 +29,10 @@ struct EliminateMultiLevelBreakContext List blocks; HashSet blockSet; List> childRegions; + + // Track exit blocks for this region (break block and continue block for loops) + List exitBlocks; + IRBlock* getBreakBlock() { switch (headerInst->getOp()) @@ -41,6 +46,31 @@ struct EliminateMultiLevelBreakContext } } + IRBlock* getContinueBlock() + { + switch (headerInst->getOp()) + { + case kIROp_Loop: + return as(headerInst)->getContinueBlock(); + case kIROp_Switch: + return nullptr; // Switches don't have continue blocks + default: + SLANG_UNREACHABLE("Unknown breakable inst"); + } + } + + void populateExitBlocks() + { + exitBlocks.clear(); + exitBlocks.add(getBreakBlock()); + + // If this is a loop, add any non-trivial continue block to the exit blocks + if (auto loop = as(headerInst)) + if (auto continueBlock = getContinueBlock()) + if (continueBlock != loop->getTargetBlock()) + exitBlocks.add(continueBlock); + } + void replaceBreakBlock(IRBuilder* builder, IRBlock* block) { switch (headerInst->getOp()) @@ -65,32 +95,36 @@ struct EliminateMultiLevelBreakContext } }; - struct MultiLevelBreakInfo + struct MultiLevelBranchInfo { - IRUnconditionalBranch* breakInst; + IRUnconditionalBranch* branchInst; BreakableRegionInfo* currentRegion; - BreakableRegionInfo* breakTargetRegion; + BreakableRegionInfo* branchTargetRegion; }; struct FuncContext { List> regions; - HashSet breakBlocks; - Dictionary mapBreakBlockToRegion; + HashSet exitBlocks; + Dictionary mapExitBlockToRegion; Dictionary mapBlockToRegion; HashSet processedBlocks; - List multiLevelBreaks; + List multiLevelBranches; + + // Track how many multi-level branches target each exit block + Dictionary exitBlockMultiLevelBranchCount; void collectBreakableRegionBlocks(BreakableRegionInfo& info) { - // Push break block to a stack so we can easily check if a block is a break block in its - // parent regions. - breakBlocks.add(info.getBreakBlock()); + // Push all exit blocks to a stack so we can easily check if a block is an exit block in + // its parent regions. + for (auto exitBlock : info.exitBlocks) + exitBlocks.add(exitBlock); auto successors = as(info.headerInst->getParent())->getSuccessors(); for (auto successor : successors) { - if (!breakBlocks.add(successor)) + if (exitBlocks.contains(successor)) continue; if (info.blockSet.add(successor)) info.blocks.add(successor); @@ -111,6 +145,7 @@ struct EliminateMultiLevelBreakContext childRegion->headerInst = block->getTerminator(); childRegion->parent = &info; childRegion->level = info.level + 1; + childRegion->populateExitBlocks(); collectBreakableRegionBlocks(*childRegion); info.childRegions.add(childRegion); block = childRegion->getBreakBlock(); @@ -125,7 +160,7 @@ struct EliminateMultiLevelBreakContext } for (auto succ : block->getSuccessors()) { - if (!breakBlocks.contains(succ)) + if (!exitBlocks.contains(succ)) { if (info.blockSet.add(succ)) info.blocks.add(succ); @@ -133,8 +168,9 @@ struct EliminateMultiLevelBreakContext } } - // Pop the break block from stack since we are no longer processing the region. - breakBlocks.remove(info.getBreakBlock()); + // Pop the exit blocks. + for (auto exitBlock : info.exitBlocks) + exitBlocks.remove(exitBlock); } void gatherInfo(IRGlobalValueWithCode* func) @@ -151,6 +187,7 @@ struct EliminateMultiLevelBreakContext { RefPtr regionInfo = new BreakableRegionInfo(); regionInfo->headerInst = terminator; + regionInfo->populateExitBlocks(); collectBreakableRegionBlocks(*regionInfo); regions.add(regionInfo); } @@ -164,13 +201,29 @@ struct EliminateMultiLevelBreakContext l->forEach( [&](BreakableRegionInfo* region) { - if (!isUnreachableRootBlock(region->getBreakBlock())) - mapBreakBlockToRegion.add(region->getBreakBlock(), region); + for (auto exitBlock : region->exitBlocks) + if (!isUnreachableRootBlock(exitBlock)) + mapExitBlockToRegion.add(exitBlock, region); + for (auto block : region->blocks) mapBlockToRegion.add(block, region); }); } + // Initialize exit block multi-level branch counts + for (auto& l : regions) + { + l->forEach( + [&](BreakableRegionInfo* region) + { + for (auto exitBlock : region->exitBlocks) + { + if (!isUnreachableRootBlock(exitBlock)) + exitBlockMultiLevelBranchCount[exitBlock] = 0; + } + }); + } + for (auto block : func->getBlocks()) { auto terminator = block->getTerminator(); @@ -178,26 +231,44 @@ struct EliminateMultiLevelBreakContext { if (as(terminator)) continue; - BreakableRegionInfo* breakTargetRegion = nullptr; + BreakableRegionInfo* targetRegion = nullptr; BreakableRegionInfo* currentRegion = nullptr; - if (!mapBreakBlockToRegion.tryGetValue( - branch->getTargetBlock(), - breakTargetRegion)) + + // Check if the target is an exit block of any region + if (!mapExitBlockToRegion.tryGetValue(branch->getTargetBlock(), targetRegion)) continue; if (mapBlockToRegion.tryGetValue(block, currentRegion)) { - if (currentRegion != breakTargetRegion) + if (currentRegion != targetRegion) { - MultiLevelBreakInfo breakInfo; - breakInfo.breakInst = branch; - breakInfo.breakTargetRegion = breakTargetRegion; - breakInfo.currentRegion = currentRegion; - multiLevelBreaks.add(breakInfo); + MultiLevelBranchInfo branchInfo; + branchInfo.branchInst = branch; + branchInfo.branchTargetRegion = targetRegion; + branchInfo.currentRegion = currentRegion; + multiLevelBranches.add(branchInfo); + + // Increment the count for this exit block + exitBlockMultiLevelBranchCount[branch->getTargetBlock()]++; } } } } } + + ShortList getMultiLevelExitBlocks(BreakableRegionInfo* region) + { + ShortList result; + for (auto exitBlock : region->exitBlocks) + { + Count branchCount = 0; + if (exitBlockMultiLevelBranchCount.tryGetValue(exitBlock, branchCount) && + branchCount > 0) + { + result.add(exitBlock); + } + } + return result; + } }; @@ -312,8 +383,31 @@ struct EliminateMultiLevelBreakContext FuncContext funcInfo; funcInfo.gatherInfo(func); - if (funcInfo.multiLevelBreaks.getCount() == 0) + if (funcInfo.multiLevelBranches.getCount() == 0) return; + + // Check if each region has a single exit block with multi-level branches + // and if it is the break block. If not, eliminate continue blocks first. + bool needsContinueElimination = false; + for (auto& region : funcInfo.regions) + region->forEach( + [&](BreakableRegionInfo* region) + { + // Ensure that each region has a unique exit block with multi-level branches + ShortList multiLevelExitBlocks = + funcInfo.getMultiLevelExitBlocks(region); + if (multiLevelExitBlocks.getCount() == 0) + return; + + if (multiLevelExitBlocks.getCount() == 1 && + multiLevelExitBlocks[0] == region->getBreakBlock()) + return; + + needsContinueElimination = true; + }); + + if (needsContinueElimination) + eliminateContinueBlocksInFunc(irModule, func); } // To make things easy, eliminate Phis before perform transformations. @@ -327,9 +421,29 @@ struct EliminateMultiLevelBreakContext FuncContext funcInfo; funcInfo.gatherInfo(func); - if (funcInfo.multiLevelBreaks.getCount() == 0) + if (funcInfo.multiLevelBranches.getCount() == 0) return; + // Verify that the only multi-level branches we have to handle are into break blocks. + for (auto& region : funcInfo.regions) + region->forEach( + [&](BreakableRegionInfo* region) + { + // Ensure that each region has a unique exit block with multi-level branches + ShortList multiLevelExitBlocks = + funcInfo.getMultiLevelExitBlocks(region); + if (multiLevelExitBlocks.getCount() == 0) + return; + + if (multiLevelExitBlocks.getCount() == 1 && + multiLevelExitBlocks[0] == region->getBreakBlock()) + return; + + SLANG_UNEXPECTED( + "Multi-level break elimination failed: unique exit block is not the break " + "block"); + }); + // Duplicate unreachable break blocks so that each break block is only mapped to a single duplicateUnreachableBreakBlocks(&funcInfo); @@ -342,22 +456,22 @@ struct EliminateMultiLevelBreakContext builder.emitUnreachable(); builder.setInsertInto(func); - // Rewrite multi-level breaks with single level break + target level argument. - for (auto breakInfo : funcInfo.multiLevelBreaks) + // Rewrite multi-level branches with single level "break" + target-level argument. + for (auto branchInfo : funcInfo.multiLevelBranches) { - auto region = breakInfo.currentRegion; + auto region = branchInfo.currentRegion; while (region) { skippedOverRegions.add(region); region = region->parent; - if (region == breakInfo.breakTargetRegion) + if (region == branchInfo.branchTargetRegion) break; } - builder.setInsertBefore(breakInfo.breakInst); + builder.setInsertBefore(branchInfo.branchInst); auto targetLevelInst = - builder.getIntValue(builder.getIntType(), breakInfo.breakTargetRegion->level); - builder.emitBranch(breakInfo.currentRegion->getBreakBlock(), 1, &targetLevelInst); - breakInfo.breakInst->removeAndDeallocate(); + builder.getIntValue(builder.getIntType(), branchInfo.branchTargetRegion->level); + builder.emitBranch(branchInfo.currentRegion->getBreakBlock(), 1, &targetLevelInst); + branchInfo.branchInst->removeAndDeallocate(); } // Rewrite skipped-over break blocks to accept a target level argument. -- cgit v1.2.3