diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-09-03 15:57:15 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-09-03 12:57:15 -0700 |
| commit | 641f7bdc4ea4f75385c30d833cce4619a411ec67 (patch) | |
| tree | ab87024966738aba53b44b9b4117ef59c712589f | |
| parent | 1d4b5b6fd2433a10cc7ab87626cb560f54b0acbb (diff) | |
Loop inversion: Handle case where loop can have additional inner breaks (#3178)
* Loop inversion: Handle case where loop can have additional inner breaks
- We now have another critical-edge-breaking block `e4` that is the target of inner breaks.
- Both `e4` and `e1` (the break branch from the loop condition) branch to the loop's ne break block `b2`.
- `b2` is a clone of the old break block `b`, and it simply branches to the old break block.
This fixes an IR validation issue in `tests/autodiff/reverse-while-loop-2.slang`
* Delete region-wave-masks.slang
---------
Co-authored-by: Yong He <yonghe@outlook.com>
| -rw-r--r-- | source/slang/slang-ir-loop-inversion.cpp | 72 |
1 files changed, 59 insertions, 13 deletions
diff --git a/source/slang/slang-ir-loop-inversion.cpp b/source/slang/slang-ir-loop-inversion.cpp index 8967272dd..9d2f13877 100644 --- a/source/slang/slang-ir-loop-inversion.cpp +++ b/source/slang/slang-ir-loop-inversion.cpp @@ -172,23 +172,52 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) if(p) c1Params.add(i); } + + // Create another break block b2 that will act as the new break block for the + // loop. The original break block b will become the merge point for the outer condition. + // + auto b2 = builder.emitBlock(); + b2->insertBefore(b); + + // Create a copy of the parameters in b. b2 will simply pass these to b. + List<IRInst*> b2Params; + for(auto p : b->getParams()) + { + auto q = duplicateToParamWithDecorations(builder, cloneEnv, p); + b2Params.add(q); + } + builder.setInsertInto(b2); + builder.emitBranch(b, b2Params.getCount(), b2Params.getBuffer()); + auto e1 = builder.emitBlock(); e1->insertAfter(c1); - builder.emitBranch(b, c1Params.getCount(), c1Params.getBuffer()); + builder.emitBranch(b2, c1Params.getCount(), c1Params.getBuffer()); c1bUse.set(e1); c1Terminator->afterBlock.set(d); - // Similarly, we have to replace any existing 'break's to break via e1 + + // We create another block e4 to handle other breaks from inside the loop, and + // rewrite existing breaks to jump to it. + // We keep e4 and e1 distinct since after the inversion step, insts in e4 will + // be re-written to use values from the loop entry, while e1 will use values from + // c1. + // + auto e4 = builder.emitBlock(); + e4->insertBefore(c1); + builder.emitBranch(b2, c1Params.getCount(), c1Params.getBuffer()); traverseUses(b, [&](IRUse* u){ auto userBlock = u->getUser()->getParent(); // Restrict this to just those blocks within this loop - if(userBlock != e1 && domTree->dominates(s, userBlock) && !domTree->dominates(b, userBlock)) - u->set(e1); + if(userBlock != e4 && userBlock != e1 && userBlock != b2 && domTree->dominates(s, userBlock) && !domTree->dominates(b, userBlock)) + u->set(e4); }); + // We now have // s: ...1 loop break=b next=c1 + // e4: goto b2 // c1: if x then goto e1 else goto d (merge at d) - // e1: goto b + // e1: goto b2 // d: goto c1 + // b2: goto b // b: ...2 // Duplicate c1 into c2, and using the same cloneEnv, duplicate e1 into e2 @@ -203,27 +232,36 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) builder.setInsertAfter(c2Terminator); const auto newC2Terminator = builder.emitIfElse(c2Terminator->getCondition(), c2Terminator->getTrueBlock(), c2Terminator->getFalseBlock(), b); c2Terminator->removeAndDeallocate(); + // The cloned e2 will branch into b2 by default, rewrite it to branch to b, the correct merge point. + SLANG_ASSERT(cast<IRUnconditionalBranch>(e2->getTerminator())->getTargetBlock() == b2); + cast<IRUnconditionalBranch>(e2->getTerminator())->block.set(b); + // We now have // s: ...1 loop break=b next=c1 + // e4: goto b2 // c2: if x then goto e2 else goto d (merge at b) // e2: goto b // c1: if x then goto e1 else goto d (merge at d) - // e1: goto b + // e1: goto b2 // d: goto c1 + // b2: goto b // b: ...2 // move the loop instruction to its own block, l const auto l = builder.emitBlock(); l->insertAfter(e2); loop->insertAtEnd(l); + e4->insertBefore(c1); // We now have // s: ...1 no-termiator // c2: if x then goto e2 else goto d (merge at b) // e2: goto b // l: loop break=b next=c1 + // e4: goto b2 // c1: if x then goto e1 else goto d (merge at d) - // e1: goto b + // e1: goto b2 // d: goto c1 + // b2: goto b // b: ...2 // add a new terminator to s. A jump to c2, our outer conditional. retain @@ -238,9 +276,11 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) // c2: if x then goto e2 else goto d (merge at b) // e2: goto b // l: loop break=b next=c1 + // e4: goto b2 // c1: if x then goto e1 else goto d (merge at d) - // e1: goto b + // e1: goto b2 // d: goto c1 + // b2: goto b // b: ...2 // modify c2 to jump to the new loop @@ -251,9 +291,11 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) // c2: if x then goto e2 else goto l (merge at b) // e2: goto b // l: loop break=b next=c1 + // e4: goto b2 // c1: if x then goto e1 else goto d (merge at d) - // e1: goto b + // e1: goto b2 // d: goto c1 + // b2: goto b // b: ...2 // @@ -271,12 +313,12 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) // from c1 also which we will accomplish using a new block, e3, // loop->block.set(d); - loop->breakBlock.set(e1); + loop->breakBlock.set(b2); // TODO: This really upsets a few later passes, why isn't it ok to do given // our "irrelevant continue" condition? // loop->continueBlock.set(loop->getTargetBlock()); SLANG_ASSERT(d->getFirstParam() == nullptr); - c1->insertBefore(b); + c1->insertBefore(b2); e1->insertAfter(c1); List<IRInst*> ps; for(const auto p : c1->getParams()) @@ -293,6 +335,8 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) } const auto e3 = builder.emitBlock(); e3->insertAfter(c1); + b2->insertBefore(b); + e4->insertAfter(c1); builder.emitBranch(d, ps.getCount(), ps.getBuffer()); c1dUse.set(e3); c1Terminator->afterBlock.set(e3); @@ -300,11 +344,13 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) // s: ...1, goto c2 // c2: if x then goto e2 else goto l (merge at b) // e2: goto b - // l: loop break=e1 next=d + // l: loop break=b2 next=d // d: goto c1 + // e4: goto b2 // c1: if x then goto e1 else goto e3 (merge at e3) // e3: goto d - // e1: goto b + // e1: goto b2 + // b2: goto b // b: ...2 } |
