summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-simplify-cfg.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-10 14:19:20 -0800
committerGitHub <noreply@github.com>2022-11-10 14:19:20 -0800
commit0b05fe33c82ee301c134f5b9a87a596aa47121c8 (patch)
tree61869daaf5cad2609efcdf239f31c203d64f39b1 /source/slang/slang-ir-simplify-cfg.cpp
parent10834e69b1e483be4116d85b00d4bc0b861da822 (diff)
Fix inlining pass. (#2506)
* Fix inlining pass. * Add more check against corner cases. * Revise comments. * Fixes. * Fix premake script. * Fixes. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-simplify-cfg.cpp')
-rw-r--r--source/slang/slang-ir-simplify-cfg.cpp104
1 files changed, 100 insertions, 4 deletions
diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp
index a1bc38b64..1e247d1d9 100644
--- a/source/slang/slang-ir-simplify-cfg.cpp
+++ b/source/slang/slang-ir-simplify-cfg.cpp
@@ -2,16 +2,109 @@
#include "slang-ir-insts.h"
#include "slang-ir.h"
+#include "slang-ir-dominators.h"
+#include "slang-ir-restructure.h"
namespace Slang
{
-bool processFunc(IRGlobalValueWithCode* func)
+struct CFGSimplificationContext
+{
+ RefPtr<RegionTree> regionTree;
+ RefPtr<IRDominatorTree> domTree;
+};
+
+static BreakableRegion* findBreakableRegion(Region* region)
+{
+ for (;;)
+ {
+ if (auto b = as<BreakableRegion>(region))
+ return b;
+ region = region->getParent();
+ if (!region)
+ return nullptr;
+ }
+}
+
+// Test if a loop is trivial: a trivial loop runs for a single iteration without any back edges, and
+// there is only one break out of the loop at the very end. The function generates `regionTree` if
+// it is needed and hasn't been generated yet.
+static bool isTrivialSingleIterationLoop(
+ IRGlobalValueWithCode* func,
+ IRLoop* loop,
+ CFGSimplificationContext& inoutContext)
+{
+ auto targetBlock = loop->getTargetBlock();
+ if (targetBlock->getPredecessors().getCount() != 1) return false;
+ if (*targetBlock->getPredecessors().begin() != loop->getParent()) return false;
+
+ int useCount = 0;
+ for (auto use = loop->getBreakBlock()->firstUse; use; use = use->nextUse)
+ {
+ if (use->getUser() == loop)
+ continue;
+ useCount++;
+ if (useCount > 1)
+ return false;
+ }
+
+ // The loop has passed simple test.
+ //
+ // We need to verify this is a trivial loop by checking if there is any multi-level breaks
+ // that skips out of this loop.
+
+ if (!inoutContext.domTree)
+ inoutContext.domTree = computeDominatorTree(func);
+ if (!inoutContext.regionTree)
+ inoutContext.regionTree = generateRegionTreeForFunc(func, nullptr);
+
+ SimpleRegion* targetBlockRegion = nullptr;
+ if (!inoutContext.regionTree->mapBlockToRegion.TryGetValue(targetBlock, targetBlockRegion))
+ return false;
+ BreakableRegion* loopBreakableRegion = findBreakableRegion(targetBlockRegion);
+ LoopRegion* loopRegion = as<LoopRegion>(loopBreakableRegion);
+ if (!loopRegion)
+ return false;
+ for (auto block : func->getBlocks())
+ {
+ if (!inoutContext.domTree->dominates(loop->getTargetBlock(), block))
+ continue;
+ if (inoutContext.domTree->dominates(loop->getBreakBlock(), block))
+ continue;
+ SimpleRegion* region = nullptr;
+ if (!inoutContext.regionTree->mapBlockToRegion.TryGetValue(block, region))
+ return false;
+
+ for (auto branchTarget : block->getSuccessors())
+ {
+ SimpleRegion* targetRegion = nullptr;
+ if (!inoutContext.regionTree->mapBlockToRegion.TryGetValue(branchTarget, targetRegion))
+ return false;
+ // If multi-level break out that skips over this loop exists, then this is not a trivial loop.
+ if (targetRegion->isDescendentOf(loopRegion))
+ continue;
+ if (targetBlock != loop->getBreakBlock())
+ return false;
+ if (findBreakableRegion(region) != loopRegion)
+ {
+ // If the break is initiated from a nested region, this is not trivial.
+ return false;
+ }
+ }
+ }
+
+ return true;
+}
+
+static bool processFunc(IRGlobalValueWithCode* func)
{
auto firstBlock = func->getFirstBlock();
if (!firstBlock)
return false;
+ // Lazily generated region tree.
+ CFGSimplificationContext simplificationContext;
+
SharedIRBuilder sharedBuilder(func->getModule());
IRBuilder builder(&sharedBuilder);
@@ -35,10 +128,12 @@ bool processFunc(IRGlobalValueWithCode* func)
loop->continueBlock.set(loop->getTargetBlock());
continueBlock->removeAndDeallocate();
}
- // If there isn't any actual back jumps into loop target, remove the header
- // and turn it into a normal branch.
+
+ // If there isn't any actual back jumps into loop target and there is a trivial
+ // break at the end of the loop, we can remove the header and turn it into
+ // a normal branch.
auto targetBlock = loop->getTargetBlock();
- if (targetBlock->getPredecessors().getCount() == 1 && *targetBlock->getPredecessors().begin() == block)
+ if (isTrivialSingleIterationLoop(func, loop, simplificationContext))
{
builder.setInsertBefore(loop);
List<IRInst*> args;
@@ -50,6 +145,7 @@ bool processFunc(IRGlobalValueWithCode* func)
loop->removeAndDeallocate();
}
}
+
// If `block` does not end with an unconditional branch, bail.
if (block->getTerminator()->getOp() != kIROp_unconditionalBranch)
break;