diff options
| author | Ellie Hermaszewska <ellieh@nvidia.com> | 2023-08-18 05:57:57 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-17 14:57:57 -0700 |
| commit | 80c8f13e369b0bf0b86d2b19a4902594e6d67e5c (patch) | |
| tree | a7ee0c6eaa286ce3f2fd208b2df8a849fa325287 /source | |
| parent | a0a9c04625d37d44ead80d574131063c6eb75d0d (diff) | |
Be more careful about merge blocks during loop inversion (#3136)
* fix block eater
* Be more careful about merge blocks during loop inversion
* Restrict loop inversion to loops without continue jumps
* Remove multiple back-edges from loops for SPIR-V
* Wiggle cross compile test output
* Make loop-inversion more conservative
* Allow loops on false side in cfg normalization
* Do not set loop continue block during inversion
* Add loop inversion test to failing test list for spirv
* Simplify single use continue blocks in spirv legalization
* wobble expected failure list
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-cfg-norm.cpp | 23 | ||||
| -rw-r--r-- | source/slang/slang-ir-loop-inversion.cpp | 60 | ||||
| -rw-r--r-- | source/slang/slang-ir-loop-unroll.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir-spirv-legalize.cpp | 121 |
4 files changed, 187 insertions, 29 deletions
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp index 0a22e112e..d46fba5e3 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.cpp +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -540,10 +540,27 @@ struct CFGNormalizationPass } // Right now, we only support loops where the loop is on the true side of - // the condition. If we ever encounter the other case, fill in logic to - // flip the condition. + // the condition. If we encounter the other case, flip the condition. // - SLANG_RELEASE_ASSERT(isLoopOnTrueSide); + if(!isLoopOnTrueSide) + { + IRBuilderInsertLocScope locScope{&builder}; + // Invert the cond + builder.setInsertBefore(ifElse); + const auto c = ifElse->getCondition(); + const auto negatedCond = c->getOp() == kIROp_Not + ? c->getOperand(0) + : builder.emitNot(builder.getBoolType(), c); + ifElse->condition.set(negatedCond); + const auto t = ifElse->getTrueBlock(); + const auto f = ifElse->getFalseBlock(); + ifElse->trueBlock.set(f); + ifElse->falseBlock.set(t); + + // Invert our discovered state + std::swap(trueEndPoint, falseEndPoint); + isLoopOnTrueSide = true; + } // Expect atleast one basic block (other than the condition block), in // the loop. diff --git a/source/slang/slang-ir-loop-inversion.cpp b/source/slang/slang-ir-loop-inversion.cpp index 5050ce9d6..8967272dd 100644 --- a/source/slang/slang-ir-loop-inversion.cpp +++ b/source/slang/slang-ir-loop-inversion.cpp @@ -34,11 +34,20 @@ static bool isSmallBlock(IRBlock* c) return true; } +static bool hasIrrelevantContinueBlock(IRLoop* loop) +{ + const auto c = loop->getContinueBlock(); + return c == loop->getTargetBlock() || c->getPredecessors().getCount() <= 1; +} + // Loops are suitable for inversion if: // - The loop jumps to a conditional branch which has the break block as one of // its successors (or a trivial break block which we erase) and the other // successor is empty // - The conditional block is "small", because we will be duplicating it +// - The loop's continue block is irrelevant, because we'll need to change it, +// either by being the loop header already or by having only a single use +// within the loop body static bool isSuitableForInversion(IRLoop* loop) { const auto nextBlock = loop->getTargetBlock(); @@ -52,6 +61,9 @@ static bool isSuitableForInversion(IRLoop* loop) if(!isSmallBlock(nextBlock)) return false; + if(!hasIrrelevantContinueBlock(loop)) + return false; + const auto t = branch->getTrueBlock(); const auto f = branch->getFalseBlock(); const auto a = branch->getAfterBlock(); @@ -98,19 +110,19 @@ static IRParam* duplicateToParamWithDecorations(IRBuilder& builder, IRCloneEnv& // Given // s: ...1 loop break=b next=c1 -// c1: if x then goto b else goto d +// c1: if x then goto b else goto d (merge at d) // d: goto c1 // b: ...2 // // Produce: -// s: ...1 goto c1 -// c1: if x then goto e1 else goto l -// e1: goto b +// s: ...1 goto c2 +// c2: if x then goto e2 else goto l (merge at b) +// e2: goto b // l: loop break=b next=d -// d: goto c2: -// c2: if x then goto e2 else goto e3 +// d: goto c1: +// c1: if x then goto e1 else goto e3 (merge at e3) // e3: goto d -// e2: goto b +// e1: goto b // b: ...2 // // s is the Start block @@ -126,7 +138,7 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) auto domTree = computeDominatorTree(s->getParent()); SLANG_ASSERT(s); const auto c1 = loop->getTargetBlock(); - const auto c1Terminator = as<IRConditionalBranch>(c1->getTerminator()); + const auto c1Terminator = as<IRIfElse>(c1->getTerminator()); SLANG_ASSERT(c1Terminator); const auto b = loop->getBreakBlock(); auto& c1dUse = c1Terminator->getTrueBlock() == b ? c1Terminator->falseBlock : c1Terminator->trueBlock; @@ -164,6 +176,7 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) e1->insertAfter(c1); builder.emitBranch(b, c1Params.getCount(), c1Params.getBuffer()); c1bUse.set(e1); + c1Terminator->afterBlock.set(d); // Similarly, we have to replace any existing 'break's to break via e1 traverseUses(b, [&](IRUse* u){ auto userBlock = u->getUser()->getParent(); @@ -173,7 +186,7 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) }); // We now have // s: ...1 loop break=b next=c1 - // c1: if x then goto e1 else goto d + // c1: if x then goto e1 else goto d (merge at d) // e1: goto b // d: goto c1 // b: ...2 @@ -192,9 +205,9 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) c2Terminator->removeAndDeallocate(); // We now have // s: ...1 loop break=b next=c1 - // c2: if x then goto e2 else goto d + // c2: if x then goto e2 else goto d (merge at b) // e2: goto b - // c1: if x then goto e1 else goto d + // c1: if x then goto e1 else goto d (merge at d) // e1: goto b // d: goto c1 // b: ...2 @@ -205,10 +218,10 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) loop->insertAtEnd(l); // We now have // s: ...1 no-termiator - // c2: if x then goto e2 else goto d + // c2: if x then goto e2 else goto d (merge at b) // e2: goto b // l: loop break=b next=c1 - // c1: if x then goto e1 else goto d + // c1: if x then goto e1 else goto d (merge at d) // e1: goto b // d: goto c1 // b: ...2 @@ -222,10 +235,10 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) builder.emitBranch(c2, as.getCount(), as.getBuffer()); // We now have // s: ...1, goto c2 - // c2: if x then goto e2 else goto d + // c2: if x then goto e2 else goto d (merge at b) // e2: goto b // l: loop break=b next=c1 - // c1: if x then goto e1 else goto d + // c1: if x then goto e1 else goto d (merge at d) // e1: goto b // d: goto c1 // b: ...2 @@ -235,10 +248,10 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) c2dUse.set(l); // We now have // s: ...1, goto c2 - // c2: if x then goto e2 else goto l + // c2: if x then goto e2 else goto l (merge at b) // e2: goto b // l: loop break=b next=c1 - // c1: if x then goto e1 else goto d + // c1: if x then goto e1 else goto d (merge at d) // e1: goto b // d: goto c1 // b: ...2 @@ -248,12 +261,20 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) // conditional, d, as we know that it won't break out of the loop on the // first iteration // + // Since we're only here if the continue block is irrelevant (either the + // target block already or has a single predecessor) we can set it to the + // loop header. + // // Beyond just retargeting the loop instruction, we need to make sure any // parameters the loop instruction is passing to c1 are instead passed to // 'd', and because we've added parameters to 'd' we need to forward them // from c1 also which we will accomplish using a new block, e3, + // loop->block.set(d); loop->breakBlock.set(e1); + // TODO: This really upsets a few later passes, why isn't it ok to do given + // our "irrelevant continue" condition? + // loop->continueBlock.set(loop->getTargetBlock()); SLANG_ASSERT(d->getFirstParam() == nullptr); c1->insertBefore(b); e1->insertAfter(c1); @@ -274,13 +295,14 @@ static void invertLoop(IRBuilder& builder, IRLoop* loop) e3->insertAfter(c1); builder.emitBranch(d, ps.getCount(), ps.getBuffer()); c1dUse.set(e3); + c1Terminator->afterBlock.set(e3); // We now have the desired output // s: ...1, goto c2 - // c2: if x then goto e2 else goto l + // c2: if x then goto e2 else goto l (merge at b) // e2: goto b // l: loop break=e1 next=d // d: goto c1 - // c1: if x then goto e1 else goto e3 + // c1: if x then goto e1 else goto e3 (merge at e3) // e3: goto d // e1: goto b // b: ...2 diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp index 36a2abbfc..c9ac4191b 100644 --- a/source/slang/slang-ir-loop-unroll.cpp +++ b/source/slang/slang-ir-loop-unroll.cpp @@ -174,11 +174,8 @@ static void _foldAndSimplifyLoopIteration( auto b = clonedBlocks[i]; if (b) { - if (i != insertIndex) - { - clonedBlocks[insertIndex] = b; - insertIndex++; - } + clonedBlocks[insertIndex] = b; + insertIndex++; } } clonedBlocks.setCount(insertIndex); @@ -554,11 +551,12 @@ void eliminateContinueBlocks(IRModule* module, IRLoop* loopInst) // where a continue is replaced with a "break" into breakableRegionBreakBlock. // - if (loopInst->getContinueBlock() == loopInst->getTargetBlock()) + auto continueBlock = loopInst->getContinueBlock(); + + if (continueBlock == loopInst->getTargetBlock()) return; // If the continue block is not reachable, remove it. - auto continueBlock = loopInst->getContinueBlock(); if (continueBlock && !continueBlock->hasMoreThanOneUse()) { loopInst->continueBlock.set(loopInst->getTargetBlock()); diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 6afdcf102..637592357 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -3,6 +3,7 @@ #include "slang-ir-glsl-legalize.h" +#include "slang-ir-clone.h" #include "slang-ir.h" #include "slang-ir-insts.h" #include "slang-emit-base.h" @@ -372,6 +373,123 @@ struct SPIRVLegalizationContext : public SourceEmitterBase addUsersToWorkList(ptrType); } + void processLoop(IRLoop* loop) + { + + // 2.11.1. Rules for Structured Control-flow Declarations + // Structured control flow declarations must satisfy the following + // rules: + // - the merge block declared by a header block must not be a merge + // block declared by any other header block + // - each header block must strictly structurally dominate its merge + // block + // - all back edges must branch to a loop header, with each loop + // header having exactly one back edge branching to it + // - for a given loop header, its merge block, OpLoopMerge Continue + // Target, and corresponding back-edge block: + // - the Continue Target and merge block must be different blocks + // - the loop header must structurally dominate the Continue + // Target + // - the Continue Target must structurally dominate the back-edge + // block + // - the back-edge block must structurally post dominate the + // Continue Target + + // If the continue block has only a single predecessor, pretend like it + // is just ordinary control flow + // + // TODO: could this fail in cases like this, where it had a single + // predecessor, but it's still nested inside a region? + // do{ + // if(x) + // continue; + // unreachable + // } while(foo) + const auto t = loop->getTargetBlock(); + auto c = loop->getContinueBlock(); + if(c->getPredecessors().getCount() <= 1) + { + c = t; + loop->continueBlock.set(c); + } + + // Our IR allows multiple back-edges to a loop header if this is also + // the loop continue block. SPIR-V does not so replace them with a + // single intermediate block + if(c == t) + { + // Subtract one predecessor for the loop entry + const auto numBackEdges = c->getPredecessors().getCount() - 1; + + // If we have multiple back-edges, make a new block at the end of + // the loop to be the new continue block which jumps straight to + // the loop header. + // + // If we have a single back-edge, we still may need to perform this + // transformation to make sure that the back-edge block + // structurally post-dominates the continue target. For example + // consider the loop: + // + // int i = 0; + // while(true) + // if(foo()) break; + // + // If we translate this to + // loop target=t break=b, continue=t + // t: if foo goto x else goto y + // x: goto b -- break + // y: goto t + // b: ... + // + // The back edge block, y, does not post-dominate the continue target, t. + // + // So we transform this to: + // + // loop target=t break=b, continue=c + // t: if foo goto x else goto y + // x: goto b -- break + // y: goto c + // c: goto t + // b: ... + // + // Now the back edge block and the continue target are one and the + // same, so the condition trivially holds. + // + // TODO: We don't need to always perform this, we could replace the + // below condition with `numBackEdges > 1 || + // !postDominates(backJumpingBlock, c)` + if(numBackEdges > 0) + { + IRBuilder builder(m_sharedContext->m_irModule); + builder.setInsertInto(loop->getParent()); + IRCloneEnv cloneEnv; + cloneEnv.squashChildrenMapping = true; + + // Insert a new continue block at the end of the loop + const auto newContinueBlock = builder.emitBlock(); + newContinueBlock->insertBefore(loop->getBreakBlock()); + + // This block simply branches to the loop header, forwarding + // any params + List<IRInst*> ps; + for(const auto p : c->getParams()) + { + const auto q = cast<IRParam>(cloneInst(&cloneEnv, &builder, p)); + newContinueBlock->addParam(q); + ps.add(q); + } + // Replace all jumps to our loop header/old continue block + c->replaceUsesWith(newContinueBlock); + + // Restore the target block + loop->block.set(t); + + // Branch to the target in our new continue block + builder.emitBranch(t, ps.getCount(), ps.getBuffer()); + } + } + } + void processModule() { addToWorkList(m_module->getModuleInst()); @@ -415,6 +533,9 @@ struct SPIRVLegalizationContext : public SourceEmitterBase case kIROp_HLSLRWStructuredBufferType: processStructuredBufferType(as<IRHLSLStructuredBufferTypeBase>(inst)); break; + case kIROp_loop: + processLoop(as<IRLoop>(inst)); + break; default: for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) { |
