From 641f7bdc4ea4f75385c30d833cce4619a411ec67 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Sun, 3 Sep 2023 15:57:15 -0400 Subject: Loop inversion: Handle case where loop can have additional inner breaks (#3178) * Loop inversion: Handle case where loop can have additional inner breaks - We now have another critical-edge-breaking block `e4` that is the target of inner breaks. - Both `e4` and `e1` (the break branch from the loop condition) branch to the loop's ne break block `b2`. - `b2` is a clone of the old break block `b`, and it simply branches to the old break block. This fixes an IR validation issue in `tests/autodiff/reverse-while-loop-2.slang` * Delete region-wave-masks.slang --------- Co-authored-by: Yong He --- source/slang/slang-ir-loop-inversion.cpp | 72 ++++++++++++++++++++++++++------ 1 file changed, 59 insertions(+), 13 deletions(-) diff --git a/source/slang/slang-ir-loop-inversion.cpp b/source/slang/slang-ir-loop-inversion.cpp index 8967272dd..9d2f13877 100644 --- a/source/slang/slang-ir-loop-inversion.cpp +++ b/source/slang/slang-ir-loop-inversion.cpp @@ -172,23 +172,52 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) if(p) c1Params.add(i); } + + // Create another break block b2 that will act as the new break block for the + // loop. The original break block b will become the merge point for the outer condition. + // + auto b2 = builder.emitBlock(); + b2->insertBefore(b); + + // Create a copy of the parameters in b. b2 will simply pass these to b. + List b2Params; + for(auto p : b->getParams()) + { + auto q = duplicateToParamWithDecorations(builder, cloneEnv, p); + b2Params.add(q); + } + builder.setInsertInto(b2); + builder.emitBranch(b, b2Params.getCount(), b2Params.getBuffer()); + auto e1 = builder.emitBlock(); e1->insertAfter(c1); - builder.emitBranch(b, c1Params.getCount(), c1Params.getBuffer()); + builder.emitBranch(b2, c1Params.getCount(), c1Params.getBuffer()); c1bUse.set(e1); c1Terminator->afterBlock.set(d); - // Similarly, we have to replace any existing 'break's to break via e1 + + // We create another block e4 to handle other breaks from inside the loop, and + // rewrite existing breaks to jump to it. + // We keep e4 and e1 distinct since after the inversion step, insts in e4 will + // be re-written to use values from the loop entry, while e1 will use values from + // c1. + // + auto e4 = builder.emitBlock(); + e4->insertBefore(c1); + builder.emitBranch(b2, c1Params.getCount(), c1Params.getBuffer()); traverseUses(b, [&](IRUse* u){ auto userBlock = u->getUser()->getParent(); // Restrict this to just those blocks within this loop - if(userBlock != e1 && domTree->dominates(s, userBlock) && !domTree->dominates(b, userBlock)) - u->set(e1); + if(userBlock != e4 && userBlock != e1 && userBlock != b2 && domTree->dominates(s, userBlock) && !domTree->dominates(b, userBlock)) + u->set(e4); }); + // We now have // s: ...1 loop break=b next=c1 + // e4: goto b2 // c1: if x then goto e1 else goto d (merge at d) - // e1: goto b + // e1: goto b2 // d: goto c1 + // b2: goto b // b: ...2 // Duplicate c1 into c2, and using the same cloneEnv, duplicate e1 into e2 @@ -203,27 +232,36 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) builder.setInsertAfter(c2Terminator); const auto newC2Terminator = builder.emitIfElse(c2Terminator->getCondition(), c2Terminator->getTrueBlock(), c2Terminator->getFalseBlock(), b); c2Terminator->removeAndDeallocate(); + // The cloned e2 will branch into b2 by default, rewrite it to branch to b, the correct merge point. + SLANG_ASSERT(cast(e2->getTerminator())->getTargetBlock() == b2); + cast(e2->getTerminator())->block.set(b); + // We now have // s: ...1 loop break=b next=c1 + // e4: goto b2 // c2: if x then goto e2 else goto d (merge at b) // e2: goto b // c1: if x then goto e1 else goto d (merge at d) - // e1: goto b + // e1: goto b2 // d: goto c1 + // b2: goto b // b: ...2 // move the loop instruction to its own block, l const auto l = builder.emitBlock(); l->insertAfter(e2); loop->insertAtEnd(l); + e4->insertBefore(c1); // We now have // s: ...1 no-termiator // c2: if x then goto e2 else goto d (merge at b) // e2: goto b // l: loop break=b next=c1 + // e4: goto b2 // c1: if x then goto e1 else goto d (merge at d) - // e1: goto b + // e1: goto b2 // d: goto c1 + // b2: goto b // b: ...2 // add a new terminator to s. A jump to c2, our outer conditional. retain @@ -238,9 +276,11 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) // c2: if x then goto e2 else goto d (merge at b) // e2: goto b // l: loop break=b next=c1 + // e4: goto b2 // c1: if x then goto e1 else goto d (merge at d) - // e1: goto b + // e1: goto b2 // d: goto c1 + // b2: goto b // b: ...2 // modify c2 to jump to the new loop @@ -251,9 +291,11 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) // c2: if x then goto e2 else goto l (merge at b) // e2: goto b // l: loop break=b next=c1 + // e4: goto b2 // c1: if x then goto e1 else goto d (merge at d) - // e1: goto b + // e1: goto b2 // d: goto c1 + // b2: goto b // b: ...2 // @@ -271,12 +313,12 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) // from c1 also which we will accomplish using a new block, e3, // loop->block.set(d); - loop->breakBlock.set(e1); + loop->breakBlock.set(b2); // TODO: This really upsets a few later passes, why isn't it ok to do given // our "irrelevant continue" condition? // loop->continueBlock.set(loop->getTargetBlock()); SLANG_ASSERT(d->getFirstParam() == nullptr); - c1->insertBefore(b); + c1->insertBefore(b2); e1->insertAfter(c1); List ps; for(const auto p : c1->getParams()) @@ -293,6 +335,8 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) } const auto e3 = builder.emitBlock(); e3->insertAfter(c1); + b2->insertBefore(b); + e4->insertAfter(c1); builder.emitBranch(d, ps.getCount(), ps.getBuffer()); c1dUse.set(e3); c1Terminator->afterBlock.set(e3); @@ -300,11 +344,13 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) // s: ...1, goto c2 // c2: if x then goto e2 else goto l (merge at b) // e2: goto b - // l: loop break=e1 next=d + // l: loop break=b2 next=d // d: goto c1 + // e4: goto b2 // c1: if x then goto e1 else goto e3 (merge at e3) // e3: goto d - // e1: goto b + // e1: goto b2 + // b2: goto b // b: ...2 } -- cgit v1.2.3