diff options
| author | Yong He <yonghe@outlook.com> | 2022-10-10 15:59:45 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-10 15:59:45 -0700 |
| commit | 768e62f6c7541439e2edc18dad5fb3846d2e05f9 (patch) | |
| tree | 8c68424ee65905b77d3ecb4c7659c5fdcc6ab948 | |
| parent | 8487678d6504459935fec07886d2e53ed688ac2f (diff) | |
Support multi-level break + single-return conversion + general inline. (#2436)
* Support multi-level break.
* Single return.
* Add test for inlining `void` return-type functions.
Co-authored-by: Yong He <yhe@nvidia.com>
25 files changed, 726 insertions, 113 deletions
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj index 9e343f197..7436f319e 100644 --- a/build/visual-studio/slang/slang.vcxproj +++ b/build/visual-studio/slang/slang.vcxproj @@ -345,6 +345,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClInclude Include="..\..\..\source\slang\slang-ir-dll-export.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-dll-import.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-dominators.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-eliminate-multilevel-break.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-eliminate-phis.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-entry-point-pass.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-entry-point-raw-ptr-params.h" />
@@ -386,6 +387,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClInclude Include="..\..\..\source\slang\slang-ir-restructure.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-sccp.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-simplify-cfg.h" />
+ <ClInclude Include="..\..\..\source\slang\slang-ir-single-return.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-arrays.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-buffer-load-arg.h" />
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-dispatch.h" />
@@ -512,6 +514,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClCompile Include="..\..\..\source\slang\slang-ir-dll-export.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-dll-import.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-dominators.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-eliminate-multilevel-break.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-eliminate-phis.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-entry-point-pass.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-entry-point-raw-ptr-params.cpp" />
@@ -551,6 +554,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla <ClCompile Include="..\..\..\source\slang\slang-ir-restructure.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-sccp.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-simplify-cfg.cpp" />
+ <ClCompile Include="..\..\..\source\slang\slang-ir-single-return.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-arrays.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-buffer-load-arg.cpp" />
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-dispatch.cpp" />
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters index 1115360d5..02ef9de2e 100644 --- a/build/visual-studio/slang/slang.vcxproj.filters +++ b/build/visual-studio/slang/slang.vcxproj.filters @@ -168,6 +168,9 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-dominators.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-eliminate-multilevel-break.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-eliminate-phis.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -291,6 +294,9 @@ <ClInclude Include="..\..\..\source\slang\slang-ir-simplify-cfg.h">
<Filter>Header Files</Filter>
</ClInclude>
+ <ClInclude Include="..\..\..\source\slang\slang-ir-single-return.h">
+ <Filter>Header Files</Filter>
+ </ClInclude>
<ClInclude Include="..\..\..\source\slang\slang-ir-specialize-arrays.h">
<Filter>Header Files</Filter>
</ClInclude>
@@ -665,6 +671,9 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-dominators.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-eliminate-multilevel-break.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-eliminate-phis.cpp">
<Filter>Source Files</Filter>
</ClCompile>
@@ -782,6 +791,9 @@ <ClCompile Include="..\..\..\source\slang\slang-ir-simplify-cfg.cpp">
<Filter>Source Files</Filter>
</ClCompile>
+ <ClCompile Include="..\..\..\source\slang\slang-ir-single-return.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
<ClCompile Include="..\..\..\source\slang\slang-ir-specialize-arrays.cpp">
<Filter>Source Files</Filter>
</ClCompile>
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index a8d3390f0..a916d0d63 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -13,6 +13,7 @@ #include "slang-ir-dll-export.h" #include "slang-ir-dll-import.h" #include "slang-ir-eliminate-phis.h" +#include "slang-ir-eliminate-multilevel-break.h" #include "slang-ir-entry-point-uniforms.h" #include "slang-ir-entry-point-raw-ptr-params.h" #include "slang-ir-explicit-global-context.h" @@ -48,7 +49,6 @@ #include "slang-ir-wrap-structured-buffers.h" #include "slang-ir-liveness.h" #include "slang-ir-glsl-liveness.h" - #include "slang-legalize-types.h" #include "slang-lower-to-ir.h" #include "slang-mangle.h" @@ -376,6 +376,8 @@ Result linkAndOptimizeIR( if (sink->getErrorCount() != 0) return SLANG_FAIL; + eliminateMultiLevelBreak(irModule); + // TODO(DG): There are multiple DCE steps here, which need to be changed // so that they don't just throw out any non-entry point code // Debugging code for IR transformations... @@ -784,7 +786,7 @@ Result linkAndOptimizeIR( { // We only want to accumulate locations if liveness tracking is enabled. - eliminatePhis(codeGenContext, livenessMode, irModule); + eliminatePhis(livenessMode, irModule); #if 0 dumpIRIfEnabled(codeGenContext, irModule, "PHIS ELIMINATED"); #endif @@ -934,7 +936,7 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr linkingAndOptimizationOptions.sourceEmitter = sourceEmitter; - switch( sourceLanguage ) + switch (sourceLanguage) { default: break; @@ -962,7 +964,11 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr // TODO: do we want to emit directly from IR, or translate the // IR back into AST for emission? #if 0 - dumpIR(compileRequest, irModule, "PRE-EMIT"); + { + StringBuilder sb; + StringWriter writer(&sb, Slang::WriterFlag::AutoFlush); + dumpIR(irModule, getIRDumpOptions(), sourceManager, &writer); + } #endif sourceEmitter->emitModule(irModule, sink); } diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 554a407ee..5eee13d5e 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -5,6 +5,7 @@ #include "slang-ir-insts.h" #include "slang-ir-clone.h" #include "slang-ir-dce.h" +#include "slang-ir-eliminate-phis.h" namespace Slang { @@ -1308,7 +1309,8 @@ struct JVPDerivativeContext IRFunc* emitJVPFunction(IRBuilder* builder, IRFunc* primalFn) { - + eliminatePhisInFunc(LivenessMode::Disabled, module, primalFn); + builder->setInsertBefore(primalFn->getNextInst()); auto jvpFn = builder->createFunc(); diff --git a/source/slang/slang-ir-dominators.cpp b/source/slang/slang-ir-dominators.cpp index bae9f772f..72b156228 100644 --- a/source/slang/slang-ir-dominators.cpp +++ b/source/slang/slang-ir-dominators.cpp @@ -305,6 +305,18 @@ void computePostorder(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder) PostorderComputationContext context; context.order = &outOrder; context.walk(code); + + // Append unvisited blocks (unreachable blocks) to the begining of postOrder. + List<IRBlock*> prefix; + for (auto block : code->getBlocks()) + { + if (!context.visited.Contains(block)) + { + prefix.add(block); + } + } + prefix.addRange(outOrder); + outOrder = _Move(prefix); } // diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp new file mode 100644 index 000000000..269b74aad --- /dev/null +++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp @@ -0,0 +1,308 @@ +// slang-ir-eliminate-multilevel-break.cpp +#include "slang-ir-eliminate-multilevel-break.h" +#include "slang-ir.h" +#include "slang-ir-clone.h" +#include "slang-ir-insts.h" +#include "slang-ir-eliminate-phis.h" +#include "slang-ir-dominators.h" + +namespace Slang +{ + +struct EliminateMultiLevelBreakContext +{ + IRModule* irModule; + + struct LoopInfo : RefObject + { + LoopInfo* parent = nullptr; + int level = 0; + IRLoop* loopInst; + List<IRBlock*> blocks; + HashSet<IRBlock*> blockSet; + List<RefPtr<LoopInfo>> childLoops; + IRBlock* getBreakBlock() { return loopInst->getBreakBlock(); } + template<typename Func> + void forEach(const Func& f) + { + f(this); + for (auto child : childLoops) + child->forEach(f); + } + }; + + struct MultiLevelBreakInfo + { + IRUnconditionalBranch* breakInst; + LoopInfo* currentLoop; + LoopInfo* breakTargetLoop; + }; + + struct FuncContext + { + List<RefPtr<LoopInfo>> loops; + HashSet<IRBlock*> breakBlocks; + Dictionary<IRBlock*, LoopInfo*> mapBreakBlockToLoop; + Dictionary<IRBlock*, LoopInfo*> mapBlockToLoop; + HashSet<IRBlock*> processedBlocks; + List<MultiLevelBreakInfo> multiLevelBreaks; + + void collectLoopBlocks(LoopInfo& info) + { + auto startBlock = info.loopInst->getTargetBlock(); + info.blockSet.Add(startBlock); + info.blocks.add(startBlock); + breakBlocks.Add(info.loopInst->getBreakBlock()); + for (Index i = 0; i < info.blocks.getCount(); i++) + { + auto block = info.blocks[i]; + if (!processedBlocks.Add(block)) + continue; + if (auto loopInst = as<IRLoop>(block->getTerminator())) + { + RefPtr<LoopInfo> childLoop = new LoopInfo(); + childLoop->loopInst = loopInst; + childLoop->parent = &info; + childLoop->level = info.level + 1; + collectLoopBlocks(*childLoop); + info.childLoops.add(childLoop); + block = loopInst->getBreakBlock(); + if (info.blockSet.Add(block)) + { + info.blocks.add(block); + } + continue; + } + for (auto succ : block->getSuccessors()) + { + if (!breakBlocks.Contains(succ)) + { + if (info.blockSet.Add(succ)) + info.blocks.add(succ); + } + } + } + breakBlocks.Remove(info.loopInst->getBreakBlock()); + } + + void gatherInfo(IRGlobalValueWithCode* func) + { + for (auto block : func->getBlocks()) + { + if (processedBlocks.Contains(block)) + continue; + auto terminator = block->getTerminator(); + if (auto loop = as<IRLoop>(terminator)) + { + RefPtr<LoopInfo> loopInfo = new LoopInfo(); + loopInfo->loopInst = loop; + collectLoopBlocks(*loopInfo); + loops.add(loopInfo); + } + } + + for (auto& l : loops) + { + l->forEach( + [&](LoopInfo* loop) + { + mapBreakBlockToLoop.Add(loop->loopInst->getBreakBlock(), loop); + for (auto block : loop->blocks) + mapBlockToLoop.Add(block, loop); + }); + } + + for (auto block : func->getBlocks()) + { + auto terminator = block->getTerminator(); + if (auto branch = as<IRUnconditionalBranch>(terminator)) + { + if (as<IRLoop>(terminator)) + continue; + LoopInfo* breakLoop = nullptr; + LoopInfo* currentLoop = nullptr; + if (!mapBreakBlockToLoop.TryGetValue(branch->getTargetBlock(), breakLoop)) + continue; + if (mapBlockToLoop.TryGetValue(block, currentLoop)) + { + if (currentLoop != breakLoop) + { + MultiLevelBreakInfo breakInfo; + breakInfo.breakInst = branch; + breakInfo.breakTargetLoop = breakLoop; + breakInfo.currentLoop = currentLoop; + multiLevelBreaks.add(breakInfo); + } + } + } + } + } + }; + + void processFunc(IRGlobalValueWithCode* func) + { + // If func does not have any multi-level breaks, return. + { + FuncContext funcInfo; + funcInfo.gatherInfo(func); + + if (funcInfo.multiLevelBreaks.getCount() == 0) + return; + } + + // To make things easy, eliminate Phis before perform transformations. + eliminatePhisInFunc(LivenessMode::Disabled, irModule, func); + + // Before modifying the cfg, we gather all required info from the existing cfg. + FuncContext funcInfo; + funcInfo.gatherInfo(func); + + if (funcInfo.multiLevelBreaks.getCount() == 0) + return; + + SharedIRBuilder sharedBuilder; + sharedBuilder.init(irModule); + IRBuilder builder(&sharedBuilder); + builder.setInsertInto(func); + + OrderedHashSet<LoopInfo*> skippedOverLoops; + auto unreachableBlock = builder.emitBlock(); + builder.setInsertInto(unreachableBlock); + builder.emitUnreachable(); + builder.setInsertInto(func); + + // Rewrite multi-level breaks with single level break + target level argument. + for (auto breakInfo : funcInfo.multiLevelBreaks) + { + auto loop = breakInfo.currentLoop; + while (loop) + { + skippedOverLoops.Add(loop); + loop = loop->parent; + if (loop == breakInfo.breakTargetLoop) + break; + } + builder.setInsertBefore(breakInfo.breakInst); + auto targetLevelInst = builder.getIntValue(builder.getIntType(), breakInfo.breakTargetLoop->level); + builder.emitBranch(breakInfo.currentLoop->getBreakBlock(), 1, &targetLevelInst); + breakInfo.breakInst->removeAndDeallocate(); + } + + // Rewrite skipped-over break blocks to accept a target level argument. + builder.setInsertInto(func); + OrderedDictionary<IRBlock*, int> mapNewBreakBlockToLoopLevel; + for (auto skippedLoop : skippedOverLoops) + { + auto breakBlock = skippedLoop->getBreakBlock(); + + // The existing break block cannot have parameters. We assume that PHI-elimination is + // run before this pass. + SLANG_RELEASE_ASSERT(breakBlock->getFirstParam() == nullptr); + + // The new CFG structure will be: newBreakBlock --> newBreakBodyBlock { IfElse (-->oldBreakBlock, -->outerBreakBlock) } + // `newBreakBlock` defines the `IRParam` for the break target, then immediately jumps to `newBreakBodyBlock` for the actual branch. We need this + // separation to avoid introducing critical edge to the CFG (blocks cannot have more + // than 1 predecessors and more than 1 successors at the same time). + auto jumpToOuterBlock = builder.createBlock(); + auto newBreakBodyBlock = builder.createBlock(); + auto newBreakBlock = builder.createBlock(); + newBreakBlock->insertBefore(breakBlock); + newBreakBodyBlock->insertAfter(breakBlock); + jumpToOuterBlock->insertAfter(newBreakBlock); + mapNewBreakBlockToLoopLevel[newBreakBlock] = skippedLoop->level; + breakBlock->replaceUsesWith(newBreakBlock); + builder.setInsertInto(newBreakBlock); + auto targetLevelParam = builder.emitParam(builder.getIntType()); + builder.emitBranch(newBreakBodyBlock); + builder.setInsertInto(newBreakBodyBlock); + auto levelNeq = builder.emitNeq(targetLevelParam, builder.getIntValue(builder.getIntType(), skippedLoop->level)); + builder.emitIfElse(levelNeq, jumpToOuterBlock, breakBlock, unreachableBlock); + builder.setInsertInto(jumpToOuterBlock); + if (skippedOverLoops.Contains(skippedLoop->parent)) + { + builder.emitBranch(skippedLoop->parent->getBreakBlock(), 1, (IRInst**)&targetLevelParam); + } + else + { + builder.emitBranch(skippedLoop->parent->getBreakBlock()); + } + } + + // Once we have rewritten loops' break blocks with additional targetLevel parameter, all + // original branches into that block without a parameter will now need to provide a default + // value equal to the level of its corresponding loop. + for (auto breakBlockKV : mapNewBreakBlockToLoopLevel) + { + auto breakBlock = breakBlockKV.Key; + auto level = breakBlockKV.Value; + IRInst* levelInst = nullptr; + List<IRUse*> uses; + for (auto use = breakBlock->firstUse; use; use = use->nextUse) + { + uses.add(use); + } + for (auto use : uses) + { + auto user = use->getUser(); + switch (user->getOp()) + { + case kIROp_conditionalBranch: + case kIROp_ifElse: + case kIROp_Switch: + // For complex branches, insert an intermediate block so we can specify the + // target index argument. + { + builder.setInsertInto(func); + auto tmpBlock = builder.createBlock(); + tmpBlock->insertAfter(user->getParent()); + builder.setInsertInto(tmpBlock); + if (!levelInst) + levelInst = builder.getIntValue(builder.getIntType(), level); + builder.emitBranch(breakBlock, 1, &levelInst); + use->set(tmpBlock); + } + break; + case kIROp_loop: + // Ignore. + continue; + case kIROp_unconditionalBranch: + { + auto originalBranch = as<IRUnconditionalBranch>(user); + if (originalBranch->getOperandCount() == 1) + { + builder.setInsertBefore(originalBranch); + if (!levelInst) + levelInst = builder.getIntValue(builder.getIntType(), level); + builder.emitBranch(breakBlock, 1, &levelInst); + originalBranch->removeAndDeallocate(); + } + } + break; + } + + } + } + } +}; + +void eliminateMultiLevelBreak(IRModule* irModule) +{ + EliminateMultiLevelBreakContext context; + context.irModule = irModule; + for (auto globalInst : irModule->getGlobalInsts()) + { + if (auto codeInst = as<IRGlobalValueWithCode>(globalInst)) + { + context.processFunc(codeInst); + } + } +} + +void eliminateMultiLevelBreakForFunc(IRModule* irModule, IRGlobalValueWithCode* func) +{ + EliminateMultiLevelBreakContext context; + context.irModule = irModule; + context.processFunc(func); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-eliminate-multilevel-break.h b/source/slang/slang-ir-eliminate-multilevel-break.h new file mode 100644 index 000000000..f6210bc12 --- /dev/null +++ b/source/slang/slang-ir-eliminate-multilevel-break.h @@ -0,0 +1,12 @@ +// slang-ir-eliminate-multi-level-break.h +#pragma once + +namespace Slang +{ + struct IRModule; + struct IRGlobalValueWithCode; + + void eliminateMultiLevelBreak(IRModule* module); + void eliminateMultiLevelBreakForFunc(IRModule* module, IRGlobalValueWithCode* func); + +} diff --git a/source/slang/slang-ir-eliminate-phis.cpp b/source/slang/slang-ir-eliminate-phis.cpp index 9aac7de3a..c4f0b8b9d 100644 --- a/source/slang/slang-ir-eliminate-phis.cpp +++ b/source/slang/slang-ir-eliminate-phis.cpp @@ -64,15 +64,13 @@ struct PhiEliminationContext // At the top level, our pas needs to have access to the IR module, and needs // a builder it can use to generate code. // - CodeGenContext* m_codeGenContext = nullptr; IRModule* m_module = nullptr; SharedIRBuilder m_sharedBuilder; IRBuilder m_builder; LivenessMode m_livenessMode; - PhiEliminationContext(CodeGenContext* codeGenContext, LivenessMode livenessMode, IRModule* module) - : m_codeGenContext(codeGenContext) - , m_module(module) + PhiEliminationContext(LivenessMode livenessMode, IRModule* module) + : m_module(module) , m_sharedBuilder(module) , m_builder(m_sharedBuilder) , m_livenessMode(livenessMode) @@ -900,10 +898,16 @@ struct PhiEliminationContext } }; -void eliminatePhis(CodeGenContext* codeGenContext, LivenessMode livenessMode, IRModule* module) +void eliminatePhis(LivenessMode livenessMode, IRModule* module) { - PhiEliminationContext context(codeGenContext, livenessMode, module); + PhiEliminationContext context(livenessMode, module); context.eliminatePhisInModule(); } +void eliminatePhisInFunc(LivenessMode livenessMode, IRModule* module, IRGlobalValueWithCode* func) +{ + PhiEliminationContext context(livenessMode, module); + context.eliminatePhisInFunc(func); +} + } diff --git a/source/slang/slang-ir-eliminate-phis.h b/source/slang/slang-ir-eliminate-phis.h index b21363cc8..ff81d5b38 100644 --- a/source/slang/slang-ir-eliminate-phis.h +++ b/source/slang/slang-ir-eliminate-phis.h @@ -15,5 +15,7 @@ namespace Slang /// are not themselves based on an SSA representation. /// /// If livenessMode is enabled LiveRangeStarts will be inserted into the module. - void eliminatePhis(CodeGenContext* context, LivenessMode livenessMode, IRModule* module); + void eliminatePhis(LivenessMode livenessMode, IRModule* module); + + void eliminatePhisInFunc(LivenessMode livenessMode, IRModule* module, IRGlobalValueWithCode* func); } diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index 77691d169..70d98d10c 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -2,6 +2,7 @@ #include "slang-ir-inline.h" #include "slang-ir-ssa-simplification.h" +#include "slang-ir-single-return.h" // This file provides general facilities for inlining function calls. @@ -320,48 +321,9 @@ struct InliningPassBase // For now, our inlining pass only handles the case where // the callee is a "single-return" function, which means the callee // function contains only one return at the end of the body. - if (isSingleReturnFunc(callee)) - { - inlineSingleReturnFuncBody(callSite, &env, &builder); - } - else - { - // Running into any non-trivial function to be inlined - // is currently an internal compiler error. - // - SLANG_UNIMPLEMENTED_X("general case of inlining"); - } - } - - /// Check if `func` represents a simple callee that has only a single `return`. - bool isSingleReturnFunc(IRFunc* func) - { - auto firstBlock = func->getFirstBlock(); - - // If the body block is decorated (for some reason), then the function is non-trivial. - // - if( firstBlock->getFirstDecoration() ) - return false; - - // If the body has more than one returns, we cannot inline it now. - bool returnFound = false; - for (auto block : func->getBlocks()) - { - for (auto inst : block->getChildren()) - { - if (inst->getOp() == kIROp_Return) - { - // If the return is not at the end of the block, we cannot handle it. - if (inst != block->getTerminator()) - return false; - // If there is already a return found, this function cannot be simple. - if (returnFound) - return false; - returnFound = true; - } - } - } - return true; + + convertFuncToSingleReturnForm(m_module, callSite.callee); + inlineSingleReturnFuncBody(callSite, &env, &builder); } // When instructions are cloned, with cloneInst no sourceLoc information is copied over by default. @@ -527,6 +489,7 @@ struct InliningPassBase // call->removeAndDeallocate(); } + }; /// An inlining pass that inlines calls to `[unsafeForceInlineEarly]` functions diff --git a/source/slang/slang-ir-inst-pass-base.h b/source/slang/slang-ir-inst-pass-base.h index ec4506272..2e251e46d 100644 --- a/source/slang/slang-ir-inst-pass-base.h +++ b/source/slang/slang-ir-inst-pass-base.h @@ -56,6 +56,32 @@ namespace Slang } } + template <typename InstType, typename Func> + void processChildInstsOfType(IROp instOp, IRInst* parent, const Func& f) + { + workList.clear(); + workListSet.Clear(); + + addToWorkList(parent); + + while (workList.getCount() != 0) + { + IRInst* inst = workList.getLast(); + + workList.removeLast(); + workListSet.Remove(inst); + if (inst->getOp() == instOp) + { + f(as<InstType>(inst)); + } + + for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) + { + addToWorkList(child); + } + } + } + template <typename Func> void processAllInsts(const Func& f) { diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 76dbe55f9..e09d60c3e 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1543,6 +1543,8 @@ struct IRContinue : IRUnconditionalBranch {}; // about the loop structure: struct IRLoop : IRUnconditionalBranch { + IR_LEAF_ISA(loop); + // The next block after the loop, which // is where we expect control flow to // re-converge, and also where a diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index af0e7c0ce..db0274d32 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -6,7 +6,7 @@ namespace Slang { -bool processFunc(IRFunc* func) +bool processFunc(IRGlobalValueWithCode* func) { auto firstBlock = func->getFirstBlock(); if (!firstBlock) @@ -23,6 +23,15 @@ bool processFunc(IRFunc* func) workList.fastRemoveAt(0); while (block) { + if (auto loop = as<IRLoop>(block->getTerminator())) + { + auto continueBlock = loop->getContinueBlock(); + if (continueBlock && !continueBlock->hasMoreThanOneUse()) + { + loop->continueBlock.set(loop->getTargetBlock()); + continueBlock->removeAndDeallocate(); + } + } // If `block` does not end with an unconditional branch, bail. if (block->getTerminator()->getOp() != kIROp_unconditionalBranch) break; @@ -33,6 +42,8 @@ bool processFunc(IRFunc* func) // merge point in CFG. Such blocks will have more than one use. if (successor->hasMoreThanOneUse()) break; + if (block->hasMoreThanOneUse()) + break; changed = true; Index paramIndex = 0; auto inst = successor->getFirstDecorationOrChild(); @@ -79,4 +90,9 @@ bool simplifyCFG(IRModule* module) return changed; } +bool simplifyCFG(IRGlobalValueWithCode* func) +{ + return processFunc(func); +} + } // namespace Slang diff --git a/source/slang/slang-ir-simplify-cfg.h b/source/slang/slang-ir-simplify-cfg.h index 3d8729274..6bfa6e2bf 100644 --- a/source/slang/slang-ir-simplify-cfg.h +++ b/source/slang/slang-ir-simplify-cfg.h @@ -4,9 +4,13 @@ namespace Slang { struct IRModule; + struct IRGlobalValueWithCode; /// Simplifies control flow graph by merging basic blocks that /// forms a simple linear chain. /// Returns true if changed. bool simplifyCFG(IRModule* module); + + bool simplifyCFG(IRGlobalValueWithCode* func); + } diff --git a/source/slang/slang-ir-single-return.cpp b/source/slang/slang-ir-single-return.cpp new file mode 100644 index 000000000..a00066556 --- /dev/null +++ b/source/slang/slang-ir-single-return.cpp @@ -0,0 +1,103 @@ +// slang-ir-single-return.cpp +#include "slang-ir-single-return.h" +#include "slang-ir.h" +#include "slang-ir-clone.h" +#include "slang-ir-insts.h" +#include "slang-ir-inst-pass-base.h" +#include "slang-ir-eliminate-multilevel-break.h" +#include "slang-ir-simplify-cfg.h" + +namespace Slang +{ + +struct SingleReturnContext : public InstPassBase +{ + SingleReturnContext(IRModule* inModule) + : InstPassBase(inModule) + {} + void processFunc(IRGlobalValueWithCode* func) + { + SharedIRBuilder sharedBuilder; + sharedBuilder.init(module); + IRBuilder builder(&sharedBuilder); + simplifyCFG(func); + + // We make use of the `eliminate-multi-level-break` pass to implement the transformation. + // To be able to do that, we need to prepare `func` so that the entire function body + // is wrapped in a trivial loop and turn all `return`s into `break`s out of the outter most + // loop. + builder.setInsertInto(func); + auto breakBlock = builder.emitBlock(); + auto returnBlock = builder.emitBlock(); + builder.setInsertInto(breakBlock); + auto resultType = as<IRFuncType>(func->getDataType())->getResultType(); + + IRInst* retValParam = nullptr; + if (resultType->getOp() != kIROp_VoidType) + { + retValParam = builder.emitParam(resultType); + } + builder.emitBranch(returnBlock); + + auto originalStartBlock = func->getFirstBlock(); + auto loopHeaderBlock = builder.createBlock(); + loopHeaderBlock->insertBefore(originalStartBlock); + builder.setInsertInto(loopHeaderBlock); + + // Move all params into `loopHeaderBlock`. + List<IRParam*> params; + for (auto param : originalStartBlock->getParams()) + { + params.add(param); + } + for (auto param : params) + { + loopHeaderBlock->addParam(param); + } + auto loopInst = (IRLoop*)builder.emitLoop(originalStartBlock, breakBlock, originalStartBlock); + + // Now replace all return insts as break insts. + processChildInstsOfType<IRReturn>(kIROp_Return, func, [&](IRReturn* returnInst) + { + IRInst* retVal = nullptr; + if (returnInst->getOperandCount() == 0) + retVal = builder.getVoidValue(); + else + retVal = returnInst->getVal(); + builder.setInsertBefore(returnInst); + if (resultType->getOp()==kIROp_VoidType) + { + builder.emitBranch(breakBlock); + } + else + { + builder.emitBranch(breakBlock, 1, &retVal); + } + returnInst->removeAndDeallocate(); + }); + + builder.setInsertInto(returnBlock); + if (retValParam) + builder.emitReturn(retValParam); + else + builder.emitReturn(); + + // Now run the multi-level-break pass. + eliminateMultiLevelBreakForFunc(module, func); + + // Now remove the trivial loop header. + SLANG_RELEASE_ASSERT(loopInst->getContinueBlock() == loopInst->getTargetBlock()); + auto targetBlock = loopInst->getTargetBlock(); + for (auto param : params) + targetBlock->addParam(param); + loopHeaderBlock->removeAndDeallocate(); + } +}; + +void convertFuncToSingleReturnForm(IRModule* irModule, IRGlobalValueWithCode* func) +{ + SingleReturnContext context(irModule); + context.processFunc(func); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-single-return.h b/source/slang/slang-ir-single-return.h new file mode 100644 index 000000000..2ddfa280b --- /dev/null +++ b/source/slang/slang-ir-single-return.h @@ -0,0 +1,12 @@ +// slang-ir-single-return.h +#pragma once + +namespace Slang +{ + struct IRModule; + struct IRGlobalValueWithCode; + + // Convert the CFG of `func` to have only a single `return` at the end. + void convertFuncToSingleReturnForm(IRModule* module, IRGlobalValueWithCode* func); + +} diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index 797fcb25c..a496db3a8 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -531,12 +531,13 @@ IRInst* addPhiOperands( auto block = blockInfo->block; List<IRInst*> operandValues; + auto predecessorCount = block->getPredecessors().getCount(); for (auto predBlock : block->getPredecessors()) { // Precondition: if we have multiple predecessors, then // each must have only one successor (no critical edges). // - SLANG_ASSERT(predBlock->getSuccessors().getCount() == 1); + SLANG_RELEASE_ASSERT(predecessorCount <= 1 || predBlock->getSuccessors().getCount() == 1); auto predInfo = *context->blockInfos.TryGetValue(predBlock); diff --git a/tests/experimental/liveness/liveness-3.slang.expected b/tests/experimental/liveness/liveness-3.slang.expected index 58f562d86..3c124ed92 100644 --- a/tests/experimental/liveness/liveness-3.slang.expected +++ b/tests/experimental/liveness/liveness-3.slang.expected @@ -87,17 +87,18 @@ int calcThing_0(int offset_0) livenessStart_1(_S7, 0); _S7 = _S10; } - idx_0[modRange_0] = idx_0[modRange_0] + (_S7 + i_0); + int _S11 = _S7 + i_0; + idx_0[modRange_0] = idx_0[modRange_0] + _S11; i_0 = i_0 + 1; livenessStart_1(_S5, 0); - int _S11 = _S7; + int _S12 = _S7; livenessEnd_0(_S7, 0); - _S5 = _S11; + _S5 = _S12; } livenessEnd_0(i_0, 0); livenessEnd_0(_S2, 0); - int _S12 = (k_0 + 7) % 5; - if(_S12 == 4) + int _S13 = (k_0 + 7) % 5; + if(_S13 == 4) { livenessEnd_0(_S5, 0); livenessEnd_1(idx_0, 0); @@ -105,39 +106,39 @@ int calcThing_0(int offset_0) livenessEnd_2(another_0, 0); return total_0; } - int _S13 = idx_0[0] + idx_0[1]; - int _S14 = idx_0[2]; + int _S14 = idx_0[0] + idx_0[1]; + int _S15 = idx_0[2]; livenessEnd_1(idx_0, 0); - int _S15 = _S13 + _S14; - int _S16 = total_0; + int _S16 = _S14 + _S15; + int _S17 = total_0; livenessEnd_0(total_0, 0); - int total_1 = _S16 + _S15; + int total_1 = _S17 + _S16; k_0 = k_0 + 1; livenessStart_1(_S2, 0); - int _S17 = _S5; + int _S18 = _S5; livenessEnd_0(_S5, 0); - _S2 = _S17; + _S2 = _S18; livenessStart_1(total_0, 0); total_0 = total_1; } livenessEnd_0(_S2, 0); livenessEnd_0(k_0, 0); livenessEnd_2(another_0, 0); - int _S18 = total_0; + int _S19 = total_0; livenessEnd_0(total_0, 0); - return - _S18; + return - _S19; } -layout(std430, binding = 0) buffer _S19 { +layout(std430, binding = 0) buffer _S20 { int _data[]; } outputBuffer_0; layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in; void main() { int index_0 = int(gl_GlobalInvocationID.x); - uint _S20 = uint(index_0); - int _S21 = calcThing_0(index_0); - ((outputBuffer_0)._data[(_S20)]) = _S21; + uint _S21 = uint(index_0); + int _S22 = calcThing_0(index_0); + ((outputBuffer_0)._data[(_S21)]) = _S22; return; } diff --git a/tests/experimental/liveness/liveness-4.slang.expected b/tests/experimental/liveness/liveness-4.slang.expected index 52c6ebb32..38f42c02a 100644 --- a/tests/experimental/liveness/liveness-4.slang.expected +++ b/tests/experimental/liveness/liveness-4.slang.expected @@ -48,12 +48,13 @@ int calcThing_0(int offset_0) { break; } - another_0[i_0 & 1] = another_0[i_0 & 1] + (k_0 + i_0); + int _S2 = k_0 + i_0; + another_0[i_0 & 1] = another_0[i_0 & 1] + _S2; i_0 = i_0 + 1; } livenessEnd_0(i_0, 0); - int _S2 = (k_0 + 7) % 5; - if(_S2 == 4) + int _S3 = (k_0 + 7) % 5; + if(_S3 == 4) { livenessEnd_0(k_0, 0); livenessEnd_1(another_0, 0); @@ -66,16 +67,16 @@ int calcThing_0(int offset_0) return -2; } -layout(std430, binding = 0) buffer _S3 { +layout(std430, binding = 0) buffer _S4 { int _data[]; } outputBuffer_0; layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in; void main() { int index_0 = int(gl_GlobalInvocationID.x); - uint _S4 = uint(index_0); - int _S5 = calcThing_0(index_0); - ((outputBuffer_0)._data[(_S4)]) = _S5; + uint _S5 = uint(index_0); + int _S6 = calcThing_0(index_0); + ((outputBuffer_0)._data[(_S5)]) = _S6; return; } diff --git a/tests/experimental/liveness/liveness-5.slang.expected b/tests/experimental/liveness/liveness-5.slang.expected index ea6e37036..920e05b59 100644 --- a/tests/experimental/liveness/liveness-5.slang.expected +++ b/tests/experimental/liveness/liveness-5.slang.expected @@ -51,16 +51,17 @@ int calcThing_0(int offset_0) { break; } - another_0[i_0 & 1] = another_0[i_0 & 1] + (k_0 + i_0); + int _S2 = k_0 + i_0; + another_0[i_0 & 1] = another_0[i_0 & 1] + _S2; i_0 = i_0 + 1; } livenessEnd_0(i_0, 0); - int _S2 = another_0[k_0 & 1]; - int _S3 = total_0; + int _S3 = another_0[k_0 & 1]; + int _S4 = total_0; livenessEnd_0(total_0, 0); - int total_1 = _S3 + _S2; - int _S4 = (k_0 + 7) % 5; - if(_S4 == 4) + int total_1 = _S4 + _S3; + int _S5 = (k_0 + 7) % 5; + if(_S5 == 4) { livenessEnd_0(k_0, 0); livenessEnd_1(another_0, 0); @@ -75,32 +76,32 @@ int calcThing_0(int offset_0) int total_2; if(total_0 > 4) { - int _S5 = total_0; + int _S6 = total_0; livenessEnd_0(total_0, 0); - int _S6 = - _S5; + int _S7 = - _S6; livenessStart_1(total_2, 0); - total_2 = _S6; + total_2 = _S7; } else { - int _S7 = total_0; + int _S8 = total_0; livenessEnd_0(total_0, 0); livenessStart_1(total_2, 0); - total_2 = _S7; + total_2 = _S8; } return total_2; } -layout(std430, binding = 0) buffer _S8 { +layout(std430, binding = 0) buffer _S9 { int _data[]; } outputBuffer_0; layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in; void main() { int index_0 = int(gl_GlobalInvocationID.x); - uint _S9 = uint(index_0); - int _S10 = calcThing_0(index_0); - ((outputBuffer_0)._data[(_S9)]) = _S10; + uint _S10 = uint(index_0); + int _S11 = calcThing_0(index_0); + ((outputBuffer_0)._data[(_S10)]) = _S11; return; } diff --git a/tests/experimental/liveness/liveness-6.slang.expected b/tests/experimental/liveness/liveness-6.slang.expected index ac1894f95..91ee98f8e 100644 --- a/tests/experimental/liveness/liveness-6.slang.expected +++ b/tests/experimental/liveness/liveness-6.slang.expected @@ -55,20 +55,21 @@ int calcThing_0(int offset_0) { break; } - another_0[i_0 & 1] = another_0[i_0 & 1] + (k_0 + i_0); + int _S3 = k_0 + i_0; + another_0[i_0 & 1] = another_0[i_0 & 1] + _S3; arr_0[k_0 & 1] = arr_0[k_0 & 1] + i_0; i_0 = i_0 + 1; } livenessEnd_0(i_0, 0); - int _S3 = another_0[k_0 & 1]; - int _S4 = total_0; + int _S4 = another_0[k_0 & 1]; + int _S5 = total_0; livenessEnd_0(total_0, 0); - int total_1 = _S4 + _S3; - int _S5 = arr_0[k_0 & 1]; + int total_1 = _S5 + _S4; + int _S6 = arr_0[k_0 & 1]; livenessEnd_1(arr_0, 0); - int total_2 = total_1 + _S5; - int _S6 = (k_0 + 7) % 5; - if(_S6 == 4) + int total_2 = total_1 + _S6; + int _S7 = (k_0 + 7) % 5; + if(_S7 == 4) { livenessEnd_0(k_0, 0); livenessEnd_1(another_0, 0); @@ -83,32 +84,32 @@ int calcThing_0(int offset_0) int total_3; if(total_0 > 4) { - int _S7 = total_0; + int _S8 = total_0; livenessEnd_0(total_0, 0); - int _S8 = - _S7; + int _S9 = - _S8; livenessStart_1(total_3, 0); - total_3 = _S8; + total_3 = _S9; } else { - int _S9 = total_0; + int _S10 = total_0; livenessEnd_0(total_0, 0); livenessStart_1(total_3, 0); - total_3 = _S9; + total_3 = _S10; } return total_3; } -layout(std430, binding = 0) buffer _S10 { +layout(std430, binding = 0) buffer _S11 { int _data[]; } outputBuffer_0; layout(local_size_x = 4, local_size_y = 1, local_size_z = 1) in; void main() { int index_0 = int(gl_GlobalInvocationID.x); - uint _S11 = uint(index_0); - int _S12 = calcThing_0(index_0); - ((outputBuffer_0)._data[(_S11)]) = _S12; + uint _S12 = uint(index_0); + int _S13 = calcThing_0(index_0); + ((outputBuffer_0)._data[(_S12)]) = _S13; return; } diff --git a/tests/language-feature/general-inline.slang b/tests/language-feature/general-inline.slang new file mode 100644 index 000000000..0fd323014 --- /dev/null +++ b/tests/language-feature/general-inline.slang @@ -0,0 +1,60 @@ +// multi-level-break.slang + +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//DISABLE_TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj + +[__unsafeForceInlineEarly] +int test(int r) +{ + int result = 0; + for (int i = 0; i < 2; i++) + { + for (int j = 0; j < 3; j++) + { + result++; + if (r == 0) + { + return result; + } + } + if (r == 1) + return result; + } + return result; +} + +[__unsafeForceInlineEarly] +void testVoid(int r, out int result) +{ + result = 0; + for (int i = 0; i < 2; i++) + { + for (int j = 0; j < 3; j++) + { + result++; + if (r == 0) + { + return; + } + } + if (r == 1) + return; + } + return; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + outputBuffer[0] = test(0); + outputBuffer[1] = test(1); + outputBuffer[2] = test(2); + + int rs; + testVoid(2, rs); + outputBuffer[3] = rs; +} diff --git a/tests/language-feature/general-inline.slang.expected.txt b/tests/language-feature/general-inline.slang.expected.txt new file mode 100644 index 000000000..ab200b7cf --- /dev/null +++ b/tests/language-feature/general-inline.slang.expected.txt @@ -0,0 +1,4 @@ +1 +3 +6 +6 diff --git a/tests/language-feature/multi-level-break.slang b/tests/language-feature/multi-level-break.slang new file mode 100644 index 000000000..5ebe82e36 --- /dev/null +++ b/tests/language-feature/multi-level-break.slang @@ -0,0 +1,52 @@ +// multi-level-break.slang + +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//DISABLE_TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj + +int test(int r) +{ + int result = 0; +iLoop: + for (int i = 0; i < 2; i++) + { + jLoop: + for (int j = 0; j < 3; j++) + { + for (;;) + { + result++; + if (r == 0) + { + // When r == 0, we break out the outter most loop, + // resulting the inner most statement being run only once. + break iLoop; + } + else if (r == 1) + { + // When r == 1, we break out the `j` loop, + // resulting the inner most statement being run loop-i times. + break jLoop; + } + else + { + // When r takes other values, we break out the inner most loop p, + // resulting the inner most statement being run i*j times. + break; + } + } + } + } + return result; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + outputBuffer[0] = test(0); + outputBuffer[1] = test(1); + outputBuffer[2] = test(2); + outputBuffer[3] = 0; +} diff --git a/tests/language-feature/multi-level-break.slang.expected.txt b/tests/language-feature/multi-level-break.slang.expected.txt new file mode 100644 index 000000000..bb18e9e15 --- /dev/null +++ b/tests/language-feature/multi-level-break.slang.expected.txt @@ -0,0 +1,4 @@ +1 +2 +6 +0 |
