summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-08-25 14:53:12 -0400
committerGitHub <noreply@github.com>2023-08-25 14:53:12 -0400
commit06f7ef354cdde4cf8e8797d8853ed2d9c3208b5b (patch)
tree43458d031c791b1e03b469f2b059391cf4a755b6 /source/slang
parentef4c9f1f1c297f1a33be95795a7a7561e0cc3bde (diff)
Fix various issues with trivial loops (#3149)
* Fix issue with trivial loop detection * Fix issue with unreachable blocks in break elimination Add logic to avoid eliminating loops with multi-level breaks. * Incorporate feedback - Use a boolean for multi-level break check - Use dominator trees for region check instead of exhaustive enumeration - Fix potential issue with enumerating parent break blocks. * fix
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ir-dominators.cpp25
-rw-r--r--source/slang/slang-ir-dominators.h4
-rw-r--r--source/slang/slang-ir-eliminate-multilevel-break.cpp78
-rw-r--r--source/slang/slang-ir-loop-unroll.cpp43
-rw-r--r--source/slang/slang-ir-loop-unroll.h2
-rw-r--r--source/slang/slang-ir-simplify-cfg.cpp71
-rw-r--r--source/slang/slang-ir-util.cpp123
-rw-r--r--source/slang/slang-ir-util.h31
8 files changed, 314 insertions, 63 deletions
diff --git a/source/slang/slang-ir-dominators.cpp b/source/slang/slang-ir-dominators.cpp
index e57099321..8527fbf36 100644
--- a/source/slang/slang-ir-dominators.cpp
+++ b/source/slang/slang-ir-dominators.cpp
@@ -188,7 +188,7 @@ Int IRDominatorTree::getBlockIndex(IRBlock* block)
bool IRDominatorTree::isUnreachable(IRBlock* block)
{
- return !mapBlockToIndex.containsKey(block);
+ return !reachableSet.contains(block);
}
@@ -333,9 +333,24 @@ struct PostorderComputationContext : public DepthFirstSearchContext
}
};
+void computeReachableSet(IRGlobalValueWithCode* code, HashSet<IRBlock*>& outSet)
+{
+ DepthFirstSearchContext context;
+ if (code->getFirstBlock())
+ context.walk(code->getFirstBlock(), [](IRBlock* block) {return block->getSuccessors(); });
+ outSet = _Move(context.visited);
+}
+
/// Compute a postorder traversal of the blocks in `code`, writing the resulting order to `outOrder`.
void computePostorder(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder)
{
+ HashSet<IRBlock*> reachableSet;
+ computePostorder(code, outOrder, reachableSet);
+}
+
+/// Compute a postorder traversal of the blocks in `code`, writing the resulting order to `outOrder`.
+void computePostorder(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder, HashSet<IRBlock*>& outReachableSet)
+{
PostorderComputationContext context;
context.order = &outOrder;
if (code->getFirstBlock())
@@ -352,6 +367,7 @@ void computePostorder(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder)
}
prefix.addRange(outOrder);
outOrder = _Move(prefix);
+ outReachableSet = _Move(context.visited);
}
void computePostorderOnReverseCFG(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder)
@@ -397,6 +413,10 @@ struct DominatorTreeComputationContext
// traversal, so that we can look up a block based on its "name"
//
List<IRBlock*> postorder;
+ //
+ // Also maintain a set of reachable blocks.
+ //
+ HashSet<IRBlock*> reachableSet;
//
// We need a way to map our actual IR blocks to their names for
@@ -426,7 +446,7 @@ struct DominatorTreeComputationContext
void iterativelyComputeImmediateDominators(IRGlobalValueWithCode* code)
{
// First we compute the postorder traversal order for the blocks in the CFG.
- computePostorder(code, postorder);
+ computePostorder(code, postorder, reachableSet);
// We will initialize our map from the block objects to their "name"
// (index in the traversal order), before moving on.
@@ -746,6 +766,7 @@ struct DominatorTreeComputationContext
RefPtr<IRDominatorTree> dominatorTree = new IRDominatorTree();
dominatorTree->code = code;
dominatorTree->nodes.setCount(blockCount);
+ dominatorTree->reachableSet = _Move(reachableSet);
// We will iterate over all of the blocks, and fill in the corresponding
// dominator tree node for each.
diff --git a/source/slang/slang-ir-dominators.h b/source/slang/slang-ir-dominators.h
index 14e84eac6..dbeed2ccc 100644
--- a/source/slang/slang-ir-dominators.h
+++ b/source/slang/slang-ir-dominators.h
@@ -114,6 +114,9 @@ namespace Slang
/// Dictionary used to accelerate `getBlockIndex`
Dictionary<IRBlock*, Int> mapBlockToIndex;
+ /// Reachability information for the CFG
+ HashSet<IRBlock*> reachableSet;
+
//
// In order to accelerate queries on the tree structure, we will order the tree nodes
// carefully, so that all of the descendants of a node are contiguous, with all of
@@ -170,6 +173,7 @@ namespace Slang
RefPtr<IRDominatorTree> computeDominatorTree(IRGlobalValueWithCode* code);
void computePostorder(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder);
+ void computePostorder(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder, HashSet<IRBlock*>& outReachableSet);
void computePostorderOnReverseCFG(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder);
inline List<IRBlock*> getPostorder(IRGlobalValueWithCode* code)
diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp
index 19c95edfd..5ff71a248 100644
--- a/source/slang/slang-ir-eliminate-multilevel-break.cpp
+++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp
@@ -8,6 +8,11 @@
namespace Slang
{
+
+bool isUnreachableRootBlock(IRBlock* block)
+{
+ return block->getPredecessors().getCount() == 0;
+}
struct EliminateMultiLevelBreakContext
{
@@ -34,6 +39,23 @@ struct EliminateMultiLevelBreakContext
}
}
+ void replaceBreakBlock(IRBuilder* builder, IRBlock* block)
+ {
+ switch (headerInst->getOp())
+ {
+ case kIROp_loop:
+ builder->replaceOperand(
+ &(as<IRLoop>(headerInst)->breakBlock), block);
+ break;
+ case kIROp_Switch:
+ builder->replaceOperand(
+ &(as<IRSwitch>(headerInst)->breakLabel), block);
+ break;
+ default:
+ SLANG_UNREACHABLE("Unknown breakable inst");
+ }
+ }
+
template<typename Func>
void forEach(const Func& f)
{
@@ -59,11 +81,6 @@ struct EliminateMultiLevelBreakContext
HashSet<IRBlock*> processedBlocks;
List<MultiLevelBreakInfo> multiLevelBreaks;
- bool isUnreachable(IRBlock* block)
- {
- return block->getPredecessors().getCount() == 0;
- }
-
void collectBreakableRegionBlocks(BreakableRegionInfo& info)
{
// Push break block to a stack so we can easily check if a block is a break block in its
@@ -97,7 +114,7 @@ struct EliminateMultiLevelBreakContext
collectBreakableRegionBlocks(*childRegion);
info.childRegions.add(childRegion);
block = childRegion->getBreakBlock();
- if (!isUnreachable(block) && info.blockSet.add(block))
+ if (!isUnreachableRootBlock(block) && info.blockSet.add(block))
{
info.blocks.add(block);
}
@@ -147,7 +164,7 @@ struct EliminateMultiLevelBreakContext
l->forEach(
[&](BreakableRegionInfo* region)
{
- if(!isUnreachable(region->getBreakBlock()))
+ if(!isUnreachableRootBlock(region->getBreakBlock()))
mapBreakBlockToRegion.add(region->getBreakBlock(), region);
for (auto block : region->blocks)
mapBlockToRegion.add(block, region);
@@ -240,6 +257,50 @@ struct EliminateMultiLevelBreakContext
return changed;
}
+ void duplicateUnreachableBreakBlocks(FuncContext* context)
+ {
+ Dictionary<IRBlock*, BreakableRegionInfo*> mapBreakBlocksToRegion;
+
+ // If we already have a region mapped for a break block, and the break block
+ // is unreachable, create a new unreachable block and map it.
+ //
+ for (auto& l : context->regions)
+ {
+ l->forEach(
+ [&](BreakableRegionInfo* region)
+ {
+ if (isUnreachableRootBlock(region->getBreakBlock()))
+ {
+ if (mapBreakBlocksToRegion.containsKey(region->getBreakBlock()))
+ {
+ if (mapBreakBlocksToRegion[region->getBreakBlock()] != region)
+ {
+ // We have a break block that is unreachable, and we have already
+ // mapped it to a region, and that region is not the current region.
+ //
+ // We need to create a new unreachable block, and map it to the
+ // current region.
+ //
+ IRBuilder builder(irModule);
+ builder.setInsertInto(region->getBreakBlock()->getParent());
+ auto newBreakBlock = builder.createBlock();
+ newBreakBlock->insertAfter(region->getBreakBlock());
+ builder.setInsertInto(newBreakBlock);
+ builder.emitUnreachable();
+ mapBreakBlocksToRegion.add(newBreakBlock, region);
+ region->replaceBreakBlock(&builder, newBreakBlock);
+ return;
+ }
+ }
+ else
+ mapBreakBlocksToRegion.add(region->getBreakBlock(), region);
+ }
+ else
+ mapBreakBlocksToRegion.add(region->getBreakBlock(), region);
+ });
+ }
+ }
+
void processFunc(IRGlobalValueWithCode* func)
{
@@ -264,6 +325,9 @@ struct EliminateMultiLevelBreakContext
if (funcInfo.multiLevelBreaks.getCount() == 0)
return;
+ // Duplicate unreachable break blocks so that each break block is only mapped to a single
+ duplicateUnreachableBreakBlocks(&funcInfo);
+
IRBuilder builder(irModule);
builder.setInsertInto(func);
diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp
index c9ac4191b..c4ef1650c 100644
--- a/source/slang/slang-ir-loop-unroll.cpp
+++ b/source/slang/slang-ir-loop-unroll.cpp
@@ -48,45 +48,6 @@ static bool _eliminateDeadBlocks(List<IRBlock*>& blocks, IRBlock* unreachableBlo
return changed;
}
-List<IRBlock*> _collectBlocksInLoop(IRDominatorTree* dom, IRLoop* loopInst)
-{
- List<IRBlock*> loopBlocks;
- HashSet<IRBlock*> loopBlocksSet;
- auto addBlock = [&](IRBlock* block)
- {
- if (loopBlocksSet.add(block))
- loopBlocks.add(block);
- };
- auto firstBlock = as<IRBlock>(loopInst->block.get());
- auto breakBlock = as<IRBlock>(loopInst->breakBlock.get());
-
- addBlock(firstBlock);
- for (Index i = 0; i < loopBlocks.getCount(); i++)
- {
- auto block = loopBlocks[i];
- for (auto succ : block->getSuccessors())
- {
- if (succ == breakBlock)
- continue;
- if (!dom->dominates(firstBlock, succ))
- continue;
- if (!as<IRUnreachable>(breakBlock->getTerminator()))
- {
- if (dom->dominates(breakBlock, succ))
- continue;
- }
- addBlock(succ);
- }
- }
- return loopBlocks;
-}
-
-List<IRBlock*> collectBlocksInLoop(IRGlobalValueWithCode* func, IRLoop* loopInst)
-{
- auto dom = computeDominatorTree(func);
- return _collectBlocksInLoop(dom, loopInst);
-}
-
static int _getLoopMaxIterationsToUnroll(IRLoop* loopInst)
{
static constexpr int kMaxIterationsToAttempt = 4096;
@@ -440,7 +401,7 @@ static bool _unrollLoop(
firstIterationBreakBlock->removeAndDeallocateAllDecorationsAndChildren();
builder.setInsertInto(firstIterationBreakBlock);
- builder.emitBranch(unreachableBlock);
+ builder.emitUnreachable();
break;
}
@@ -487,7 +448,7 @@ bool unrollLoopsInFunc(
// Remove any continue jumps from the loop.
eliminateContinueBlocks(module, loop);
- auto blocks = collectBlocksInLoop(func, loop);
+ auto blocks = collectBlocksInRegion(func, loop);
auto loopLoc = loop->sourceLoc;
if (!_unrollLoop(module, loop, blocks))
{
diff --git a/source/slang/slang-ir-loop-unroll.h b/source/slang/slang-ir-loop-unroll.h
index 6f7a41192..90d530556 100644
--- a/source/slang/slang-ir-loop-unroll.h
+++ b/source/slang/slang-ir-loop-unroll.h
@@ -16,8 +16,6 @@ namespace Slang
bool unrollLoopsInModule(IRModule* module, DiagnosticSink* sink);
- List<IRBlock*> collectBlocksInLoop(IRGlobalValueWithCode* func, IRLoop* loop);
-
// Turn a loop with continue block into a loop with only back jumps and breaks.
// Each iteration will be wrapped in a breakable region, where everything before `continue`
// is within the breakable region, and everything after `continue` is outside the breakable
diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp
index e848d11c1..44a8909e4 100644
--- a/source/slang/slang-ir-simplify-cfg.cpp
+++ b/source/slang/slang-ir-simplify-cfg.cpp
@@ -29,6 +29,32 @@ static BreakableRegion* findBreakableRegion(Region* region)
}
}
+static bool isBlockInRegion(IRDominatorTree* domTree, IRTerminatorInst* regionHeader, IRBlock* block)
+{
+ auto headerBlock = cast<IRBlock>(regionHeader->getParent());
+ IRBlock* breakBlock = nullptr;
+ if (auto loop = as<IRLoop>(regionHeader))
+ breakBlock = loop->getBreakBlock();
+ else if (auto switchInst = as<IRSwitch>(regionHeader))
+ breakBlock = switchInst->getBreakLabel();
+
+ auto parentBreakBlocks = getParentBreakBlockSet(domTree, headerBlock);
+
+ if (!domTree->dominates(headerBlock, block))
+ return false;
+
+ if (domTree->dominates(breakBlock, block))
+ return false;
+
+ for (auto parentBreakBlock : parentBreakBlocks)
+ {
+ if (domTree->dominates(parentBreakBlock, block))
+ return false;
+ }
+
+ return true;
+}
+
// Test if a loop is trivial: a trivial loop runs for a single iteration without any back edges, and
// there is only one break out of the loop at the very end. The function generates `regionTree` if
// it is needed and hasn't been generated yet.
@@ -102,19 +128,36 @@ static bool isTrivialSingleIterationLoop(
// Track the break block backwards through the dominator tree, and see if we find a loop block
// that is not the current loop.
//
- auto currBlock = loop->getBreakBlock();
- for (;;)
+ auto breakPredList = loop->getBreakBlock()->getPredecessors();
+
+ if (breakPredList.getCount() > 0)
{
- auto parent = context.domTree->getImmediateDominator(currBlock);
- if (!parent)
- break;
- currBlock = parent;
- if (auto _loop = as<IRLoop>(currBlock->getTerminator()))
+ auto breakOriginBlock = *loop->getBreakBlock()->getPredecessors().begin();
+
+ for (auto currBlock = breakOriginBlock;
+ currBlock;
+ currBlock = context.domTree->getImmediateDominator(currBlock))
{
- if (loop != _loop)
- return false;
- if (loop == _loop)
+ auto terminator = currBlock->getTerminator();
+ if (terminator == loop)
+ break;
+
+ // Check if the break originated from an inner breakable region.
+ // If so, the outer loop cannot be trivially removed.
+ //
+ switch (terminator->getOp())
+ {
+ case kIROp_loop:
+ if (isBlockInRegion(context.domTree, as<IRLoop>(terminator), breakOriginBlock))
+ return false;
break;
+ case kIROp_Switch:
+ if (isBlockInRegion(context.domTree, as<IRSwitch>(terminator), breakOriginBlock))
+ return false;
+ break;
+ default:
+ break;
+ }
}
}
@@ -123,7 +166,13 @@ static bool isTrivialSingleIterationLoop(
static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst)
{
- auto blocks = collectBlocksInLoop(func, loopInst);
+ bool hasMultiLevelBreaks = false;
+ auto blocks = collectBlocksInRegion(func, loopInst, &hasMultiLevelBreaks);
+
+ // We'll currently not deal with loops that contain multi-level breaks.
+ if (hasMultiLevelBreaks)
+ return true;
+
HashSet<IRBlock*> loopBlocks;
for (auto b : blocks)
loopBlocks.add(b);
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 467580c83..5ead1a1f4 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -998,11 +998,134 @@ void resetScratchDataBit(IRInst* inst, int bitIndex)
}
}
+List<IRBlock*> collectBlocksInRegion(
+ IRDominatorTree* dom,
+ IRLoop* loop,
+ bool* outHasMultiLevelBreaks)
+{
+ return collectBlocksInRegion(dom, loop->getBreakBlock(), loop->getTargetBlock(), true, outHasMultiLevelBreaks);
+}
+
+List<IRBlock*> collectBlocksInRegion(
+ IRDominatorTree* dom,
+ IRLoop* loop)
+{
+ bool hasMultiLevelBreaks = false;
+ return collectBlocksInRegion(dom, loop->getBreakBlock(), loop->getTargetBlock(), true, &hasMultiLevelBreaks);
+}
+
+List<IRBlock*> collectBlocksInRegion(
+ IRDominatorTree* dom,
+ IRSwitch* switchInst,
+ bool* outHasMultiLevelBreaks)
+{
+ return collectBlocksInRegion(dom, switchInst->getBreakLabel(), as<IRBlock>(switchInst->getParent()), false, outHasMultiLevelBreaks);
+}
+
+List<IRBlock*> collectBlocksInRegion(
+ IRDominatorTree* dom,
+ IRSwitch* switchInst)
+{
+ bool hasMultiLevelBreaks = false;
+ return collectBlocksInRegion(dom, switchInst->getBreakLabel(), as<IRBlock>(switchInst->getParent()), false, &hasMultiLevelBreaks);
+}
+
+HashSet<IRBlock*> getParentBreakBlockSet(IRDominatorTree* dom, IRBlock* block)
+{
+ HashSet<IRBlock*> parentBreakBlocksSet;
+ for (IRBlock* currBlock = dom->getImmediateDominator(block);
+ currBlock;
+ currBlock = dom->getImmediateDominator(currBlock))
+ {
+ if (auto loopInst = as<IRLoop>(currBlock->getTerminator()))
+ if (!dom->dominates(loopInst->getBreakBlock(), block))
+ parentBreakBlocksSet.add(loopInst->getBreakBlock());
+ else if (auto switchInst = as<IRSwitch>(currBlock->getTerminator()))
+ if (!dom->dominates(switchInst->getBreakLabel(), block))
+ parentBreakBlocksSet.add(switchInst->getBreakLabel());
+ }
+
+ return parentBreakBlocksSet;
+}
+
+List<IRBlock*> collectBlocksInRegion(
+ IRDominatorTree* dom,
+ IRBlock* breakBlock,
+ IRBlock* firstBlock,
+ bool includeFirstBlock,
+ bool* outHasMultiLevelBreaks)
+{
+ List<IRBlock*> regionBlocks;
+ HashSet<IRBlock*> regionBlocksSet;
+ auto addBlock = [&](IRBlock* block)
+ {
+ if (regionBlocksSet.add(block))
+ regionBlocks.add(block);
+ };
+
+ // Use dominator tree heirarchy to find break blocks of
+ // all parent regions. We'll need to this to detect breaks
+ // to outer regions (particularly when our region has no reachable
+ // break block of its own)
+ //
+ HashSet<IRBlock*> parentBreakBlocksSet = getParentBreakBlockSet(dom, firstBlock);
+
+ *outHasMultiLevelBreaks = false;
+
+ addBlock(firstBlock);
+ for (Index i = 0; i < regionBlocks.getCount(); i++)
+ {
+ auto block = regionBlocks[i];
+ for (auto succ : block->getSuccessors())
+ {
+ if (parentBreakBlocksSet.contains(succ) && succ != breakBlock)
+ {
+ *outHasMultiLevelBreaks = true;
+ continue;
+ }
+
+ if (succ == breakBlock)
+ continue;
+ if (!dom->dominates(firstBlock, succ))
+ continue;
+ if (!as<IRUnreachable>(breakBlock->getTerminator()))
+ {
+ if (dom->dominates(breakBlock, succ))
+ continue;
+ }
+
+ addBlock(succ);
+ }
+ }
+
+ if (!includeFirstBlock)
+ {
+ regionBlocksSet.remove(firstBlock);
+ regionBlocks.remove(firstBlock);
+ }
+
+ return regionBlocks;
+}
+
+List<IRBlock *> collectBlocksInRegion(IRGlobalValueWithCode *func, IRLoop *loopInst, bool* outHasMultiLevelBreaks)
+{
+ auto dom = computeDominatorTree(func);
+ return collectBlocksInRegion(dom, loopInst, outHasMultiLevelBreaks);
+}
+
+List<IRBlock*> collectBlocksInRegion(IRGlobalValueWithCode* func, IRLoop* loopInst)
+{
+ auto dom = computeDominatorTree(func);
+ bool hasMultiLevelBreaks = false;
+ return collectBlocksInRegion(dom, loopInst, &hasMultiLevelBreaks);
+}
+
IRVarLayout* findVarLayout(IRInst* value)
{
if (auto layoutDecoration = value->findDecoration<IRLayoutDecoration>())
return as<IRVarLayout>(layoutDecoration->getLayout());
return nullptr;
+
}
UnownedStringSlice getBasicTypeNameHint(IRType* basicType)
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index c107ec24a..20bac0cbf 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -224,6 +224,37 @@ bool isOne(IRInst* inst);
void initializeScratchData(IRInst* inst);
void resetScratchDataBit(IRInst* inst, int bitIndex);
+List<IRBlock*> collectBlocksInRegion(
+ IRDominatorTree* dom,
+ IRLoop* loop);
+
+List<IRBlock*> collectBlocksInRegion(
+ IRDominatorTree* dom,
+ IRSwitch* switchInst);
+
+List<IRBlock*> collectBlocksInRegion(
+ IRDominatorTree* dom,
+ IRSwitch* switchInst,
+ bool* outHasMultilevelBreaks);
+
+List<IRBlock*> collectBlocksInRegion(
+ IRDominatorTree* dom,
+ IRLoop* loop,
+ bool* outHasMultilevelBreaks);
+
+List<IRBlock*> collectBlocksInRegion(
+ IRDominatorTree* dom,
+ IRBlock* breakBlock,
+ IRBlock* firstBlock,
+ bool includeFirstBlock,
+ bool* outHasMultilevelBreaks);
+
+List<IRBlock*> collectBlocksInRegion(IRGlobalValueWithCode* func, IRLoop* loopInst, bool* outHasMultilevelBreaks);
+
+List<IRBlock*> collectBlocksInRegion(IRGlobalValueWithCode* func, IRLoop* loopInst);
+
+HashSet<IRBlock*> getParentBreakBlockSet(IRDominatorTree* dom, IRBlock* block);
+
IRVarLayout* findVarLayout(IRInst* value);
// Run an operation over every block in a module