diff options
| author | Yong He <yonghe@outlook.com> | 2022-10-10 15:59:45 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-10 15:59:45 -0700 |
| commit | 768e62f6c7541439e2edc18dad5fb3846d2e05f9 (patch) | |
| tree | 8c68424ee65905b77d3ecb4c7659c5fdcc6ab948 /source | |
| parent | 8487678d6504459935fec07886d2e53ed688ac2f (diff) | |
Support multi-level break + single-return conversion + general inline. (#2436)
* Support multi-level break.
* Single return.
* Add test for inlining `void` return-type functions.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-emit.cpp | 14 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-dominators.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir-eliminate-multilevel-break.cpp | 308 | ||||
| -rw-r--r-- | source/slang/slang-ir-eliminate-multilevel-break.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir-eliminate-phis.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-ir-eliminate-phis.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-inline.cpp | 47 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-pass-base.h | 26 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-simplify-cfg.cpp | 18 | ||||
| -rw-r--r-- | source/slang/slang-ir-simplify-cfg.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-single-return.cpp | 103 | ||||
| -rw-r--r-- | source/slang/slang-ir-single-return.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-ir-ssa.cpp | 3 |
15 files changed, 529 insertions, 56 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index a8d3390f0..a916d0d63 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -13,6 +13,7 @@ #include "slang-ir-dll-export.h" #include "slang-ir-dll-import.h" #include "slang-ir-eliminate-phis.h" +#include "slang-ir-eliminate-multilevel-break.h" #include "slang-ir-entry-point-uniforms.h" #include "slang-ir-entry-point-raw-ptr-params.h" #include "slang-ir-explicit-global-context.h" @@ -48,7 +49,6 @@ #include "slang-ir-wrap-structured-buffers.h" #include "slang-ir-liveness.h" #include "slang-ir-glsl-liveness.h" - #include "slang-legalize-types.h" #include "slang-lower-to-ir.h" #include "slang-mangle.h" @@ -376,6 +376,8 @@ Result linkAndOptimizeIR( if (sink->getErrorCount() != 0) return SLANG_FAIL; + eliminateMultiLevelBreak(irModule); + // TODO(DG): There are multiple DCE steps here, which need to be changed // so that they don't just throw out any non-entry point code // Debugging code for IR transformations... @@ -784,7 +786,7 @@ Result linkAndOptimizeIR( { // We only want to accumulate locations if liveness tracking is enabled. - eliminatePhis(codeGenContext, livenessMode, irModule); + eliminatePhis(livenessMode, irModule); #if 0 dumpIRIfEnabled(codeGenContext, irModule, "PHIS ELIMINATED"); #endif @@ -934,7 +936,7 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr linkingAndOptimizationOptions.sourceEmitter = sourceEmitter; - switch( sourceLanguage ) + switch (sourceLanguage) { default: break; @@ -962,7 +964,11 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr // TODO: do we want to emit directly from IR, or translate the // IR back into AST for emission? #if 0 - dumpIR(compileRequest, irModule, "PRE-EMIT"); + { + StringBuilder sb; + StringWriter writer(&sb, Slang::WriterFlag::AutoFlush); + dumpIR(irModule, getIRDumpOptions(), sourceManager, &writer); + } #endif sourceEmitter->emitModule(irModule, sink); } diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 554a407ee..5eee13d5e 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -5,6 +5,7 @@ #include "slang-ir-insts.h" #include "slang-ir-clone.h" #include "slang-ir-dce.h" +#include "slang-ir-eliminate-phis.h" namespace Slang { @@ -1308,7 +1309,8 @@ struct JVPDerivativeContext IRFunc* emitJVPFunction(IRBuilder* builder, IRFunc* primalFn) { - + eliminatePhisInFunc(LivenessMode::Disabled, module, primalFn); + builder->setInsertBefore(primalFn->getNextInst()); auto jvpFn = builder->createFunc(); diff --git a/source/slang/slang-ir-dominators.cpp b/source/slang/slang-ir-dominators.cpp index bae9f772f..72b156228 100644 --- a/source/slang/slang-ir-dominators.cpp +++ b/source/slang/slang-ir-dominators.cpp @@ -305,6 +305,18 @@ void computePostorder(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder) PostorderComputationContext context; context.order = &outOrder; context.walk(code); + + // Append unvisited blocks (unreachable blocks) to the begining of postOrder. + List<IRBlock*> prefix; + for (auto block : code->getBlocks()) + { + if (!context.visited.Contains(block)) + { + prefix.add(block); + } + } + prefix.addRange(outOrder); + outOrder = _Move(prefix); } // diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp new file mode 100644 index 000000000..269b74aad --- /dev/null +++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp @@ -0,0 +1,308 @@ +// slang-ir-eliminate-multilevel-break.cpp +#include "slang-ir-eliminate-multilevel-break.h" +#include "slang-ir.h" +#include "slang-ir-clone.h" +#include "slang-ir-insts.h" +#include "slang-ir-eliminate-phis.h" +#include "slang-ir-dominators.h" + +namespace Slang +{ + +struct EliminateMultiLevelBreakContext +{ + IRModule* irModule; + + struct LoopInfo : RefObject + { + LoopInfo* parent = nullptr; + int level = 0; + IRLoop* loopInst; + List<IRBlock*> blocks; + HashSet<IRBlock*> blockSet; + List<RefPtr<LoopInfo>> childLoops; + IRBlock* getBreakBlock() { return loopInst->getBreakBlock(); } + template<typename Func> + void forEach(const Func& f) + { + f(this); + for (auto child : childLoops) + child->forEach(f); + } + }; + + struct MultiLevelBreakInfo + { + IRUnconditionalBranch* breakInst; + LoopInfo* currentLoop; + LoopInfo* breakTargetLoop; + }; + + struct FuncContext + { + List<RefPtr<LoopInfo>> loops; + HashSet<IRBlock*> breakBlocks; + Dictionary<IRBlock*, LoopInfo*> mapBreakBlockToLoop; + Dictionary<IRBlock*, LoopInfo*> mapBlockToLoop; + HashSet<IRBlock*> processedBlocks; + List<MultiLevelBreakInfo> multiLevelBreaks; + + void collectLoopBlocks(LoopInfo& info) + { + auto startBlock = info.loopInst->getTargetBlock(); + info.blockSet.Add(startBlock); + info.blocks.add(startBlock); + breakBlocks.Add(info.loopInst->getBreakBlock()); + for (Index i = 0; i < info.blocks.getCount(); i++) + { + auto block = info.blocks[i]; + if (!processedBlocks.Add(block)) + continue; + if (auto loopInst = as<IRLoop>(block->getTerminator())) + { + RefPtr<LoopInfo> childLoop = new LoopInfo(); + childLoop->loopInst = loopInst; + childLoop->parent = &info; + childLoop->level = info.level + 1; + collectLoopBlocks(*childLoop); + info.childLoops.add(childLoop); + block = loopInst->getBreakBlock(); + if (info.blockSet.Add(block)) + { + info.blocks.add(block); + } + continue; + } + for (auto succ : block->getSuccessors()) + { + if (!breakBlocks.Contains(succ)) + { + if (info.blockSet.Add(succ)) + info.blocks.add(succ); + } + } + } + breakBlocks.Remove(info.loopInst->getBreakBlock()); + } + + void gatherInfo(IRGlobalValueWithCode* func) + { + for (auto block : func->getBlocks()) + { + if (processedBlocks.Contains(block)) + continue; + auto terminator = block->getTerminator(); + if (auto loop = as<IRLoop>(terminator)) + { + RefPtr<LoopInfo> loopInfo = new LoopInfo(); + loopInfo->loopInst = loop; + collectLoopBlocks(*loopInfo); + loops.add(loopInfo); + } + } + + for (auto& l : loops) + { + l->forEach( + [&](LoopInfo* loop) + { + mapBreakBlockToLoop.Add(loop->loopInst->getBreakBlock(), loop); + for (auto block : loop->blocks) + mapBlockToLoop.Add(block, loop); + }); + } + + for (auto block : func->getBlocks()) + { + auto terminator = block->getTerminator(); + if (auto branch = as<IRUnconditionalBranch>(terminator)) + { + if (as<IRLoop>(terminator)) + continue; + LoopInfo* breakLoop = nullptr; + LoopInfo* currentLoop = nullptr; + if (!mapBreakBlockToLoop.TryGetValue(branch->getTargetBlock(), breakLoop)) + continue; + if (mapBlockToLoop.TryGetValue(block, currentLoop)) + { + if (currentLoop != breakLoop) + { + MultiLevelBreakInfo breakInfo; + breakInfo.breakInst = branch; + breakInfo.breakTargetLoop = breakLoop; + breakInfo.currentLoop = currentLoop; + multiLevelBreaks.add(breakInfo); + } + } + } + } + } + }; + + void processFunc(IRGlobalValueWithCode* func) + { + // If func does not have any multi-level breaks, return. + { + FuncContext funcInfo; + funcInfo.gatherInfo(func); + + if (funcInfo.multiLevelBreaks.getCount() == 0) + return; + } + + // To make things easy, eliminate Phis before perform transformations. + eliminatePhisInFunc(LivenessMode::Disabled, irModule, func); + + // Before modifying the cfg, we gather all required info from the existing cfg. + FuncContext funcInfo; + funcInfo.gatherInfo(func); + + if (funcInfo.multiLevelBreaks.getCount() == 0) + return; + + SharedIRBuilder sharedBuilder; + sharedBuilder.init(irModule); + IRBuilder builder(&sharedBuilder); + builder.setInsertInto(func); + + OrderedHashSet<LoopInfo*> skippedOverLoops; + auto unreachableBlock = builder.emitBlock(); + builder.setInsertInto(unreachableBlock); + builder.emitUnreachable(); + builder.setInsertInto(func); + + // Rewrite multi-level breaks with single level break + target level argument. + for (auto breakInfo : funcInfo.multiLevelBreaks) + { + auto loop = breakInfo.currentLoop; + while (loop) + { + skippedOverLoops.Add(loop); + loop = loop->parent; + if (loop == breakInfo.breakTargetLoop) + break; + } + builder.setInsertBefore(breakInfo.breakInst); + auto targetLevelInst = builder.getIntValue(builder.getIntType(), breakInfo.breakTargetLoop->level); + builder.emitBranch(breakInfo.currentLoop->getBreakBlock(), 1, &targetLevelInst); + breakInfo.breakInst->removeAndDeallocate(); + } + + // Rewrite skipped-over break blocks to accept a target level argument. + builder.setInsertInto(func); + OrderedDictionary<IRBlock*, int> mapNewBreakBlockToLoopLevel; + for (auto skippedLoop : skippedOverLoops) + { + auto breakBlock = skippedLoop->getBreakBlock(); + + // The existing break block cannot have parameters. We assume that PHI-elimination is + // run before this pass. + SLANG_RELEASE_ASSERT(breakBlock->getFirstParam() == nullptr); + + // The new CFG structure will be: newBreakBlock --> newBreakBodyBlock { IfElse (-->oldBreakBlock, -->outerBreakBlock) } + // `newBreakBlock` defines the `IRParam` for the break target, then immediately jumps to `newBreakBodyBlock` for the actual branch. We need this + // separation to avoid introducing critical edge to the CFG (blocks cannot have more + // than 1 predecessors and more than 1 successors at the same time). + auto jumpToOuterBlock = builder.createBlock(); + auto newBreakBodyBlock = builder.createBlock(); + auto newBreakBlock = builder.createBlock(); + newBreakBlock->insertBefore(breakBlock); + newBreakBodyBlock->insertAfter(breakBlock); + jumpToOuterBlock->insertAfter(newBreakBlock); + mapNewBreakBlockToLoopLevel[newBreakBlock] = skippedLoop->level; + breakBlock->replaceUsesWith(newBreakBlock); + builder.setInsertInto(newBreakBlock); + auto targetLevelParam = builder.emitParam(builder.getIntType()); + builder.emitBranch(newBreakBodyBlock); + builder.setInsertInto(newBreakBodyBlock); + auto levelNeq = builder.emitNeq(targetLevelParam, builder.getIntValue(builder.getIntType(), skippedLoop->level)); + builder.emitIfElse(levelNeq, jumpToOuterBlock, breakBlock, unreachableBlock); + builder.setInsertInto(jumpToOuterBlock); + if (skippedOverLoops.Contains(skippedLoop->parent)) + { + builder.emitBranch(skippedLoop->parent->getBreakBlock(), 1, (IRInst**)&targetLevelParam); + } + else + { + builder.emitBranch(skippedLoop->parent->getBreakBlock()); + } + } + + // Once we have rewritten loops' break blocks with additional targetLevel parameter, all + // original branches into that block without a parameter will now need to provide a default + // value equal to the level of its corresponding loop. + for (auto breakBlockKV : mapNewBreakBlockToLoopLevel) + { + auto breakBlock = breakBlockKV.Key; + auto level = breakBlockKV.Value; + IRInst* levelInst = nullptr; + List<IRUse*> uses; + for (auto use = breakBlock->firstUse; use; use = use->nextUse) + { + uses.add(use); + } + for (auto use : uses) + { + auto user = use->getUser(); + switch (user->getOp()) + { + case kIROp_conditionalBranch: + case kIROp_ifElse: + case kIROp_Switch: + // For complex branches, insert an intermediate block so we can specify the + // target index argument. + { + builder.setInsertInto(func); + auto tmpBlock = builder.createBlock(); + tmpBlock->insertAfter(user->getParent()); + builder.setInsertInto(tmpBlock); + if (!levelInst) + levelInst = builder.getIntValue(builder.getIntType(), level); + builder.emitBranch(breakBlock, 1, &levelInst); + use->set(tmpBlock); + } + break; + case kIROp_loop: + // Ignore. + continue; + case kIROp_unconditionalBranch: + { + auto originalBranch = as<IRUnconditionalBranch>(user); + if (originalBranch->getOperandCount() == 1) + { + builder.setInsertBefore(originalBranch); + if (!levelInst) + levelInst = builder.getIntValue(builder.getIntType(), level); + builder.emitBranch(breakBlock, 1, &levelInst); + originalBranch->removeAndDeallocate(); + } + } + break; + } + + } + } + } +}; + +void eliminateMultiLevelBreak(IRModule* irModule) +{ + EliminateMultiLevelBreakContext context; + context.irModule = irModule; + for (auto globalInst : irModule->getGlobalInsts()) + { + if (auto codeInst = as<IRGlobalValueWithCode>(globalInst)) + { + context.processFunc(codeInst); + } + } +} + +void eliminateMultiLevelBreakForFunc(IRModule* irModule, IRGlobalValueWithCode* func) +{ + EliminateMultiLevelBreakContext context; + context.irModule = irModule; + context.processFunc(func); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-eliminate-multilevel-break.h b/source/slang/slang-ir-eliminate-multilevel-break.h new file mode 100644 index 000000000..f6210bc12 --- /dev/null +++ b/source/slang/slang-ir-eliminate-multilevel-break.h @@ -0,0 +1,12 @@ +// slang-ir-eliminate-multi-level-break.h +#pragma once + +namespace Slang +{ + struct IRModule; + struct IRGlobalValueWithCode; + + void eliminateMultiLevelBreak(IRModule* module); + void eliminateMultiLevelBreakForFunc(IRModule* module, IRGlobalValueWithCode* func); + +} diff --git a/source/slang/slang-ir-eliminate-phis.cpp b/source/slang/slang-ir-eliminate-phis.cpp index 9aac7de3a..c4f0b8b9d 100644 --- a/source/slang/slang-ir-eliminate-phis.cpp +++ b/source/slang/slang-ir-eliminate-phis.cpp @@ -64,15 +64,13 @@ struct PhiEliminationContext // At the top level, our pas needs to have access to the IR module, and needs // a builder it can use to generate code. // - CodeGenContext* m_codeGenContext = nullptr; IRModule* m_module = nullptr; SharedIRBuilder m_sharedBuilder; IRBuilder m_builder; LivenessMode m_livenessMode; - PhiEliminationContext(CodeGenContext* codeGenContext, LivenessMode livenessMode, IRModule* module) - : m_codeGenContext(codeGenContext) - , m_module(module) + PhiEliminationContext(LivenessMode livenessMode, IRModule* module) + : m_module(module) , m_sharedBuilder(module) , m_builder(m_sharedBuilder) , m_livenessMode(livenessMode) @@ -900,10 +898,16 @@ struct PhiEliminationContext } }; -void eliminatePhis(CodeGenContext* codeGenContext, LivenessMode livenessMode, IRModule* module) +void eliminatePhis(LivenessMode livenessMode, IRModule* module) { - PhiEliminationContext context(codeGenContext, livenessMode, module); + PhiEliminationContext context(livenessMode, module); context.eliminatePhisInModule(); } +void eliminatePhisInFunc(LivenessMode livenessMode, IRModule* module, IRGlobalValueWithCode* func) +{ + PhiEliminationContext context(livenessMode, module); + context.eliminatePhisInFunc(func); +} + } diff --git a/source/slang/slang-ir-eliminate-phis.h b/source/slang/slang-ir-eliminate-phis.h index b21363cc8..ff81d5b38 100644 --- a/source/slang/slang-ir-eliminate-phis.h +++ b/source/slang/slang-ir-eliminate-phis.h @@ -15,5 +15,7 @@ namespace Slang /// are not themselves based on an SSA representation. /// /// If livenessMode is enabled LiveRangeStarts will be inserted into the module. - void eliminatePhis(CodeGenContext* context, LivenessMode livenessMode, IRModule* module); + void eliminatePhis(LivenessMode livenessMode, IRModule* module); + + void eliminatePhisInFunc(LivenessMode livenessMode, IRModule* module, IRGlobalValueWithCode* func); } diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index 77691d169..70d98d10c 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -2,6 +2,7 @@ #include "slang-ir-inline.h" #include "slang-ir-ssa-simplification.h" +#include "slang-ir-single-return.h" // This file provides general facilities for inlining function calls. @@ -320,48 +321,9 @@ struct InliningPassBase // For now, our inlining pass only handles the case where // the callee is a "single-return" function, which means the callee // function contains only one return at the end of the body. - if (isSingleReturnFunc(callee)) - { - inlineSingleReturnFuncBody(callSite, &env, &builder); - } - else - { - // Running into any non-trivial function to be inlined - // is currently an internal compiler error. - // - SLANG_UNIMPLEMENTED_X("general case of inlining"); - } - } - - /// Check if `func` represents a simple callee that has only a single `return`. - bool isSingleReturnFunc(IRFunc* func) - { - auto firstBlock = func->getFirstBlock(); - - // If the body block is decorated (for some reason), then the function is non-trivial. - // - if( firstBlock->getFirstDecoration() ) - return false; - - // If the body has more than one returns, we cannot inline it now. - bool returnFound = false; - for (auto block : func->getBlocks()) - { - for (auto inst : block->getChildren()) - { - if (inst->getOp() == kIROp_Return) - { - // If the return is not at the end of the block, we cannot handle it. - if (inst != block->getTerminator()) - return false; - // If there is already a return found, this function cannot be simple. - if (returnFound) - return false; - returnFound = true; - } - } - } - return true; + + convertFuncToSingleReturnForm(m_module, callSite.callee); + inlineSingleReturnFuncBody(callSite, &env, &builder); } // When instructions are cloned, with cloneInst no sourceLoc information is copied over by default. @@ -527,6 +489,7 @@ struct InliningPassBase // call->removeAndDeallocate(); } + }; /// An inlining pass that inlines calls to `[unsafeForceInlineEarly]` functions diff --git a/source/slang/slang-ir-inst-pass-base.h b/source/slang/slang-ir-inst-pass-base.h index ec4506272..2e251e46d 100644 --- a/source/slang/slang-ir-inst-pass-base.h +++ b/source/slang/slang-ir-inst-pass-base.h @@ -56,6 +56,32 @@ namespace Slang } } + template <typename InstType, typename Func> + void processChildInstsOfType(IROp instOp, IRInst* parent, const Func& f) + { + workList.clear(); + workListSet.Clear(); + + addToWorkList(parent); + + while (workList.getCount() != 0) + { + IRInst* inst = workList.getLast(); + + workList.removeLast(); + workListSet.Remove(inst); + if (inst->getOp() == instOp) + { + f(as<InstType>(inst)); + } + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + addToWorkList(child); + } + } + } + template <typename Func> void processAllInsts(const Func& f) { diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 76dbe55f9..e09d60c3e 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1543,6 +1543,8 @@ struct IRContinue : IRUnconditionalBranch {}; // about the loop structure: struct IRLoop : IRUnconditionalBranch { + IR_LEAF_ISA(loop); + // The next block after the loop, which // is where we expect control flow to // re-converge, and also where a diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index af0e7c0ce..db0274d32 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -6,7 +6,7 @@ namespace Slang { -bool processFunc(IRFunc* func) +bool processFunc(IRGlobalValueWithCode* func) { auto firstBlock = func->getFirstBlock(); if (!firstBlock) @@ -23,6 +23,15 @@ bool processFunc(IRFunc* func) workList.fastRemoveAt(0); while (block) { + if (auto loop = as<IRLoop>(block->getTerminator())) + { + auto continueBlock = loop->getContinueBlock(); + if (continueBlock && !continueBlock->hasMoreThanOneUse()) + { + loop->continueBlock.set(loop->getTargetBlock()); + continueBlock->removeAndDeallocate(); + } + } // If `block` does not end with an unconditional branch, bail. if (block->getTerminator()->getOp() != kIROp_unconditionalBranch) break; @@ -33,6 +42,8 @@ bool processFunc(IRFunc* func) // merge point in CFG. Such blocks will have more than one use. if (successor->hasMoreThanOneUse()) break; + if (block->hasMoreThanOneUse()) + break; changed = true; Index paramIndex = 0; auto inst = successor->getFirstDecorationOrChild(); @@ -79,4 +90,9 @@ bool simplifyCFG(IRModule* module) return changed; } +bool simplifyCFG(IRGlobalValueWithCode* func) +{ + return processFunc(func); +} + } // namespace Slang diff --git a/source/slang/slang-ir-simplify-cfg.h b/source/slang/slang-ir-simplify-cfg.h index 3d8729274..6bfa6e2bf 100644 --- a/source/slang/slang-ir-simplify-cfg.h +++ b/source/slang/slang-ir-simplify-cfg.h @@ -4,9 +4,13 @@ namespace Slang { struct IRModule; + struct IRGlobalValueWithCode; /// Simplifies control flow graph by merging basic blocks that /// forms a simple linear chain. /// Returns true if changed. bool simplifyCFG(IRModule* module); + + bool simplifyCFG(IRGlobalValueWithCode* func); + } diff --git a/source/slang/slang-ir-single-return.cpp b/source/slang/slang-ir-single-return.cpp new file mode 100644 index 000000000..a00066556 --- /dev/null +++ b/source/slang/slang-ir-single-return.cpp @@ -0,0 +1,103 @@ +// slang-ir-single-return.cpp +#include "slang-ir-single-return.h" +#include "slang-ir.h" +#include "slang-ir-clone.h" +#include "slang-ir-insts.h" +#include "slang-ir-inst-pass-base.h" +#include "slang-ir-eliminate-multilevel-break.h" +#include "slang-ir-simplify-cfg.h" + +namespace Slang +{ + +struct SingleReturnContext : public InstPassBase +{ + SingleReturnContext(IRModule* inModule) + : InstPassBase(inModule) + {} + void processFunc(IRGlobalValueWithCode* func) + { + SharedIRBuilder sharedBuilder; + sharedBuilder.init(module); + IRBuilder builder(&sharedBuilder); + simplifyCFG(func); + + // We make use of the `eliminate-multi-level-break` pass to implement the transformation. + // To be able to do that, we need to prepare `func` so that the entire function body + // is wrapped in a trivial loop and turn all `return`s into `break`s out of the outter most + // loop. + builder.setInsertInto(func); + auto breakBlock = builder.emitBlock(); + auto returnBlock = builder.emitBlock(); + builder.setInsertInto(breakBlock); + auto resultType = as<IRFuncType>(func->getDataType())->getResultType(); + + IRInst* retValParam = nullptr; + if (resultType->getOp() != kIROp_VoidType) + { + retValParam = builder.emitParam(resultType); + } + builder.emitBranch(returnBlock); + + auto originalStartBlock = func->getFirstBlock(); + auto loopHeaderBlock = builder.createBlock(); + loopHeaderBlock->insertBefore(originalStartBlock); + builder.setInsertInto(loopHeaderBlock); + + // Move all params into `loopHeaderBlock`. + List<IRParam*> params; + for (auto param : originalStartBlock->getParams()) + { + params.add(param); + } + for (auto param : params) + { + loopHeaderBlock->addParam(param); + } + auto loopInst = (IRLoop*)builder.emitLoop(originalStartBlock, breakBlock, originalStartBlock); + + // Now replace all return insts as break insts. + processChildInstsOfType<IRReturn>(kIROp_Return, func, [&](IRReturn* returnInst) + { + IRInst* retVal = nullptr; + if (returnInst->getOperandCount() == 0) + retVal = builder.getVoidValue(); + else + retVal = returnInst->getVal(); + builder.setInsertBefore(returnInst); + if (resultType->getOp()==kIROp_VoidType) + { + builder.emitBranch(breakBlock); + } + else + { + builder.emitBranch(breakBlock, 1, &retVal); + } + returnInst->removeAndDeallocate(); + }); + + builder.setInsertInto(returnBlock); + if (retValParam) + builder.emitReturn(retValParam); + else + builder.emitReturn(); + + // Now run the multi-level-break pass. + eliminateMultiLevelBreakForFunc(module, func); + + // Now remove the trivial loop header. + SLANG_RELEASE_ASSERT(loopInst->getContinueBlock() == loopInst->getTargetBlock()); + auto targetBlock = loopInst->getTargetBlock(); + for (auto param : params) + targetBlock->addParam(param); + loopHeaderBlock->removeAndDeallocate(); + } +}; + +void convertFuncToSingleReturnForm(IRModule* irModule, IRGlobalValueWithCode* func) +{ + SingleReturnContext context(irModule); + context.processFunc(func); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-single-return.h b/source/slang/slang-ir-single-return.h new file mode 100644 index 000000000..2ddfa280b --- /dev/null +++ b/source/slang/slang-ir-single-return.h @@ -0,0 +1,12 @@ +// slang-ir-single-return.h +#pragma once + +namespace Slang +{ + struct IRModule; + struct IRGlobalValueWithCode; + + // Convert the CFG of `func` to have only a single `return` at the end. + void convertFuncToSingleReturnForm(IRModule* module, IRGlobalValueWithCode* func); + +} diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index 797fcb25c..a496db3a8 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -531,12 +531,13 @@ IRInst* addPhiOperands( auto block = blockInfo->block; List<IRInst*> operandValues; + auto predecessorCount = block->getPredecessors().getCount(); for (auto predBlock : block->getPredecessors()) { // Precondition: if we have multiple predecessors, then // each must have only one successor (no critical edges). // - SLANG_ASSERT(predBlock->getSuccessors().getCount() == 1); + SLANG_RELEASE_ASSERT(predecessorCount <= 1 || predBlock->getSuccessors().getCount() == 1); auto predInfo = *context->blockInfos.TryGetValue(predBlock); |
