summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-emit.cpp14
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp4
-rw-r--r--source/slang/slang-ir-dominators.cpp12
-rw-r--r--source/slang/slang-ir-eliminate-multilevel-break.cpp308
-rw-r--r--source/slang/slang-ir-eliminate-multilevel-break.h12
-rw-r--r--source/slang/slang-ir-eliminate-phis.cpp16
-rw-r--r--source/slang/slang-ir-eliminate-phis.h4
-rw-r--r--source/slang/slang-ir-inline.cpp47
-rw-r--r--source/slang/slang-ir-inst-pass-base.h26
-rw-r--r--source/slang/slang-ir-insts.h2
-rw-r--r--source/slang/slang-ir-simplify-cfg.cpp18
-rw-r--r--source/slang/slang-ir-simplify-cfg.h4
-rw-r--r--source/slang/slang-ir-single-return.cpp103
-rw-r--r--source/slang/slang-ir-single-return.h12
-rw-r--r--source/slang/slang-ir-ssa.cpp3
15 files changed, 529 insertions, 56 deletions
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index a8d3390f0..a916d0d63 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -13,6 +13,7 @@
#include "slang-ir-dll-export.h"
#include "slang-ir-dll-import.h"
#include "slang-ir-eliminate-phis.h"
+#include "slang-ir-eliminate-multilevel-break.h"
#include "slang-ir-entry-point-uniforms.h"
#include "slang-ir-entry-point-raw-ptr-params.h"
#include "slang-ir-explicit-global-context.h"
@@ -48,7 +49,6 @@
#include "slang-ir-wrap-structured-buffers.h"
#include "slang-ir-liveness.h"
#include "slang-ir-glsl-liveness.h"
-
#include "slang-legalize-types.h"
#include "slang-lower-to-ir.h"
#include "slang-mangle.h"
@@ -376,6 +376,8 @@ Result linkAndOptimizeIR(
if (sink->getErrorCount() != 0)
return SLANG_FAIL;
+ eliminateMultiLevelBreak(irModule);
+
// TODO(DG): There are multiple DCE steps here, which need to be changed
// so that they don't just throw out any non-entry point code
// Debugging code for IR transformations...
@@ -784,7 +786,7 @@ Result linkAndOptimizeIR(
{
// We only want to accumulate locations if liveness tracking is enabled.
- eliminatePhis(codeGenContext, livenessMode, irModule);
+ eliminatePhis(livenessMode, irModule);
#if 0
dumpIRIfEnabled(codeGenContext, irModule, "PHIS ELIMINATED");
#endif
@@ -934,7 +936,7 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr
linkingAndOptimizationOptions.sourceEmitter = sourceEmitter;
- switch( sourceLanguage )
+ switch (sourceLanguage)
{
default:
break;
@@ -962,7 +964,11 @@ SlangResult CodeGenContext::emitEntryPointsSourceFromIR(ComPtr<IArtifact>& outAr
// TODO: do we want to emit directly from IR, or translate the
// IR back into AST for emission?
#if 0
- dumpIR(compileRequest, irModule, "PRE-EMIT");
+ {
+ StringBuilder sb;
+ StringWriter writer(&sb, Slang::WriterFlag::AutoFlush);
+ dumpIR(irModule, getIRDumpOptions(), sourceManager, &writer);
+ }
#endif
sourceEmitter->emitModule(irModule, sink);
}
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 554a407ee..5eee13d5e 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -5,6 +5,7 @@
#include "slang-ir-insts.h"
#include "slang-ir-clone.h"
#include "slang-ir-dce.h"
+#include "slang-ir-eliminate-phis.h"
namespace Slang
{
@@ -1308,7 +1309,8 @@ struct JVPDerivativeContext
IRFunc* emitJVPFunction(IRBuilder* builder,
IRFunc* primalFn)
{
-
+ eliminatePhisInFunc(LivenessMode::Disabled, module, primalFn);
+
builder->setInsertBefore(primalFn->getNextInst());
auto jvpFn = builder->createFunc();
diff --git a/source/slang/slang-ir-dominators.cpp b/source/slang/slang-ir-dominators.cpp
index bae9f772f..72b156228 100644
--- a/source/slang/slang-ir-dominators.cpp
+++ b/source/slang/slang-ir-dominators.cpp
@@ -305,6 +305,18 @@ void computePostorder(IRGlobalValueWithCode* code, List<IRBlock*>& outOrder)
PostorderComputationContext context;
context.order = &outOrder;
context.walk(code);
+
+ // Append unvisited blocks (unreachable blocks) to the begining of postOrder.
+ List<IRBlock*> prefix;
+ for (auto block : code->getBlocks())
+ {
+ if (!context.visited.Contains(block))
+ {
+ prefix.add(block);
+ }
+ }
+ prefix.addRange(outOrder);
+ outOrder = _Move(prefix);
}
//
diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp
new file mode 100644
index 000000000..269b74aad
--- /dev/null
+++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp
@@ -0,0 +1,308 @@
+// slang-ir-eliminate-multilevel-break.cpp
+#include "slang-ir-eliminate-multilevel-break.h"
+#include "slang-ir.h"
+#include "slang-ir-clone.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-eliminate-phis.h"
+#include "slang-ir-dominators.h"
+
+namespace Slang
+{
+
+struct EliminateMultiLevelBreakContext
+{
+ IRModule* irModule;
+
+ struct LoopInfo : RefObject
+ {
+ LoopInfo* parent = nullptr;
+ int level = 0;
+ IRLoop* loopInst;
+ List<IRBlock*> blocks;
+ HashSet<IRBlock*> blockSet;
+ List<RefPtr<LoopInfo>> childLoops;
+ IRBlock* getBreakBlock() { return loopInst->getBreakBlock(); }
+ template<typename Func>
+ void forEach(const Func& f)
+ {
+ f(this);
+ for (auto child : childLoops)
+ child->forEach(f);
+ }
+ };
+
+ struct MultiLevelBreakInfo
+ {
+ IRUnconditionalBranch* breakInst;
+ LoopInfo* currentLoop;
+ LoopInfo* breakTargetLoop;
+ };
+
+ struct FuncContext
+ {
+ List<RefPtr<LoopInfo>> loops;
+ HashSet<IRBlock*> breakBlocks;
+ Dictionary<IRBlock*, LoopInfo*> mapBreakBlockToLoop;
+ Dictionary<IRBlock*, LoopInfo*> mapBlockToLoop;
+ HashSet<IRBlock*> processedBlocks;
+ List<MultiLevelBreakInfo> multiLevelBreaks;
+
+ void collectLoopBlocks(LoopInfo& info)
+ {
+ auto startBlock = info.loopInst->getTargetBlock();
+ info.blockSet.Add(startBlock);
+ info.blocks.add(startBlock);
+ breakBlocks.Add(info.loopInst->getBreakBlock());
+ for (Index i = 0; i < info.blocks.getCount(); i++)
+ {
+ auto block = info.blocks[i];
+ if (!processedBlocks.Add(block))
+ continue;
+ if (auto loopInst = as<IRLoop>(block->getTerminator()))
+ {
+ RefPtr<LoopInfo> childLoop = new LoopInfo();
+ childLoop->loopInst = loopInst;
+ childLoop->parent = &info;
+ childLoop->level = info.level + 1;
+ collectLoopBlocks(*childLoop);
+ info.childLoops.add(childLoop);
+ block = loopInst->getBreakBlock();
+ if (info.blockSet.Add(block))
+ {
+ info.blocks.add(block);
+ }
+ continue;
+ }
+ for (auto succ : block->getSuccessors())
+ {
+ if (!breakBlocks.Contains(succ))
+ {
+ if (info.blockSet.Add(succ))
+ info.blocks.add(succ);
+ }
+ }
+ }
+ breakBlocks.Remove(info.loopInst->getBreakBlock());
+ }
+
+ void gatherInfo(IRGlobalValueWithCode* func)
+ {
+ for (auto block : func->getBlocks())
+ {
+ if (processedBlocks.Contains(block))
+ continue;
+ auto terminator = block->getTerminator();
+ if (auto loop = as<IRLoop>(terminator))
+ {
+ RefPtr<LoopInfo> loopInfo = new LoopInfo();
+ loopInfo->loopInst = loop;
+ collectLoopBlocks(*loopInfo);
+ loops.add(loopInfo);
+ }
+ }
+
+ for (auto& l : loops)
+ {
+ l->forEach(
+ [&](LoopInfo* loop)
+ {
+ mapBreakBlockToLoop.Add(loop->loopInst->getBreakBlock(), loop);
+ for (auto block : loop->blocks)
+ mapBlockToLoop.Add(block, loop);
+ });
+ }
+
+ for (auto block : func->getBlocks())
+ {
+ auto terminator = block->getTerminator();
+ if (auto branch = as<IRUnconditionalBranch>(terminator))
+ {
+ if (as<IRLoop>(terminator))
+ continue;
+ LoopInfo* breakLoop = nullptr;
+ LoopInfo* currentLoop = nullptr;
+ if (!mapBreakBlockToLoop.TryGetValue(branch->getTargetBlock(), breakLoop))
+ continue;
+ if (mapBlockToLoop.TryGetValue(block, currentLoop))
+ {
+ if (currentLoop != breakLoop)
+ {
+ MultiLevelBreakInfo breakInfo;
+ breakInfo.breakInst = branch;
+ breakInfo.breakTargetLoop = breakLoop;
+ breakInfo.currentLoop = currentLoop;
+ multiLevelBreaks.add(breakInfo);
+ }
+ }
+ }
+ }
+ }
+ };
+
+ void processFunc(IRGlobalValueWithCode* func)
+ {
+ // If func does not have any multi-level breaks, return.
+ {
+ FuncContext funcInfo;
+ funcInfo.gatherInfo(func);
+
+ if (funcInfo.multiLevelBreaks.getCount() == 0)
+ return;
+ }
+
+ // To make things easy, eliminate Phis before perform transformations.
+ eliminatePhisInFunc(LivenessMode::Disabled, irModule, func);
+
+ // Before modifying the cfg, we gather all required info from the existing cfg.
+ FuncContext funcInfo;
+ funcInfo.gatherInfo(func);
+
+ if (funcInfo.multiLevelBreaks.getCount() == 0)
+ return;
+
+ SharedIRBuilder sharedBuilder;
+ sharedBuilder.init(irModule);
+ IRBuilder builder(&sharedBuilder);
+ builder.setInsertInto(func);
+
+ OrderedHashSet<LoopInfo*> skippedOverLoops;
+ auto unreachableBlock = builder.emitBlock();
+ builder.setInsertInto(unreachableBlock);
+ builder.emitUnreachable();
+ builder.setInsertInto(func);
+
+ // Rewrite multi-level breaks with single level break + target level argument.
+ for (auto breakInfo : funcInfo.multiLevelBreaks)
+ {
+ auto loop = breakInfo.currentLoop;
+ while (loop)
+ {
+ skippedOverLoops.Add(loop);
+ loop = loop->parent;
+ if (loop == breakInfo.breakTargetLoop)
+ break;
+ }
+ builder.setInsertBefore(breakInfo.breakInst);
+ auto targetLevelInst = builder.getIntValue(builder.getIntType(), breakInfo.breakTargetLoop->level);
+ builder.emitBranch(breakInfo.currentLoop->getBreakBlock(), 1, &targetLevelInst);
+ breakInfo.breakInst->removeAndDeallocate();
+ }
+
+ // Rewrite skipped-over break blocks to accept a target level argument.
+ builder.setInsertInto(func);
+ OrderedDictionary<IRBlock*, int> mapNewBreakBlockToLoopLevel;
+ for (auto skippedLoop : skippedOverLoops)
+ {
+ auto breakBlock = skippedLoop->getBreakBlock();
+
+ // The existing break block cannot have parameters. We assume that PHI-elimination is
+ // run before this pass.
+ SLANG_RELEASE_ASSERT(breakBlock->getFirstParam() == nullptr);
+
+ // The new CFG structure will be: newBreakBlock --> newBreakBodyBlock { IfElse (-->oldBreakBlock, -->outerBreakBlock) }
+ // `newBreakBlock` defines the `IRParam` for the break target, then immediately jumps to `newBreakBodyBlock` for the actual branch. We need this
+ // separation to avoid introducing critical edge to the CFG (blocks cannot have more
+ // than 1 predecessors and more than 1 successors at the same time).
+ auto jumpToOuterBlock = builder.createBlock();
+ auto newBreakBodyBlock = builder.createBlock();
+ auto newBreakBlock = builder.createBlock();
+ newBreakBlock->insertBefore(breakBlock);
+ newBreakBodyBlock->insertAfter(breakBlock);
+ jumpToOuterBlock->insertAfter(newBreakBlock);
+ mapNewBreakBlockToLoopLevel[newBreakBlock] = skippedLoop->level;
+ breakBlock->replaceUsesWith(newBreakBlock);
+ builder.setInsertInto(newBreakBlock);
+ auto targetLevelParam = builder.emitParam(builder.getIntType());
+ builder.emitBranch(newBreakBodyBlock);
+ builder.setInsertInto(newBreakBodyBlock);
+ auto levelNeq = builder.emitNeq(targetLevelParam, builder.getIntValue(builder.getIntType(), skippedLoop->level));
+ builder.emitIfElse(levelNeq, jumpToOuterBlock, breakBlock, unreachableBlock);
+ builder.setInsertInto(jumpToOuterBlock);
+ if (skippedOverLoops.Contains(skippedLoop->parent))
+ {
+ builder.emitBranch(skippedLoop->parent->getBreakBlock(), 1, (IRInst**)&targetLevelParam);
+ }
+ else
+ {
+ builder.emitBranch(skippedLoop->parent->getBreakBlock());
+ }
+ }
+
+ // Once we have rewritten loops' break blocks with additional targetLevel parameter, all
+ // original branches into that block without a parameter will now need to provide a default
+ // value equal to the level of its corresponding loop.
+ for (auto breakBlockKV : mapNewBreakBlockToLoopLevel)
+ {
+ auto breakBlock = breakBlockKV.Key;
+ auto level = breakBlockKV.Value;
+ IRInst* levelInst = nullptr;
+ List<IRUse*> uses;
+ for (auto use = breakBlock->firstUse; use; use = use->nextUse)
+ {
+ uses.add(use);
+ }
+ for (auto use : uses)
+ {
+ auto user = use->getUser();
+ switch (user->getOp())
+ {
+ case kIROp_conditionalBranch:
+ case kIROp_ifElse:
+ case kIROp_Switch:
+ // For complex branches, insert an intermediate block so we can specify the
+ // target index argument.
+ {
+ builder.setInsertInto(func);
+ auto tmpBlock = builder.createBlock();
+ tmpBlock->insertAfter(user->getParent());
+ builder.setInsertInto(tmpBlock);
+ if (!levelInst)
+ levelInst = builder.getIntValue(builder.getIntType(), level);
+ builder.emitBranch(breakBlock, 1, &levelInst);
+ use->set(tmpBlock);
+ }
+ break;
+ case kIROp_loop:
+ // Ignore.
+ continue;
+ case kIROp_unconditionalBranch:
+ {
+ auto originalBranch = as<IRUnconditionalBranch>(user);
+ if (originalBranch->getOperandCount() == 1)
+ {
+ builder.setInsertBefore(originalBranch);
+ if (!levelInst)
+ levelInst = builder.getIntValue(builder.getIntType(), level);
+ builder.emitBranch(breakBlock, 1, &levelInst);
+ originalBranch->removeAndDeallocate();
+ }
+ }
+ break;
+ }
+
+ }
+ }
+ }
+};
+
+void eliminateMultiLevelBreak(IRModule* irModule)
+{
+ EliminateMultiLevelBreakContext context;
+ context.irModule = irModule;
+ for (auto globalInst : irModule->getGlobalInsts())
+ {
+ if (auto codeInst = as<IRGlobalValueWithCode>(globalInst))
+ {
+ context.processFunc(codeInst);
+ }
+ }
+}
+
+void eliminateMultiLevelBreakForFunc(IRModule* irModule, IRGlobalValueWithCode* func)
+{
+ EliminateMultiLevelBreakContext context;
+ context.irModule = irModule;
+ context.processFunc(func);
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-eliminate-multilevel-break.h b/source/slang/slang-ir-eliminate-multilevel-break.h
new file mode 100644
index 000000000..f6210bc12
--- /dev/null
+++ b/source/slang/slang-ir-eliminate-multilevel-break.h
@@ -0,0 +1,12 @@
+// slang-ir-eliminate-multi-level-break.h
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+ struct IRGlobalValueWithCode;
+
+ void eliminateMultiLevelBreak(IRModule* module);
+ void eliminateMultiLevelBreakForFunc(IRModule* module, IRGlobalValueWithCode* func);
+
+}
diff --git a/source/slang/slang-ir-eliminate-phis.cpp b/source/slang/slang-ir-eliminate-phis.cpp
index 9aac7de3a..c4f0b8b9d 100644
--- a/source/slang/slang-ir-eliminate-phis.cpp
+++ b/source/slang/slang-ir-eliminate-phis.cpp
@@ -64,15 +64,13 @@ struct PhiEliminationContext
// At the top level, our pas needs to have access to the IR module, and needs
// a builder it can use to generate code.
//
- CodeGenContext* m_codeGenContext = nullptr;
IRModule* m_module = nullptr;
SharedIRBuilder m_sharedBuilder;
IRBuilder m_builder;
LivenessMode m_livenessMode;
- PhiEliminationContext(CodeGenContext* codeGenContext, LivenessMode livenessMode, IRModule* module)
- : m_codeGenContext(codeGenContext)
- , m_module(module)
+ PhiEliminationContext(LivenessMode livenessMode, IRModule* module)
+ : m_module(module)
, m_sharedBuilder(module)
, m_builder(m_sharedBuilder)
, m_livenessMode(livenessMode)
@@ -900,10 +898,16 @@ struct PhiEliminationContext
}
};
-void eliminatePhis(CodeGenContext* codeGenContext, LivenessMode livenessMode, IRModule* module)
+void eliminatePhis(LivenessMode livenessMode, IRModule* module)
{
- PhiEliminationContext context(codeGenContext, livenessMode, module);
+ PhiEliminationContext context(livenessMode, module);
context.eliminatePhisInModule();
}
+void eliminatePhisInFunc(LivenessMode livenessMode, IRModule* module, IRGlobalValueWithCode* func)
+{
+ PhiEliminationContext context(livenessMode, module);
+ context.eliminatePhisInFunc(func);
+}
+
}
diff --git a/source/slang/slang-ir-eliminate-phis.h b/source/slang/slang-ir-eliminate-phis.h
index b21363cc8..ff81d5b38 100644
--- a/source/slang/slang-ir-eliminate-phis.h
+++ b/source/slang/slang-ir-eliminate-phis.h
@@ -15,5 +15,7 @@ namespace Slang
/// are not themselves based on an SSA representation.
///
/// If livenessMode is enabled LiveRangeStarts will be inserted into the module.
- void eliminatePhis(CodeGenContext* context, LivenessMode livenessMode, IRModule* module);
+ void eliminatePhis(LivenessMode livenessMode, IRModule* module);
+
+ void eliminatePhisInFunc(LivenessMode livenessMode, IRModule* module, IRGlobalValueWithCode* func);
}
diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp
index 77691d169..70d98d10c 100644
--- a/source/slang/slang-ir-inline.cpp
+++ b/source/slang/slang-ir-inline.cpp
@@ -2,6 +2,7 @@
#include "slang-ir-inline.h"
#include "slang-ir-ssa-simplification.h"
+#include "slang-ir-single-return.h"
// This file provides general facilities for inlining function calls.
@@ -320,48 +321,9 @@ struct InliningPassBase
// For now, our inlining pass only handles the case where
// the callee is a "single-return" function, which means the callee
// function contains only one return at the end of the body.
- if (isSingleReturnFunc(callee))
- {
- inlineSingleReturnFuncBody(callSite, &env, &builder);
- }
- else
- {
- // Running into any non-trivial function to be inlined
- // is currently an internal compiler error.
- //
- SLANG_UNIMPLEMENTED_X("general case of inlining");
- }
- }
-
- /// Check if `func` represents a simple callee that has only a single `return`.
- bool isSingleReturnFunc(IRFunc* func)
- {
- auto firstBlock = func->getFirstBlock();
-
- // If the body block is decorated (for some reason), then the function is non-trivial.
- //
- if( firstBlock->getFirstDecoration() )
- return false;
-
- // If the body has more than one returns, we cannot inline it now.
- bool returnFound = false;
- for (auto block : func->getBlocks())
- {
- for (auto inst : block->getChildren())
- {
- if (inst->getOp() == kIROp_Return)
- {
- // If the return is not at the end of the block, we cannot handle it.
- if (inst != block->getTerminator())
- return false;
- // If there is already a return found, this function cannot be simple.
- if (returnFound)
- return false;
- returnFound = true;
- }
- }
- }
- return true;
+
+ convertFuncToSingleReturnForm(m_module, callSite.callee);
+ inlineSingleReturnFuncBody(callSite, &env, &builder);
}
// When instructions are cloned, with cloneInst no sourceLoc information is copied over by default.
@@ -527,6 +489,7 @@ struct InliningPassBase
//
call->removeAndDeallocate();
}
+
};
/// An inlining pass that inlines calls to `[unsafeForceInlineEarly]` functions
diff --git a/source/slang/slang-ir-inst-pass-base.h b/source/slang/slang-ir-inst-pass-base.h
index ec4506272..2e251e46d 100644
--- a/source/slang/slang-ir-inst-pass-base.h
+++ b/source/slang/slang-ir-inst-pass-base.h
@@ -56,6 +56,32 @@ namespace Slang
}
}
+ template <typename InstType, typename Func>
+ void processChildInstsOfType(IROp instOp, IRInst* parent, const Func& f)
+ {
+ workList.clear();
+ workListSet.Clear();
+
+ addToWorkList(parent);
+
+ while (workList.getCount() != 0)
+ {
+ IRInst* inst = workList.getLast();
+
+ workList.removeLast();
+ workListSet.Remove(inst);
+ if (inst->getOp() == instOp)
+ {
+ f(as<InstType>(inst));
+ }
+
+ for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
+ {
+ addToWorkList(child);
+ }
+ }
+ }
+
template <typename Func>
void processAllInsts(const Func& f)
{
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 76dbe55f9..e09d60c3e 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -1543,6 +1543,8 @@ struct IRContinue : IRUnconditionalBranch {};
// about the loop structure:
struct IRLoop : IRUnconditionalBranch
{
+ IR_LEAF_ISA(loop);
+
// The next block after the loop, which
// is where we expect control flow to
// re-converge, and also where a
diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp
index af0e7c0ce..db0274d32 100644
--- a/source/slang/slang-ir-simplify-cfg.cpp
+++ b/source/slang/slang-ir-simplify-cfg.cpp
@@ -6,7 +6,7 @@
namespace Slang
{
-bool processFunc(IRFunc* func)
+bool processFunc(IRGlobalValueWithCode* func)
{
auto firstBlock = func->getFirstBlock();
if (!firstBlock)
@@ -23,6 +23,15 @@ bool processFunc(IRFunc* func)
workList.fastRemoveAt(0);
while (block)
{
+ if (auto loop = as<IRLoop>(block->getTerminator()))
+ {
+ auto continueBlock = loop->getContinueBlock();
+ if (continueBlock && !continueBlock->hasMoreThanOneUse())
+ {
+ loop->continueBlock.set(loop->getTargetBlock());
+ continueBlock->removeAndDeallocate();
+ }
+ }
// If `block` does not end with an unconditional branch, bail.
if (block->getTerminator()->getOp() != kIROp_unconditionalBranch)
break;
@@ -33,6 +42,8 @@ bool processFunc(IRFunc* func)
// merge point in CFG. Such blocks will have more than one use.
if (successor->hasMoreThanOneUse())
break;
+ if (block->hasMoreThanOneUse())
+ break;
changed = true;
Index paramIndex = 0;
auto inst = successor->getFirstDecorationOrChild();
@@ -79,4 +90,9 @@ bool simplifyCFG(IRModule* module)
return changed;
}
+bool simplifyCFG(IRGlobalValueWithCode* func)
+{
+ return processFunc(func);
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-simplify-cfg.h b/source/slang/slang-ir-simplify-cfg.h
index 3d8729274..6bfa6e2bf 100644
--- a/source/slang/slang-ir-simplify-cfg.h
+++ b/source/slang/slang-ir-simplify-cfg.h
@@ -4,9 +4,13 @@
namespace Slang
{
struct IRModule;
+ struct IRGlobalValueWithCode;
/// Simplifies control flow graph by merging basic blocks that
/// forms a simple linear chain.
/// Returns true if changed.
bool simplifyCFG(IRModule* module);
+
+ bool simplifyCFG(IRGlobalValueWithCode* func);
+
}
diff --git a/source/slang/slang-ir-single-return.cpp b/source/slang/slang-ir-single-return.cpp
new file mode 100644
index 000000000..a00066556
--- /dev/null
+++ b/source/slang/slang-ir-single-return.cpp
@@ -0,0 +1,103 @@
+// slang-ir-single-return.cpp
+#include "slang-ir-single-return.h"
+#include "slang-ir.h"
+#include "slang-ir-clone.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-inst-pass-base.h"
+#include "slang-ir-eliminate-multilevel-break.h"
+#include "slang-ir-simplify-cfg.h"
+
+namespace Slang
+{
+
+struct SingleReturnContext : public InstPassBase
+{
+ SingleReturnContext(IRModule* inModule)
+ : InstPassBase(inModule)
+ {}
+ void processFunc(IRGlobalValueWithCode* func)
+ {
+ SharedIRBuilder sharedBuilder;
+ sharedBuilder.init(module);
+ IRBuilder builder(&sharedBuilder);
+ simplifyCFG(func);
+
+ // We make use of the `eliminate-multi-level-break` pass to implement the transformation.
+ // To be able to do that, we need to prepare `func` so that the entire function body
+ // is wrapped in a trivial loop and turn all `return`s into `break`s out of the outter most
+ // loop.
+ builder.setInsertInto(func);
+ auto breakBlock = builder.emitBlock();
+ auto returnBlock = builder.emitBlock();
+ builder.setInsertInto(breakBlock);
+ auto resultType = as<IRFuncType>(func->getDataType())->getResultType();
+
+ IRInst* retValParam = nullptr;
+ if (resultType->getOp() != kIROp_VoidType)
+ {
+ retValParam = builder.emitParam(resultType);
+ }
+ builder.emitBranch(returnBlock);
+
+ auto originalStartBlock = func->getFirstBlock();
+ auto loopHeaderBlock = builder.createBlock();
+ loopHeaderBlock->insertBefore(originalStartBlock);
+ builder.setInsertInto(loopHeaderBlock);
+
+ // Move all params into `loopHeaderBlock`.
+ List<IRParam*> params;
+ for (auto param : originalStartBlock->getParams())
+ {
+ params.add(param);
+ }
+ for (auto param : params)
+ {
+ loopHeaderBlock->addParam(param);
+ }
+ auto loopInst = (IRLoop*)builder.emitLoop(originalStartBlock, breakBlock, originalStartBlock);
+
+ // Now replace all return insts as break insts.
+ processChildInstsOfType<IRReturn>(kIROp_Return, func, [&](IRReturn* returnInst)
+ {
+ IRInst* retVal = nullptr;
+ if (returnInst->getOperandCount() == 0)
+ retVal = builder.getVoidValue();
+ else
+ retVal = returnInst->getVal();
+ builder.setInsertBefore(returnInst);
+ if (resultType->getOp()==kIROp_VoidType)
+ {
+ builder.emitBranch(breakBlock);
+ }
+ else
+ {
+ builder.emitBranch(breakBlock, 1, &retVal);
+ }
+ returnInst->removeAndDeallocate();
+ });
+
+ builder.setInsertInto(returnBlock);
+ if (retValParam)
+ builder.emitReturn(retValParam);
+ else
+ builder.emitReturn();
+
+ // Now run the multi-level-break pass.
+ eliminateMultiLevelBreakForFunc(module, func);
+
+ // Now remove the trivial loop header.
+ SLANG_RELEASE_ASSERT(loopInst->getContinueBlock() == loopInst->getTargetBlock());
+ auto targetBlock = loopInst->getTargetBlock();
+ for (auto param : params)
+ targetBlock->addParam(param);
+ loopHeaderBlock->removeAndDeallocate();
+ }
+};
+
+void convertFuncToSingleReturnForm(IRModule* irModule, IRGlobalValueWithCode* func)
+{
+ SingleReturnContext context(irModule);
+ context.processFunc(func);
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-single-return.h b/source/slang/slang-ir-single-return.h
new file mode 100644
index 000000000..2ddfa280b
--- /dev/null
+++ b/source/slang/slang-ir-single-return.h
@@ -0,0 +1,12 @@
+// slang-ir-single-return.h
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+ struct IRGlobalValueWithCode;
+
+ // Convert the CFG of `func` to have only a single `return` at the end.
+ void convertFuncToSingleReturnForm(IRModule* module, IRGlobalValueWithCode* func);
+
+}
diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp
index 797fcb25c..a496db3a8 100644
--- a/source/slang/slang-ir-ssa.cpp
+++ b/source/slang/slang-ir-ssa.cpp
@@ -531,12 +531,13 @@ IRInst* addPhiOperands(
auto block = blockInfo->block;
List<IRInst*> operandValues;
+ auto predecessorCount = block->getPredecessors().getCount();
for (auto predBlock : block->getPredecessors())
{
// Precondition: if we have multiple predecessors, then
// each must have only one successor (no critical edges).
//
- SLANG_ASSERT(predBlock->getSuccessors().getCount() == 1);
+ SLANG_RELEASE_ASSERT(predecessorCount <= 1 || predBlock->getSuccessors().getCount() == 1);
auto predInfo = *context->blockInfos.TryGetValue(predBlock);