diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-10 14:19:20 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-10 14:19:20 -0800 |
| commit | 0b05fe33c82ee301c134f5b9a87a596aa47121c8 (patch) | |
| tree | 61869daaf5cad2609efcdf239f31c203d64f39b1 /source/slang/slang-ir-simplify-cfg.cpp | |
| parent | 10834e69b1e483be4116d85b00d4bc0b861da822 (diff) | |
Fix inlining pass. (#2506)
* Fix inlining pass.
* Add more check against corner cases.
* Revise comments.
* Fixes.
* Fix premake script.
* Fixes.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-simplify-cfg.cpp')
| -rw-r--r-- | source/slang/slang-ir-simplify-cfg.cpp | 104 |
1 files changed, 100 insertions, 4 deletions
diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index a1bc38b64..1e247d1d9 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -2,16 +2,109 @@ #include "slang-ir-insts.h" #include "slang-ir.h" +#include "slang-ir-dominators.h" +#include "slang-ir-restructure.h" namespace Slang { -bool processFunc(IRGlobalValueWithCode* func) +struct CFGSimplificationContext +{ + RefPtr<RegionTree> regionTree; + RefPtr<IRDominatorTree> domTree; +}; + +static BreakableRegion* findBreakableRegion(Region* region) +{ + for (;;) + { + if (auto b = as<BreakableRegion>(region)) + return b; + region = region->getParent(); + if (!region) + return nullptr; + } +} + +// Test if a loop is trivial: a trivial loop runs for a single iteration without any back edges, and +// there is only one break out of the loop at the very end. The function generates `regionTree` if +// it is needed and hasn't been generated yet. +static bool isTrivialSingleIterationLoop( + IRGlobalValueWithCode* func, + IRLoop* loop, + CFGSimplificationContext& inoutContext) +{ + auto targetBlock = loop->getTargetBlock(); + if (targetBlock->getPredecessors().getCount() != 1) return false; + if (*targetBlock->getPredecessors().begin() != loop->getParent()) return false; + + int useCount = 0; + for (auto use = loop->getBreakBlock()->firstUse; use; use = use->nextUse) + { + if (use->getUser() == loop) + continue; + useCount++; + if (useCount > 1) + return false; + } + + // The loop has passed simple test. + // + // We need to verify this is a trivial loop by checking if there is any multi-level breaks + // that skips out of this loop. + + if (!inoutContext.domTree) + inoutContext.domTree = computeDominatorTree(func); + if (!inoutContext.regionTree) + inoutContext.regionTree = generateRegionTreeForFunc(func, nullptr); + + SimpleRegion* targetBlockRegion = nullptr; + if (!inoutContext.regionTree->mapBlockToRegion.TryGetValue(targetBlock, targetBlockRegion)) + return false; + BreakableRegion* loopBreakableRegion = findBreakableRegion(targetBlockRegion); + LoopRegion* loopRegion = as<LoopRegion>(loopBreakableRegion); + if (!loopRegion) + return false; + for (auto block : func->getBlocks()) + { + if (!inoutContext.domTree->dominates(loop->getTargetBlock(), block)) + continue; + if (inoutContext.domTree->dominates(loop->getBreakBlock(), block)) + continue; + SimpleRegion* region = nullptr; + if (!inoutContext.regionTree->mapBlockToRegion.TryGetValue(block, region)) + return false; + + for (auto branchTarget : block->getSuccessors()) + { + SimpleRegion* targetRegion = nullptr; + if (!inoutContext.regionTree->mapBlockToRegion.TryGetValue(branchTarget, targetRegion)) + return false; + // If multi-level break out that skips over this loop exists, then this is not a trivial loop. + if (targetRegion->isDescendentOf(loopRegion)) + continue; + if (targetBlock != loop->getBreakBlock()) + return false; + if (findBreakableRegion(region) != loopRegion) + { + // If the break is initiated from a nested region, this is not trivial. + return false; + } + } + } + + return true; +} + +static bool processFunc(IRGlobalValueWithCode* func) { auto firstBlock = func->getFirstBlock(); if (!firstBlock) return false; + // Lazily generated region tree. + CFGSimplificationContext simplificationContext; + SharedIRBuilder sharedBuilder(func->getModule()); IRBuilder builder(&sharedBuilder); @@ -35,10 +128,12 @@ bool processFunc(IRGlobalValueWithCode* func) loop->continueBlock.set(loop->getTargetBlock()); continueBlock->removeAndDeallocate(); } - // If there isn't any actual back jumps into loop target, remove the header - // and turn it into a normal branch. + + // If there isn't any actual back jumps into loop target and there is a trivial + // break at the end of the loop, we can remove the header and turn it into + // a normal branch. auto targetBlock = loop->getTargetBlock(); - if (targetBlock->getPredecessors().getCount() == 1 && *targetBlock->getPredecessors().begin() == block) + if (isTrivialSingleIterationLoop(func, loop, simplificationContext)) { builder.setInsertBefore(loop); List<IRInst*> args; @@ -50,6 +145,7 @@ bool processFunc(IRGlobalValueWithCode* func) loop->removeAndDeallocate(); } } + // If `block` does not end with an unconditional branch, bail. if (block->getTerminator()->getOp() != kIROp_unconditionalBranch) break; |
