summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-09-03 15:57:15 -0400
committerGitHub <noreply@github.com>2023-09-03 12:57:15 -0700
commit641f7bdc4ea4f75385c30d833cce4619a411ec67 (patch)
treeab87024966738aba53b44b9b4117ef59c712589f
parent1d4b5b6fd2433a10cc7ab87626cb560f54b0acbb (diff)
Loop inversion: Handle case where loop can have additional inner breaks (#3178)
* Loop inversion: Handle case where loop can have additional inner breaks - We now have another critical-edge-breaking block `e4` that is the target of inner breaks. - Both `e4` and `e1` (the break branch from the loop condition) branch to the loop's ne break block `b2`. - `b2` is a clone of the old break block `b`, and it simply branches to the old break block. This fixes an IR validation issue in `tests/autodiff/reverse-while-loop-2.slang` * Delete region-wave-masks.slang --------- Co-authored-by: Yong He <yonghe@outlook.com>
-rw-r--r--source/slang/slang-ir-loop-inversion.cpp72
1 files changed, 59 insertions, 13 deletions
diff --git a/source/slang/slang-ir-loop-inversion.cpp b/source/slang/slang-ir-loop-inversion.cpp
index 8967272dd..9d2f13877 100644
--- a/source/slang/slang-ir-loop-inversion.cpp
+++ b/source/slang/slang-ir-loop-inversion.cpp
@@ -172,23 +172,52 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop)
if(p)
c1Params.add(i);
}
+
+ // Create another break block b2 that will act as the new break block for the
+ // loop. The original break block b will become the merge point for the outer condition.
+ //
+ auto b2 = builder.emitBlock();
+ b2->insertBefore(b);
+
+ // Create a copy of the parameters in b. b2 will simply pass these to b.
+ List<IRInst*> b2Params;
+ for(auto p : b->getParams())
+ {
+ auto q = duplicateToParamWithDecorations(builder, cloneEnv, p);
+ b2Params.add(q);
+ }
+ builder.setInsertInto(b2);
+ builder.emitBranch(b, b2Params.getCount(), b2Params.getBuffer());
+
auto e1 = builder.emitBlock();
e1->insertAfter(c1);
- builder.emitBranch(b, c1Params.getCount(), c1Params.getBuffer());
+ builder.emitBranch(b2, c1Params.getCount(), c1Params.getBuffer());
c1bUse.set(e1);
c1Terminator->afterBlock.set(d);
- // Similarly, we have to replace any existing 'break's to break via e1
+
+ // We create another block e4 to handle other breaks from inside the loop, and
+ // rewrite existing breaks to jump to it.
+ // We keep e4 and e1 distinct since after the inversion step, insts in e4 will
+ // be re-written to use values from the loop entry, while e1 will use values from
+ // c1.
+ //
+ auto e4 = builder.emitBlock();
+ e4->insertBefore(c1);
+ builder.emitBranch(b2, c1Params.getCount(), c1Params.getBuffer());
traverseUses(b, [&](IRUse* u){
auto userBlock = u->getUser()->getParent();
// Restrict this to just those blocks within this loop
- if(userBlock != e1 && domTree->dominates(s, userBlock) && !domTree->dominates(b, userBlock))
- u->set(e1);
+ if(userBlock != e4 && userBlock != e1 && userBlock != b2 && domTree->dominates(s, userBlock) && !domTree->dominates(b, userBlock))
+ u->set(e4);
});
+
// We now have
// s: ...1 loop break=b next=c1
+ // e4: goto b2
// c1: if x then goto e1 else goto d (merge at d)
- // e1: goto b
+ // e1: goto b2
// d: goto c1
+ // b2: goto b
// b: ...2
// Duplicate c1 into c2, and using the same cloneEnv, duplicate e1 into e2
@@ -203,27 +232,36 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop)
builder.setInsertAfter(c2Terminator);
const auto newC2Terminator = builder.emitIfElse(c2Terminator->getCondition(), c2Terminator->getTrueBlock(), c2Terminator->getFalseBlock(), b);
c2Terminator->removeAndDeallocate();
+ // The cloned e2 will branch into b2 by default, rewrite it to branch to b, the correct merge point.
+ SLANG_ASSERT(cast<IRUnconditionalBranch>(e2->getTerminator())->getTargetBlock() == b2);
+ cast<IRUnconditionalBranch>(e2->getTerminator())->block.set(b);
+
// We now have
// s: ...1 loop break=b next=c1
+ // e4: goto b2
// c2: if x then goto e2 else goto d (merge at b)
// e2: goto b
// c1: if x then goto e1 else goto d (merge at d)
- // e1: goto b
+ // e1: goto b2
// d: goto c1
+ // b2: goto b
// b: ...2
// move the loop instruction to its own block, l
const auto l = builder.emitBlock();
l->insertAfter(e2);
loop->insertAtEnd(l);
+ e4->insertBefore(c1);
// We now have
// s: ...1 no-termiator
// c2: if x then goto e2 else goto d (merge at b)
// e2: goto b
// l: loop break=b next=c1
+ // e4: goto b2
// c1: if x then goto e1 else goto d (merge at d)
- // e1: goto b
+ // e1: goto b2
// d: goto c1
+ // b2: goto b
// b: ...2
// add a new terminator to s. A jump to c2, our outer conditional. retain
@@ -238,9 +276,11 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop)
// c2: if x then goto e2 else goto d (merge at b)
// e2: goto b
// l: loop break=b next=c1
+ // e4: goto b2
// c1: if x then goto e1 else goto d (merge at d)
- // e1: goto b
+ // e1: goto b2
// d: goto c1
+ // b2: goto b
// b: ...2
// modify c2 to jump to the new loop
@@ -251,9 +291,11 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop)
// c2: if x then goto e2 else goto l (merge at b)
// e2: goto b
// l: loop break=b next=c1
+ // e4: goto b2
// c1: if x then goto e1 else goto d (merge at d)
- // e1: goto b
+ // e1: goto b2
// d: goto c1
+ // b2: goto b
// b: ...2
//
@@ -271,12 +313,12 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop)
// from c1 also which we will accomplish using a new block, e3,
//
loop->block.set(d);
- loop->breakBlock.set(e1);
+ loop->breakBlock.set(b2);
// TODO: This really upsets a few later passes, why isn't it ok to do given
// our "irrelevant continue" condition?
// loop->continueBlock.set(loop->getTargetBlock());
SLANG_ASSERT(d->getFirstParam() == nullptr);
- c1->insertBefore(b);
+ c1->insertBefore(b2);
e1->insertAfter(c1);
List<IRInst*> ps;
for(const auto p : c1->getParams())
@@ -293,6 +335,8 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop)
}
const auto e3 = builder.emitBlock();
e3->insertAfter(c1);
+ b2->insertBefore(b);
+ e4->insertAfter(c1);
builder.emitBranch(d, ps.getCount(), ps.getBuffer());
c1dUse.set(e3);
c1Terminator->afterBlock.set(e3);
@@ -300,11 +344,13 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop)
// s: ...1, goto c2
// c2: if x then goto e2 else goto l (merge at b)
// e2: goto b
- // l: loop break=e1 next=d
+ // l: loop break=b2 next=d
// d: goto c1
+ // e4: goto b2
// c1: if x then goto e1 else goto e3 (merge at e3)
// e3: goto d
- // e1: goto b
+ // e1: goto b2
+ // b2: goto b
// b: ...2
}