summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir-simplify-cfg.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-24 10:01:47 -0800
committerGitHub <noreply@github.com>2023-02-24 10:01:47 -0800
commitbd6306cdaa4a49344658bd026721b6532e103d09 (patch)
treebb7f666d426e6cfc7777a3ccac0a1d628588eb39 /source/slang/slang-ir-simplify-cfg.cpp
parente8c08e7ecb1124f115a1d1042277776193122b57 (diff)
More control flow simplifications. (#2673)
* More control flow and Phi param simplifications. * Fix. * Fix gcc error. * Fix. * More IR cleanup. * Fix bug in phi param dce + ifelse simplify. * Propagate and DCE side-effect-free functions. * Enhance CFG simplifcation to remove loops with no side effects. * Fix. * Fixes. * Fix tests. Add [__AlwaysFoldIntoUseSite] for rayPayloadLocation. * More cleanup. * Fixes. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-simplify-cfg.cpp')
-rw-r--r--source/slang/slang-ir-simplify-cfg.cpp467
1 files changed, 451 insertions, 16 deletions
diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp
index 7e9e105e1..b814442fa 100644
--- a/source/slang/slang-ir-simplify-cfg.cpp
+++ b/source/slang/slang-ir-simplify-cfg.cpp
@@ -4,6 +4,8 @@
#include "slang-ir.h"
#include "slang-ir-dominators.h"
#include "slang-ir-restructure.h"
+#include "slang-ir-util.h"
+#include "slang-ir-loop-unroll.h"
namespace Slang
{
@@ -31,8 +33,7 @@ static BreakableRegion* findBreakableRegion(Region* region)
// it is needed and hasn't been generated yet.
static bool isTrivialSingleIterationLoop(
IRGlobalValueWithCode* func,
- IRLoop* loop,
- CFGSimplificationContext& inoutContext)
+ IRLoop* loop)
{
auto targetBlock = loop->getTargetBlock();
if (targetBlock->getPredecessors().getCount() != 1) return false;
@@ -52,14 +53,14 @@ static bool isTrivialSingleIterationLoop(
//
// We need to verify this is a trivial loop by checking if there is any multi-level breaks
// that skips out of this loop.
-
- if (!inoutContext.domTree)
- inoutContext.domTree = computeDominatorTree(func);
- if (!inoutContext.regionTree)
- inoutContext.regionTree = generateRegionTreeForFunc(func, nullptr);
+ CFGSimplificationContext context;
+ if (!context.domTree)
+ context.domTree = computeDominatorTree(func);
+ if (!context.regionTree)
+ context.regionTree = generateRegionTreeForFunc(func, nullptr);
SimpleRegion* targetBlockRegion = nullptr;
- if (!inoutContext.regionTree->mapBlockToRegion.TryGetValue(targetBlock, targetBlockRegion))
+ if (!context.regionTree->mapBlockToRegion.TryGetValue(targetBlock, targetBlockRegion))
return false;
BreakableRegion* loopBreakableRegion = findBreakableRegion(targetBlockRegion);
LoopRegion* loopRegion = as<LoopRegion>(loopBreakableRegion);
@@ -67,18 +68,18 @@ static bool isTrivialSingleIterationLoop(
return false;
for (auto block : func->getBlocks())
{
- if (!inoutContext.domTree->dominates(loop->getTargetBlock(), block))
+ if (!context.domTree->dominates(loop->getTargetBlock(), block))
continue;
- if (inoutContext.domTree->dominates(loop->getBreakBlock(), block))
+ if (context.domTree->dominates(loop->getBreakBlock(), block))
continue;
SimpleRegion* region = nullptr;
- if (!inoutContext.regionTree->mapBlockToRegion.TryGetValue(block, region))
+ if (!context.regionTree->mapBlockToRegion.TryGetValue(block, region))
return false;
for (auto branchTarget : block->getSuccessors())
{
SimpleRegion* targetRegion = nullptr;
- if (!inoutContext.regionTree->mapBlockToRegion.TryGetValue(branchTarget, targetRegion))
+ if (!context.regionTree->mapBlockToRegion.TryGetValue(branchTarget, targetRegion))
return false;
// If multi-level break out that skips over this loop exists, then this is not a trivial loop.
if (targetRegion->isDescendentOf(loopRegion))
@@ -96,6 +97,104 @@ static bool isTrivialSingleIterationLoop(
return true;
}
+static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst)
+{
+ auto blocks = collectBlocksInLoop(func, loopInst);
+ HashSet<IRBlock*> loopBlocks;
+ for (auto b : blocks)
+ loopBlocks.Add(b);
+ auto addressHasOutOfLoopUses = [&](IRInst* addr)
+ {
+ // The entire access chain of `addr` must have no uses out side the loop.
+ // The root variable must be a local var.
+ for (auto chainNode = addr; chainNode;)
+ {
+ if (getParentFunc(chainNode) != func)
+ return true;
+ for (auto use = chainNode->firstUse; use; use = use->nextUse)
+ {
+ if (!loopBlocks.Contains(as<IRBlock>(use->getUser()->getParent())))
+ return true;
+ }
+ switch (chainNode->getOp())
+ {
+ case kIROp_GetElementPtr:
+ case kIROp_FieldAddress:
+ chainNode = chainNode->getOperand(0);
+ continue;
+ case kIROp_Var:
+ break;
+ default:
+ return true;
+ }
+ break;
+ }
+ return false;
+ };
+
+ for (auto b : blocks)
+ {
+ for (auto inst : b->getChildren())
+ {
+ // Is this inst used anywhere outside the loop? If so the loop has side effect.
+ for (auto use = inst->firstUse; use; use = use->nextUse)
+ {
+ if (!loopBlocks.Contains(as<IRBlock>(use->getUser()->getParent())))
+ return true;
+ }
+
+ // The inst can't possibly have side effect? Skip it.
+ if (!inst->mightHaveSideEffects())
+ continue;
+
+ // This inst might have side effect, try to prove that the
+ // side effect does not leak beyond the scope of the loop.
+ if (auto call = as<IRCall>(inst))
+ {
+ auto callee = getResolvedInstForDecorations(call->getCallee());
+ if (!callee || !callee->findDecoration<IRReadNoneDecoration>())
+ return true;
+ // We are calling a pure function, check if any of the return
+ // variables are used outside the loop.
+ for (UInt i = 0; i < call->getArgCount(); i++)
+ {
+ auto arg = call->getArg(i);
+ if (!isValueType(arg->getDataType()))
+ {
+ if (addressHasOutOfLoopUses(arg))
+ return true;
+ }
+ }
+ }
+ else if (auto store = as<IRStore>(inst))
+ {
+ if (addressHasOutOfLoopUses(store->getPtr()))
+ return true;
+ }
+ else if (auto branch = as<IRUnconditionalBranch>(inst))
+ {
+ if (loopBlocks.Contains(branch->getTargetBlock()))
+ continue;
+ // Branching out of the loop with some argument is considered
+ // having a side effect.
+ if (branch->getArgCount() != 0)
+ return true;
+ }
+ else if (as<IRIfElse>(inst) || as<IRSwitch>(inst) || as<IRLoop>(inst))
+ {
+ // We are starting a sub control flow.
+ // This is considered side effect free.
+ }
+ else
+ {
+ // For all other insts, we assume it has a global side effect.
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
static bool removeDeadBlocks(IRGlobalValueWithCode* func)
{
bool changed = false;
@@ -142,15 +241,327 @@ static bool removeDeadBlocks(IRGlobalValueWithCode* func)
return changed;
}
+// Return the true of the if-else branch block if the branch is a trivial jump
+// to after block with no other insts.
+static bool isTrivialIfElseBranch(IRIfElse* condBranch, IRBlock* branchBlock)
+{
+ if (branchBlock != condBranch->getAfterBlock())
+ {
+ if (auto br = as<IRUnconditionalBranch>(branchBlock->getFirstOrdinaryInst()))
+ {
+ if (br->getTargetBlock() == condBranch->getAfterBlock() && br->getOp() == kIROp_unconditionalBranch)
+ {
+ return true;
+ }
+ }
+ }
+ else
+ {
+ return true;
+ }
+ return false;
+}
+
+static bool arePhiArgsEquivalentInBranches(IRIfElse* ifElse)
+{
+ // If one of the branch target is afterBlock itself, and the other branch
+ // is a trivial block that jumps into the afterBlock, this if-else is trivial.
+ // In this case the argCount must be 0 because a block with phi parameters can't
+ // be used as targets in a conditional branch.
+ auto branch1 = ifElse->getTrueBlock();
+ auto branch2 = ifElse->getFalseBlock();
+ auto afterBlock = ifElse->getAfterBlock();
+
+ if (branch1 == afterBlock) return true;
+ if (branch2 == afterBlock) return true;
+
+ auto branchInst1 = as<IRUnconditionalBranch>(branch1->getTerminator());
+ auto branchInst2 = as<IRUnconditionalBranch>(branch2->getTerminator());
+ if (!branchInst1) return false;
+ if (!branchInst2) return false;
+
+ // If both branches are trivial blocks, we must compare the arguments.
+ if (branchInst1->getArgCount() != branchInst2->getArgCount())
+ {
+ // This should never happen, return false now to be safe.
+ return false;
+ }
+
+ for (UInt i = 0; i < branchInst1->getArgCount(); i++)
+ {
+ if (branchInst1->getArg(i) != branchInst2->getArg(i))
+ {
+ // argument is different, the if-else is non-trivial.
+ return false;
+ }
+ }
+ return true;
+}
+
+static bool isTrivialIfElse(IRIfElse* condBranch, bool& isTrueBranchTrivial, bool& isFalseBranchTrivial)
+{
+ isTrueBranchTrivial = isTrivialIfElseBranch(condBranch, condBranch->getTrueBlock());
+ isFalseBranchTrivial = isTrivialIfElseBranch(condBranch, condBranch->getFalseBlock());
+ if (isTrueBranchTrivial && isFalseBranchTrivial)
+ {
+ if (arePhiArgsEquivalentInBranches(condBranch))
+ return true;
+ }
+ return false;
+}
+
+#if 0
+static bool tryMoveFalseBranchToTrueBranch(IRBuilder& builder, IRIfElse* ifElseInst)
+{
+ auto falseBlock = ifElseInst->getFalseBlock();
+ if (falseBlock == ifElseInst->getAfterBlock())
+ return false;
+ if (auto termInst = as<IRUnconditionalBranch>(falseBlock->getTerminator()))
+ {
+ // We can't fold a branch with arguments into the ifElse.
+ if (termInst->getArgCount() != 0)
+ return false;
+ }
+ ifElseInst->trueBlock.set(falseBlock);
+ ifElseInst->falseBlock.set(ifElseInst->getAfterBlock());
+ builder.setInsertBefore(ifElseInst);
+ auto newCondition = builder.emitNot(builder.getBoolType(), ifElseInst->getCondition());
+ ifElseInst->condition.set(newCondition);
+ return true;
+}
+#endif
+
+static bool tryEliminateFalseBranch(IRIfElse* ifElseInst)
+{
+ auto falseBlock = ifElseInst->getFalseBlock();
+ if (falseBlock == ifElseInst->getAfterBlock())
+ return false;
+ if (auto termInst = as<IRUnconditionalBranch>(falseBlock->getTerminator()))
+ {
+ // We can't fold a branch with arguments into the ifElse.
+ if (termInst->getArgCount() != 0)
+ return false;
+ }
+ ifElseInst->falseBlock.set(ifElseInst->getAfterBlock());
+ return true;
+}
+
+static bool trySimplifyIfElse(IRBuilder& builder, IRIfElse* ifElseInst)
+{
+ bool isTrueBranchTrivial = false;
+ bool isFalseBranchTrivial = false;
+ if (isTrivialIfElse(ifElseInst, isTrueBranchTrivial, isFalseBranchTrivial))
+ {
+ // If both branches of `if-else` are trivial jumps into after block,
+ // we can get rid of the entire conditional branch and replace it
+ // with a jump into the after block.
+ if (auto termInst = as<IRUnconditionalBranch>(ifElseInst->getTrueBlock()->getTerminator()))
+ {
+ List<IRInst*> args;
+ for (UInt i = 0; i < termInst->getArgCount(); i++)
+ args.add(termInst->getArg(i));
+ builder.setInsertBefore(ifElseInst);
+ builder.emitBranch(ifElseInst->getAfterBlock(), (Int)args.getCount(), args.getBuffer());
+ ifElseInst->removeAndDeallocate();
+ return true;
+ }
+ }
+ else if (isTrueBranchTrivial)
+ {
+ // If true branch is empty, we move false branch to true branch and invert the condition.
+ // TODO: diabled for now since our auto-diff pass can't handle loops whose body is on the false
+ // side of condition.
+ //return tryMoveFalseBranchToTrueBranch(builder, ifElseInst);
+ }
+ else if (isFalseBranchTrivial)
+ {
+ // If false branch is empty, we set it to afterBlock.
+ return tryEliminateFalseBranch(ifElseInst);
+ }
+ return false;
+}
+
+static bool isTrueLit(IRInst* lit)
+{
+ if (auto boolLit = as<IRBoolLit>(lit))
+ return boolLit->getValue();
+ return false;
+}
+static bool isFalseLit(IRInst* lit)
+{
+ if (auto boolLit = as<IRBoolLit>(lit))
+ return !boolLit->getValue();
+ return false;
+}
+
+static bool simplifyBoolPhiParam(IRIfElse* ifElse, Array<IRBlock*, 2>& preds, IRParam* param, UInt paramIndex)
+{
+ // For bool params where its value is assigned from the same `if-else` statement,
+ // we can simplify it into an expression of the condition of the source `if-else`.
+
+ if (!param->getDataType() || param->getDataType()->getOp() != kIROp_BoolType)
+ return false;
+
+ auto branch0 = as<IRUnconditionalBranch>(preds[0]->getTerminator());
+ if (!branch0)
+ return false;
+ if (branch0->getArgCount() <= paramIndex)
+ return false;
+ auto branch1 = as<IRUnconditionalBranch>(preds[1]->getTerminator());
+ if (!branch1)
+ return false;
+ if (branch1->getArgCount() <= paramIndex)
+ return false;
+
+ IRInst* replacement = nullptr;
+ if (isTrueLit(branch0->getArg(paramIndex)) && isFalseLit(branch1->getArg(paramIndex)))
+ {
+ replacement = ifElse->getCondition();
+ }
+ else if (isFalseLit(branch0->getArg(paramIndex)) && isTrueLit(branch1->getArg(paramIndex)))
+ {
+ IRBuilder builder(param);
+ setInsertBeforeOrdinaryInst(&builder, param);
+ replacement = builder.emitNot(builder.getBoolType(), ifElse->getCondition());
+ }
+ if (replacement)
+ {
+ param->replaceUsesWith(replacement);
+ param->removeAndDeallocate();
+ branch0->removeArgument(paramIndex);
+ branch1->removeArgument(paramIndex);
+ return true;
+ }
+ return false;
+}
+
+static bool simplifyBoolPhiParams(IRBlock* block)
+{
+ if (!block)
+ return false;
+
+ if (block->getPredecessors().getCount() != 2)
+ return false;
+
+ Array<IRBlock*, 2> preds;
+ for (auto pred : block->getPredecessors())
+ preds.add(pred);
+
+ IRBlock* ifElseBlock = nullptr;
+ if (preds[0]->getPredecessors().getCount() != 1)
+ return false;
+ ifElseBlock = *(preds[0]->getPredecessors().begin());
+ if (preds[1]->getPredecessors().getCount() != 1)
+ return false;
+ auto p = *(preds[1]->getPredecessors().begin());
+ if (p != ifElseBlock)
+ return false;
+
+ auto ifElse = as<IRIfElse>(ifElseBlock->getTerminator());
+ if (!ifElse)
+ return false;
+
+ if (ifElse->getTrueBlock() == preds[1])
+ {
+ Swap(preds[0], preds[1]);
+ }
+ SLANG_ASSERT(ifElse->getTrueBlock() == preds[0] && ifElse->getFalseBlock() == preds[1]);
+
+ List<IRParam*> params;
+ for (auto param : block->getParams())
+ params.add(param);
+ bool changed = false;
+ for (Index i = params.getCount() - 1; i >= 0; i--)
+ {
+ changed |= simplifyBoolPhiParam(ifElse, preds, params[i], (UInt)i);
+ }
+ return changed;
+}
+
+static bool removeTrivialPhiParams(IRBlock* block)
+{
+ // We can remove a phi parmeter if:
+ // 1. all arguments to a parameter is the same (not really a phi).
+ // 2. the arguments to the parameter is always the same as arguments to another existing parameter (duplicate phi).
+
+ bool changed = false;
+ List<IRParam*> params;
+ struct ParamState
+ {
+ bool areKnownValueSame = true;
+ IRInst* knownValue = nullptr;
+ OrderedHashSet<UInt> sameAsParamSet;
+ };
+ List<ParamState> args;
+ List<IRUnconditionalBranch*> termInsts;
+ for (auto param : block->getParams())
+ {
+ params.add(param);
+ args.add(ParamState());
+ }
+
+ if (!params.getCount())
+ return false;
+
+ for (UInt i = 1; i < (UInt)args.getCount(); i++)
+ for (UInt j = 0; j < i; j++)
+ args[i].sameAsParamSet.Add(j);
+
+ for (auto pred : block->getPredecessors())
+ {
+ auto termInst = as<IRUnconditionalBranch>(pred->getTerminator());
+ if (!termInst)
+ return false;
+ SLANG_ASSERT(termInst->getArgCount() == (UInt)args.getCount());
+ termInsts.add(termInst);
+ for (UInt i = 0; i < termInst->getArgCount(); i++)
+ {
+ if (args[i].areKnownValueSame)
+ {
+ if (args[i].knownValue == nullptr)
+ args[i].knownValue = termInst->getArg(i);
+ else if (args[i].knownValue != termInst->getArg(i))
+ args[i].areKnownValueSame = false;
+ }
+ for (UInt j = 0; j < i; j++)
+ {
+ if (termInst->getArg(i) != termInst->getArg(j))
+ {
+ args[i].sameAsParamSet.Remove(j);
+ }
+ }
+ }
+ }
+ for (Index i = args.getCount() - 1; i >= 0; i--)
+ {
+ IRInst* targetVal = nullptr;
+ if (args[i].areKnownValueSame)
+ {
+ targetVal = args[i].knownValue;
+ }
+ else if (args[i].sameAsParamSet.Count())
+ {
+ auto targetParamId = *args[i].sameAsParamSet.begin();
+ targetVal = params[targetParamId];
+ }
+ if (targetVal)
+ {
+ params[i]->replaceUsesWith(args[i].knownValue);
+ params[i]->removeAndDeallocate();
+ for (auto termInst : termInsts)
+ termInst->removeArgument((UInt)i);
+ changed = true;
+ }
+ }
+ return changed;
+}
+
static bool processFunc(IRGlobalValueWithCode* func)
{
auto firstBlock = func->getFirstBlock();
if (!firstBlock)
return false;
- // Lazily generated region tree.
- CFGSimplificationContext simplificationContext;
-
IRBuilder builder(func->getModule());
bool changed = false;
@@ -165,6 +576,14 @@ static bool processFunc(IRGlobalValueWithCode* func)
workList.fastRemoveAt(0);
while (block)
{
+ // If all arguments to a phi parameter are the known to be the same,
+ // we can safely replace the phi parameter with the argument.
+ if (block != func->getFirstBlock())
+ {
+ changed |= simplifyBoolPhiParams(block);
+ changed |= removeTrivialPhiParams(block);
+ }
+
if (auto loop = as<IRLoop>(block->getTerminator()))
{
// If continue block is unreachable, remove it.
@@ -179,7 +598,7 @@ static bool processFunc(IRGlobalValueWithCode* func)
// break at the end of the loop, we can remove the header and turn it into
// a normal branch.
auto targetBlock = loop->getTargetBlock();
- if (isTrivialSingleIterationLoop(func, loop, simplificationContext))
+ if (isTrivialSingleIterationLoop(func, loop))
{
builder.setInsertBefore(loop);
List<IRInst*> args;
@@ -189,7 +608,22 @@ static bool processFunc(IRGlobalValueWithCode* func)
}
builder.emitBranch(targetBlock, args.getCount(), args.getBuffer());
loop->removeAndDeallocate();
+ changed = true;
}
+ else if (!doesLoopHasSideEffect(func, loop))
+ {
+ // The loop isn't computing anything useful outside the loop.
+ // We can delete the entire loop.
+ builder.setInsertBefore(loop);
+ SLANG_ASSERT(loop->getBreakBlock()->getFirstParam() == nullptr);
+ builder.emitBranch(loop->getBreakBlock());
+ loop->removeAndDeallocate();
+ changed = true;
+ }
+ }
+ else if (auto condBranch = as<IRIfElse>(block->getTerminator()))
+ {
+ changed |= trySimplifyIfElse(builder, condBranch);
}
// If `block` does not end with an unconditional branch, bail.
@@ -225,6 +659,7 @@ static bool processFunc(IRGlobalValueWithCode* func)
branch->removeAndDeallocate();
assert(!successor->hasUses());
successor->removeAndDeallocate();
+ break;
}
for (auto successor : block->getSuccessors())
{