summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2025-07-28 15:29:58 -0400
committerGitHub <noreply@github.com>2025-07-28 19:29:58 +0000
commit5f8475bee2589b8e851c856135cf10758e859e72 (patch)
tree98517a1b532144cbab9000b40a3424321615d7c1 /source
parent4ca545ed9e98fa49740b3537473e02b950c23a99 (diff)
Fix issue in multi-level break elimination by handling multi-level continue statements (#7953)
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-eliminate-multilevel-break.cpp184
1 files changed, 149 insertions, 35 deletions
diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp
index cbbd1c5c8..3507e83ba 100644
--- a/source/slang/slang-ir-eliminate-multilevel-break.cpp
+++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp
@@ -5,6 +5,7 @@
#include "slang-ir-dominators.h"
#include "slang-ir-eliminate-phis.h"
#include "slang-ir-insts.h"
+#include "slang-ir-loop-unroll.h"
#include "slang-ir-util.h"
#include "slang-ir.h"
@@ -28,6 +29,10 @@ struct EliminateMultiLevelBreakContext
List<IRBlock*> blocks;
HashSet<IRBlock*> blockSet;
List<RefPtr<BreakableRegionInfo>> childRegions;
+
+ // Track exit blocks for this region (break block and continue block for loops)
+ List<IRBlock*> exitBlocks;
+
IRBlock* getBreakBlock()
{
switch (headerInst->getOp())
@@ -41,6 +46,31 @@ struct EliminateMultiLevelBreakContext
}
}
+ IRBlock* getContinueBlock()
+ {
+ switch (headerInst->getOp())
+ {
+ case kIROp_Loop:
+ return as<IRLoop>(headerInst)->getContinueBlock();
+ case kIROp_Switch:
+ return nullptr; // Switches don't have continue blocks
+ default:
+ SLANG_UNREACHABLE("Unknown breakable inst");
+ }
+ }
+
+ void populateExitBlocks()
+ {
+ exitBlocks.clear();
+ exitBlocks.add(getBreakBlock());
+
+ // If this is a loop, add any non-trivial continue block to the exit blocks
+ if (auto loop = as<IRLoop>(headerInst))
+ if (auto continueBlock = getContinueBlock())
+ if (continueBlock != loop->getTargetBlock())
+ exitBlocks.add(continueBlock);
+ }
+
void replaceBreakBlock(IRBuilder* builder, IRBlock* block)
{
switch (headerInst->getOp())
@@ -65,32 +95,36 @@ struct EliminateMultiLevelBreakContext
}
};
- struct MultiLevelBreakInfo
+ struct MultiLevelBranchInfo
{
- IRUnconditionalBranch* breakInst;
+ IRUnconditionalBranch* branchInst;
BreakableRegionInfo* currentRegion;
- BreakableRegionInfo* breakTargetRegion;
+ BreakableRegionInfo* branchTargetRegion;
};
struct FuncContext
{
List<RefPtr<BreakableRegionInfo>> regions;
- HashSet<IRBlock*> breakBlocks;
- Dictionary<IRBlock*, BreakableRegionInfo*> mapBreakBlockToRegion;
+ HashSet<IRBlock*> exitBlocks;
+ Dictionary<IRBlock*, BreakableRegionInfo*> mapExitBlockToRegion;
Dictionary<IRBlock*, BreakableRegionInfo*> mapBlockToRegion;
HashSet<IRBlock*> processedBlocks;
- List<MultiLevelBreakInfo> multiLevelBreaks;
+ List<MultiLevelBranchInfo> multiLevelBranches;
+
+ // Track how many multi-level branches target each exit block
+ Dictionary<IRBlock*, Count> exitBlockMultiLevelBranchCount;
void collectBreakableRegionBlocks(BreakableRegionInfo& info)
{
- // Push break block to a stack so we can easily check if a block is a break block in its
- // parent regions.
- breakBlocks.add(info.getBreakBlock());
+ // Push all exit blocks to a stack so we can easily check if a block is an exit block in
+ // its parent regions.
+ for (auto exitBlock : info.exitBlocks)
+ exitBlocks.add(exitBlock);
auto successors = as<IRBlock>(info.headerInst->getParent())->getSuccessors();
for (auto successor : successors)
{
- if (!breakBlocks.add(successor))
+ if (exitBlocks.contains(successor))
continue;
if (info.blockSet.add(successor))
info.blocks.add(successor);
@@ -111,6 +145,7 @@ struct EliminateMultiLevelBreakContext
childRegion->headerInst = block->getTerminator();
childRegion->parent = &info;
childRegion->level = info.level + 1;
+ childRegion->populateExitBlocks();
collectBreakableRegionBlocks(*childRegion);
info.childRegions.add(childRegion);
block = childRegion->getBreakBlock();
@@ -125,7 +160,7 @@ struct EliminateMultiLevelBreakContext
}
for (auto succ : block->getSuccessors())
{
- if (!breakBlocks.contains(succ))
+ if (!exitBlocks.contains(succ))
{
if (info.blockSet.add(succ))
info.blocks.add(succ);
@@ -133,8 +168,9 @@ struct EliminateMultiLevelBreakContext
}
}
- // Pop the break block from stack since we are no longer processing the region.
- breakBlocks.remove(info.getBreakBlock());
+ // Pop the exit blocks.
+ for (auto exitBlock : info.exitBlocks)
+ exitBlocks.remove(exitBlock);
}
void gatherInfo(IRGlobalValueWithCode* func)
@@ -151,6 +187,7 @@ struct EliminateMultiLevelBreakContext
{
RefPtr<BreakableRegionInfo> regionInfo = new BreakableRegionInfo();
regionInfo->headerInst = terminator;
+ regionInfo->populateExitBlocks();
collectBreakableRegionBlocks(*regionInfo);
regions.add(regionInfo);
}
@@ -164,13 +201,29 @@ struct EliminateMultiLevelBreakContext
l->forEach(
[&](BreakableRegionInfo* region)
{
- if (!isUnreachableRootBlock(region->getBreakBlock()))
- mapBreakBlockToRegion.add(region->getBreakBlock(), region);
+ for (auto exitBlock : region->exitBlocks)
+ if (!isUnreachableRootBlock(exitBlock))
+ mapExitBlockToRegion.add(exitBlock, region);
+
for (auto block : region->blocks)
mapBlockToRegion.add(block, region);
});
}
+ // Initialize exit block multi-level branch counts
+ for (auto& l : regions)
+ {
+ l->forEach(
+ [&](BreakableRegionInfo* region)
+ {
+ for (auto exitBlock : region->exitBlocks)
+ {
+ if (!isUnreachableRootBlock(exitBlock))
+ exitBlockMultiLevelBranchCount[exitBlock] = 0;
+ }
+ });
+ }
+
for (auto block : func->getBlocks())
{
auto terminator = block->getTerminator();
@@ -178,26 +231,44 @@ struct EliminateMultiLevelBreakContext
{
if (as<IRLoop>(terminator))
continue;
- BreakableRegionInfo* breakTargetRegion = nullptr;
+ BreakableRegionInfo* targetRegion = nullptr;
BreakableRegionInfo* currentRegion = nullptr;
- if (!mapBreakBlockToRegion.tryGetValue(
- branch->getTargetBlock(),
- breakTargetRegion))
+
+ // Check if the target is an exit block of any region
+ if (!mapExitBlockToRegion.tryGetValue(branch->getTargetBlock(), targetRegion))
continue;
if (mapBlockToRegion.tryGetValue(block, currentRegion))
{
- if (currentRegion != breakTargetRegion)
+ if (currentRegion != targetRegion)
{
- MultiLevelBreakInfo breakInfo;
- breakInfo.breakInst = branch;
- breakInfo.breakTargetRegion = breakTargetRegion;
- breakInfo.currentRegion = currentRegion;
- multiLevelBreaks.add(breakInfo);
+ MultiLevelBranchInfo branchInfo;
+ branchInfo.branchInst = branch;
+ branchInfo.branchTargetRegion = targetRegion;
+ branchInfo.currentRegion = currentRegion;
+ multiLevelBranches.add(branchInfo);
+
+ // Increment the count for this exit block
+ exitBlockMultiLevelBranchCount[branch->getTargetBlock()]++;
}
}
}
}
}
+
+ ShortList<IRBlock*, 2> getMultiLevelExitBlocks(BreakableRegionInfo* region)
+ {
+ ShortList<IRBlock*, 2> result;
+ for (auto exitBlock : region->exitBlocks)
+ {
+ Count branchCount = 0;
+ if (exitBlockMultiLevelBranchCount.tryGetValue(exitBlock, branchCount) &&
+ branchCount > 0)
+ {
+ result.add(exitBlock);
+ }
+ }
+ return result;
+ }
};
@@ -312,8 +383,31 @@ struct EliminateMultiLevelBreakContext
FuncContext funcInfo;
funcInfo.gatherInfo(func);
- if (funcInfo.multiLevelBreaks.getCount() == 0)
+ if (funcInfo.multiLevelBranches.getCount() == 0)
return;
+
+ // Check if each region has a single exit block with multi-level branches
+ // and if it is the break block. If not, eliminate continue blocks first.
+ bool needsContinueElimination = false;
+ for (auto& region : funcInfo.regions)
+ region->forEach(
+ [&](BreakableRegionInfo* region)
+ {
+ // Ensure that each region has a unique exit block with multi-level branches
+ ShortList<IRBlock*, 2> multiLevelExitBlocks =
+ funcInfo.getMultiLevelExitBlocks(region);
+ if (multiLevelExitBlocks.getCount() == 0)
+ return;
+
+ if (multiLevelExitBlocks.getCount() == 1 &&
+ multiLevelExitBlocks[0] == region->getBreakBlock())
+ return;
+
+ needsContinueElimination = true;
+ });
+
+ if (needsContinueElimination)
+ eliminateContinueBlocksInFunc(irModule, func);
}
// To make things easy, eliminate Phis before perform transformations.
@@ -327,9 +421,29 @@ struct EliminateMultiLevelBreakContext
FuncContext funcInfo;
funcInfo.gatherInfo(func);
- if (funcInfo.multiLevelBreaks.getCount() == 0)
+ if (funcInfo.multiLevelBranches.getCount() == 0)
return;
+ // Verify that the only multi-level branches we have to handle are into break blocks.
+ for (auto& region : funcInfo.regions)
+ region->forEach(
+ [&](BreakableRegionInfo* region)
+ {
+ // Ensure that each region has a unique exit block with multi-level branches
+ ShortList<IRBlock*, 2> multiLevelExitBlocks =
+ funcInfo.getMultiLevelExitBlocks(region);
+ if (multiLevelExitBlocks.getCount() == 0)
+ return;
+
+ if (multiLevelExitBlocks.getCount() == 1 &&
+ multiLevelExitBlocks[0] == region->getBreakBlock())
+ return;
+
+ SLANG_UNEXPECTED(
+ "Multi-level break elimination failed: unique exit block is not the break "
+ "block");
+ });
+
// Duplicate unreachable break blocks so that each break block is only mapped to a single
duplicateUnreachableBreakBlocks(&funcInfo);
@@ -342,22 +456,22 @@ struct EliminateMultiLevelBreakContext
builder.emitUnreachable();
builder.setInsertInto(func);
- // Rewrite multi-level breaks with single level break + target level argument.
- for (auto breakInfo : funcInfo.multiLevelBreaks)
+ // Rewrite multi-level branches with single level "break" + target-level argument.
+ for (auto branchInfo : funcInfo.multiLevelBranches)
{
- auto region = breakInfo.currentRegion;
+ auto region = branchInfo.currentRegion;
while (region)
{
skippedOverRegions.add(region);
region = region->parent;
- if (region == breakInfo.breakTargetRegion)
+ if (region == branchInfo.branchTargetRegion)
break;
}
- builder.setInsertBefore(breakInfo.breakInst);
+ builder.setInsertBefore(branchInfo.branchInst);
auto targetLevelInst =
- builder.getIntValue(builder.getIntType(), breakInfo.breakTargetRegion->level);
- builder.emitBranch(breakInfo.currentRegion->getBreakBlock(), 1, &targetLevelInst);
- breakInfo.breakInst->removeAndDeallocate();
+ builder.getIntValue(builder.getIntType(), branchInfo.branchTargetRegion->level);
+ builder.emitBranch(branchInfo.currentRegion->getBreakBlock(), 1, &targetLevelInst);
+ branchInfo.branchInst->removeAndDeallocate();
}
// Rewrite skipped-over break blocks to accept a target level argument.