summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorJulius Ikkala <julius.ikkala@gmail.com>2025-04-07 06:08:29 +0300
committerGitHub <noreply@github.com>2025-04-06 20:08:29 -0700
commit1b82501dd0c74347cda4a2c7fe5a84fd610bb485 (patch)
treef283a491e0545aa6b890a988ac9fb14f192b4663 /source/slang
parent680fb0b4e9cbb65d46677183a3f68630be1f6179 (diff)
Add defer statement (#6619)
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ast-iterator.h6
-rw-r--r--source/slang/slang-ast-stmt.h7
-rw-r--r--source/slang/slang-check-decl.cpp2
-rw-r--r--source/slang/slang-check-impl.h4
-rw-r--r--source/slang/slang-check-stmt.cpp29
-rw-r--r--source/slang/slang-diagnostic-defs.h8
-rw-r--r--source/slang/slang-ir-inst-defs.h4
-rw-r--r--source/slang/slang-ir-insts.h11
-rw-r--r--source/slang/slang-ir-lower-defer.cpp258
-rw-r--r--source/slang/slang-ir-lower-defer.h18
-rw-r--r--source/slang/slang-ir.cpp16
-rw-r--r--source/slang/slang-language-server-ast-lookup.cpp2
-rw-r--r--source/slang/slang-lower-to-ir.cpp96
-rw-r--r--source/slang/slang-parser.cpp14
14 files changed, 468 insertions, 7 deletions
diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h
index 436e97c1b..c7da945f2 100644
--- a/source/slang/slang-ast-iterator.h
+++ b/source/slang/slang-ast-iterator.h
@@ -430,6 +430,12 @@ struct ASTIterator
iterator->visitExpr(stmt->expression);
}
+ void visitDeferStmt(DeferStmt* stmt)
+ {
+ iterator->maybeDispatchCallback(stmt);
+ dispatchIfNotNull(stmt->statement);
+ }
+
void visitWhileStmt(WhileStmt* stmt)
{
iterator->maybeDispatchCallback(stmt);
diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h
index 09c491287..f6580ba21 100644
--- a/source/slang/slang-ast-stmt.h
+++ b/source/slang/slang-ast-stmt.h
@@ -253,6 +253,13 @@ class ReturnStmt : public Stmt
Expr* expression = nullptr;
};
+class DeferStmt : public Stmt
+{
+ SLANG_AST_CLASS(DeferStmt)
+
+ Stmt* statement = nullptr;
+};
+
class ExpressionStmt : public Stmt
{
SLANG_AST_CLASS(ExpressionStmt)
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 34143ad75..4ab909118 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -650,6 +650,8 @@ struct SemanticsDeclReferenceVisitor : public SemanticsDeclVisitorBase,
void visitReturnStmt(ReturnStmt* stmt) { dispatchIfNotNull(stmt->expression); }
+ void visitDeferStmt(DeferStmt* stmt) { dispatchIfNotNull(stmt->statement); }
+
void visitWhileStmt(WhileStmt* stmt)
{
dispatchIfNotNull(stmt->predicate);
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 95716744c..f7681ba45 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2976,7 +2976,7 @@ struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitor<SemanticsStmt
void checkStmt(Stmt* stmt);
template<typename T>
- T* FindOuterStmt();
+ T* FindOuterStmt(Stmt* searchUntil = nullptr);
Stmt* findOuterStmtWithLabel(Name* label);
@@ -3020,6 +3020,8 @@ struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitor<SemanticsStmt
void visitReturnStmt(ReturnStmt* stmt);
+ void visitDeferStmt(DeferStmt* stmt);
+
void visitWhileStmt(WhileStmt* stmt);
void visitGpuForeachStmt(GpuForeachStmt* stmt);
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index db6f00d23..c85eb7593 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -119,9 +119,10 @@ void SemanticsStmtVisitor::checkStmt(Stmt* stmt)
}
template<typename T>
-T* SemanticsStmtVisitor::FindOuterStmt()
+T* SemanticsStmtVisitor::FindOuterStmt(Stmt* searchUntil)
{
- for (auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next)
+ for (auto outerStmtInfo = m_outerStmts; outerStmtInfo && outerStmtInfo->stmt != searchUntil;
+ outerStmtInfo = outerStmtInfo->next)
{
auto outerStmt = outerStmtInfo->stmt;
auto found = as<T>(outerStmt);
@@ -178,6 +179,14 @@ void SemanticsStmtVisitor::visitBreakStmt(BreakStmt* stmt)
getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop);
}
}
+
+ // If there is a defer statement before the breakable statement, it's
+ // illegal.
+ if (FindOuterStmt<DeferStmt>(targetStmt))
+ {
+ getSink()->diagnose(stmt, Diagnostics::breakInsideDefer);
+ }
+
stmt->parentStmt = targetStmt;
}
@@ -188,6 +197,11 @@ void SemanticsStmtVisitor::visitContinueStmt(ContinueStmt* stmt)
{
getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop);
}
+
+ if (FindOuterStmt<DeferStmt>(outer))
+ {
+ getSink()->diagnose(stmt, Diagnostics::continueInsideDefer);
+ }
stmt->parentStmt = outer;
}
@@ -497,6 +511,11 @@ void SemanticsStmtVisitor::visitReturnStmt(ReturnStmt* stmt)
}
}
}
+
+ if (FindOuterStmt<DeferStmt>())
+ {
+ getSink()->diagnose(stmt, Diagnostics::returnInsideDefer);
+ }
}
void SemanticsStmtVisitor::visitWhileStmt(WhileStmt* stmt)
@@ -508,6 +527,12 @@ void SemanticsStmtVisitor::visitWhileStmt(WhileStmt* stmt)
checkLoopInDifferentiableFunc(stmt);
}
+void SemanticsStmtVisitor::visitDeferStmt(DeferStmt* stmt)
+{
+ WithOuterStmt subContext(this, stmt);
+ subContext.checkStmt(stmt->statement);
+}
+
void SemanticsStmtVisitor::visitExpressionStmt(ExpressionStmt* stmt)
{
stmt->expression = CheckExpr(stmt->expression);
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index b7469deb0..f2c7fecc1 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -890,6 +890,14 @@ DIAGNOSTIC(
DIAGNOSTIC(30106, Error, improperUseOfType, "type '$0' cannot be used in this context.")
DIAGNOSTIC(30107, Error, parameterPackMustBeConst, "a parameter pack must be declared as 'const'.")
+DIAGNOSTIC(30108, Error, breakInsideDefer, "'break' must not appear inside a defer statement.")
+DIAGNOSTIC(
+ 30109,
+ Error,
+ continueInsideDefer,
+ "'continue' must not appear inside a defer statement.")
+DIAGNOSTIC(30110, Error, returnInsideDefer, "'return' must not appear inside a defer statement.")
+
// Include
DIAGNOSTIC(
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index bd590f08f..c7ed5affe 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -661,7 +661,9 @@ INST(MissingReturn, missingReturn, 0, 0)
INST(Unreachable, unreachable, 0, 0)
INST_RANGE(Unreachable, MissingReturn, Unreachable)
-INST_RANGE(TerminatorInst, Return, Unreachable)
+INST(Defer, defer, 3, 0)
+
+INST_RANGE(TerminatorInst, Return, Defer)
INST(discard, discard, 0, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 51608ebc9..645b13e00 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2852,6 +2852,15 @@ struct IRTryCall : IRTerminatorInst
IRInst* getArg(UInt index) { return getOperand(index + 3); }
};
+struct IRDefer : IRTerminatorInst
+{
+ IR_LEAF_ISA(Defer);
+
+ IRBlock* getDeferBlock() { return cast<IRBlock>(getOperand(0)); }
+ IRBlock* getMergeBlock() { return cast<IRBlock>(getOperand(1)); }
+ IRBlock* getScopeBlock() { return cast<IRBlock>(getOperand(2)); }
+};
+
struct IRSwizzle : IRInst
{
IR_LEAF_ISA(swizzle);
@@ -4462,6 +4471,8 @@ public:
IRInst* emitThrow(IRInst* val);
+ IRInst* emitDefer(IRBlock* deferBlock, IRBlock* mergeBlock, IRBlock* scopeEndBlock);
+
IRInst* emitDiscard();
IRInst* emitCheckpointObject(IRInst* value);
diff --git a/source/slang/slang-ir-lower-defer.cpp b/source/slang/slang-ir-lower-defer.cpp
new file mode 100644
index 000000000..4e1ec9a8b
--- /dev/null
+++ b/source/slang/slang-ir-lower-defer.cpp
@@ -0,0 +1,258 @@
+// slang-ir-lower-defer.cpp
+
+#include "slang-ir-lower-defer.h"
+
+#include "slang-ir-clone.h"
+#include "slang-ir-dominators.h"
+#include "slang-ir-inst-pass-base.h"
+#include "slang-ir-insts.h"
+#include "slang-ir.h"
+
+namespace Slang
+{
+
+struct DeferLoweringContext : InstPassBase
+{
+ DiagnosticSink* diagnosticSink;
+
+ DeferLoweringContext(IRModule* inModule)
+ : InstPassBase(inModule)
+ {
+ }
+
+ void inlineSingleBlockDefer(IRInst* beforeInst, IRBlock* deferBlock, IRBuilder* builder)
+ {
+ builder->setInsertBefore(beforeInst);
+ IRCloneEnv env;
+ for (IRInst* inst : deferBlock->getChildren())
+ {
+ // Copy everything except the terminator; the terminator should only
+ // be a jump to mergeBlock, which isn't needed after inlining.
+ if (!as<IRTerminatorInst>(inst))
+ cloneInst(&env, builder, inst);
+ }
+ }
+
+ // Returns the new last block.
+ IRBlock* inlineDefer(
+ IRInst* beforeInst,
+ IRBlock* targetBlock,
+ const List<IRBlock*>& deferBlocks,
+ IRBlock* mergeBlock,
+ IRBuilder* builder)
+ {
+ // The single-block inlining case is simple, we can just dump the
+ // instructions at the target position, in the existing block.
+ if (deferBlocks.getCount() == 1)
+ {
+ inlineSingleBlockDefer(beforeInst, deferBlocks.getFirst(), builder);
+ return targetBlock;
+ }
+
+ // Otherwise, we'll have to splice the blocks in.
+ IRCloneEnv env;
+ builder->setInsertAfter(targetBlock);
+ auto lastBlock = targetBlock;
+
+ // Clone blocks first
+ for (auto block : deferBlocks)
+ {
+ auto clonedBlock = builder->createBlock();
+ builder->addInst(clonedBlock);
+ env.mapOldValToNew[block] = clonedBlock;
+ }
+
+ // Then, clone instructions, but mapping old blocks to new blocks.
+ for (auto block : deferBlocks)
+ {
+ auto clonedBlock = as<IRBlock>(env.mapOldValToNew.getValue(block));
+ builder->setInsertInto(clonedBlock);
+ for (auto inst : block->getChildren())
+ {
+ auto endBranch = as<IRUnconditionalBranch>(inst);
+ if (endBranch && endBranch->getTargetBlock() == mergeBlock)
+ {
+ lastBlock = clonedBlock;
+ }
+ else
+ cloneInst(&env, builder, inst);
+ }
+ }
+
+ // Move old instructions to the last block's end. The last defer block
+ // shouldn't have a terminator at this point yet.
+ while (beforeInst)
+ {
+ auto nextInst = beforeInst->getNextInst();
+ beforeInst->insertAtEnd(lastBlock);
+ beforeInst = nextInst;
+ }
+
+ // Make target block jump to the cloned blocks.
+ builder->setInsertInto(targetBlock);
+ auto mainBlock = as<IRBlock>(env.mapOldValToNew.getValue(deferBlocks[0]));
+ builder->emitBranch(mainBlock);
+
+ return lastBlock;
+ }
+
+ HashSet<IRBlock*> findSuccessorBlocks(IRGlobalValueWithCode* func, IRBlock* block)
+ {
+ HashSet<IRBlock*> successorBlocksSet;
+ List<IRBlock*> successorWorkList;
+ successorWorkList.add(block);
+
+ List<IRBlock*> postorder = getPostorder(func);
+ Index limitIndex = postorder.indexOf(block);
+ while (successorWorkList.getCount() > 0)
+ {
+ IRBlock* predecessor = successorWorkList.getLast();
+ successorWorkList.removeLast();
+ if (successorBlocksSet.contains(predecessor))
+ continue;
+
+ Index predecessorIndex = postorder.indexOf(predecessor);
+ // Does not succeed if it is after the given block in postorder.
+ if (predecessorIndex > limitIndex)
+ continue;
+
+ successorBlocksSet.add(predecessor);
+ for (IRBlock* successor : predecessor->getSuccessors())
+ successorWorkList.add(successor);
+ }
+ return successorBlocksSet;
+ }
+
+ void processFunc(IRGlobalValueWithCode* func)
+ {
+ // Iterating over `defer` instructions in reverse order allows us to
+ // expand them in the correct order, including nested `defer`s.
+ // We also use this to determine scope extents.
+ List<IRBlock*> reverseBlocks = getReversePostorderOnReverseCFG(func);
+ List<IRDefer*> unhandledDefers;
+
+ for (IRBlock* block : reverseBlocks)
+ {
+ for (auto child = block->getLastChild(); child; child = child->getPrevInst())
+ {
+ if (auto defer = as<IRDefer>(child))
+ unhandledDefers.add(defer);
+ }
+ }
+
+ IRBuilder builder(module);
+ Dictionary<IRBlock*, IRBlock*> mapOldScopeToNew;
+ for (IRDefer* defer : unhandledDefers)
+ {
+ IRBlock* firstDeferBlock = defer->getDeferBlock();
+ IRBlock* mergeBlock = defer->getMergeBlock();
+ IRBlock* scopeEndBlock = defer->getScopeBlock();
+ mapOldScopeToNew.tryGetValue(scopeEndBlock, scopeEndBlock);
+ IRBlock* parentBlock = as<IRBlock>(defer->getParent());
+
+ // The dominator tree gets invalidated on every iteration, so it's
+ // necessary to construct it inside the loop.
+ auto dom = module->findOrCreateDominatorTree(func);
+
+ // Enumerate defer block range. That is, all blocks dominated by
+ // parentBlock and not dominated by mergeBlock.
+ auto deferDominatedBlocks = dom->getProperlyDominatedBlocks(firstDeferBlock);
+ List<IRBlock*> deferBlocks;
+ deferBlocks.add(firstDeferBlock);
+ for (IRBlock* block : deferDominatedBlocks)
+ {
+ if (!dom->properlyDominates(mergeBlock, block) && block != mergeBlock)
+ deferBlocks.add(block);
+ }
+
+ auto dominatedBlocks = dom->getProperlyDominatedBlocks(mergeBlock);
+
+
+ HashSet<IRBlock*> scopeSuccessorBlocksSet = findSuccessorBlocks(func, scopeEndBlock);
+ HashSet<IRBlock*> scopeBlocksSet;
+ scopeBlocksSet.add(mergeBlock);
+ for (IRBlock* block : dominatedBlocks)
+ {
+ if (!scopeSuccessorBlocksSet.contains(block))
+ scopeBlocksSet.add(block);
+ }
+
+ // All jumps from blocks in scope to blocks out of scope are to be
+ // preceded by a copy of the deferBlocks.
+ for (IRBlock* block : scopeBlocksSet)
+ {
+ auto terminator = block->getTerminator();
+ SLANG_ASSERT(terminator);
+ bool exits = false;
+ switch (terminator->getOp())
+ {
+ case kIROp_Return:
+ case kIROp_discard:
+ case kIROp_Throw:
+ exits = true;
+ break;
+ case kIROp_unconditionalBranch:
+ {
+ auto targetBlock = as<IRBlock>(terminator->getOperand(0));
+ if (!scopeBlocksSet.contains(targetBlock))
+ {
+ exits = true;
+ }
+ }
+ break;
+ case kIROp_conditionalBranch:
+ {
+ auto trueBlock = as<IRBlock>(terminator->getOperand(1));
+ auto falseBlock = as<IRBlock>(terminator->getOperand(2));
+ if (!scopeBlocksSet.contains(trueBlock) ||
+ !scopeBlocksSet.contains(falseBlock))
+ {
+ exits = true;
+ }
+ }
+ break;
+ default:
+ break;
+ }
+
+ if (exits)
+ { // Duplicate child instructions to the end of this block.
+ auto newEnd = inlineDefer(terminator, block, deferBlocks, mergeBlock, &builder);
+ if (newEnd != block)
+ {
+ mapOldScopeToNew[block] = newEnd;
+ }
+ }
+ }
+
+ // Replace defer with unconditional branch to mergeBlock. Defer
+ // blocks should now be orphaned, and we can remove them too.
+ defer->removeAndDeallocate();
+ builder.setInsertInto(parentBlock);
+ builder.emitBranch(mergeBlock);
+
+ for (IRBlock* deferBlock : deferBlocks)
+ {
+ deferBlock->removeAndDeallocate();
+ }
+
+ // Some blocks got removed and added, so mark analysis of the
+ // function with defer as outdated.
+ module->invalidateAnalysisForInst(func);
+ }
+ }
+
+ void processModule()
+ {
+ processInstsOfType<IRFunc>(kIROp_Func, [&](IRFunc* func) { processFunc(func); });
+ }
+};
+
+void lowerDefer(IRModule* module, DiagnosticSink* sink)
+{
+ DeferLoweringContext context(module);
+ context.diagnosticSink = sink;
+ return context.processModule();
+}
+
+} // namespace Slang
diff --git a/source/slang/slang-ir-lower-defer.h b/source/slang/slang-ir-lower-defer.h
new file mode 100644
index 000000000..c00104745
--- /dev/null
+++ b/source/slang/slang-ir-lower-defer.h
@@ -0,0 +1,18 @@
+// slang-ir-lower-defer.h
+#pragma once
+
+#include "slang-ir.h"
+
+namespace Slang
+{
+struct IRModule;
+class DiagnosticSink;
+
+/// Lower the `defer` statements.
+///
+/// Duplicates the child instructions of each `defer` to the end of each
+/// dominated block whose terminator jumps to a location that is not dominated
+/// by the `defer`. Also removes all `IRDefer` instructions after that.
+void lowerDefer(IRModule* module, DiagnosticSink* sink);
+
+} // namespace Slang
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 1287d1598..f75fe2f48 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -596,6 +596,13 @@ static IRBlock::SuccessorList getSuccessors(IRInst* terminator)
end = operands + terminator->getOperandCount() + 1;
stride = 2;
break;
+
+ case kIROp_Defer:
+ // defer <deferBlock> <mergeBlock> <scopeEndBlock>
+ begin = operands + 0;
+ end = begin + 1;
+ break;
+
default:
SLANG_UNEXPECTED("unhandled terminator instruction");
UNREACHABLE_RETURN(IRBlock::SuccessorList(nullptr, nullptr));
@@ -869,6 +876,7 @@ bool isTerminatorInst(IROp op)
case kIROp_Switch:
case kIROp_Unreachable:
case kIROp_MissingReturn:
+ case kIROp_Defer:
return true;
}
}
@@ -5698,6 +5706,14 @@ IRInst* IRBuilder::emitReturn()
return inst;
}
+IRInst* IRBuilder::emitDefer(IRBlock* deferBlock, IRBlock* mergeBlock, IRBlock* scopeEndBlock)
+{
+ auto inst =
+ createInst<IRDefer>(this, kIROp_Defer, nullptr, deferBlock, mergeBlock, scopeEndBlock);
+ addInst(inst);
+ return inst;
+}
+
IRInst* IRBuilder::emitThrow(IRInst* val)
{
auto inst = createInst<IRThrow>(this, kIROp_Throw, nullptr, val);
diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp
index ebc3fe9ad..7375756f5 100644
--- a/source/slang/slang-language-server-ast-lookup.cpp
+++ b/source/slang/slang-language-server-ast-lookup.cpp
@@ -642,6 +642,8 @@ struct ASTLookupStmtVisitor : public StmtVisitor<ASTLookupStmtVisitor, bool>
bool visitReturnStmt(ReturnStmt* stmt) { return checkExpr(stmt->expression); }
+ bool visitDeferStmt(DeferStmt* stmt) { return dispatchIfNotNull(stmt->statement); }
+
bool visitWhileStmt(WhileStmt* stmt)
{
if (checkExpr(stmt->predicate))
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index e3c4ddf05..e6ec68660 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -19,6 +19,7 @@
#include "slang-ir-insert-debug-value-store.h"
#include "slang-ir-insts.h"
#include "slang-ir-loop-inversion.h"
+#include "slang-ir-lower-defer.h"
#include "slang-ir-lower-error-handling.h"
#include "slang-ir-lower-expand-type.h"
#include "slang-ir-missing-return.h"
@@ -593,6 +594,9 @@ struct IRGenContext
// The element index if we are inside an `expand` expression.
IRInst* expandIndex = nullptr;
+ // The current scope end for use with `defer`.
+ IRBlock* scopeEndBlock = nullptr;
+
// Callback function to call when after lowering a type.
std::function<IRType*(IRGenContext* context, Type* type, IRType* irType)> lowerTypeCallback =
nullptr;
@@ -6021,6 +6025,53 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
startBlock();
}
+ /// Create a new scope end block and return the previous one.
+ ///
+ /// This is needed for `defer` to be aware of scopes. `preallocated` can
+ /// be specified if you already have a block at the end of the scope, like
+ /// in `for` loops.
+ IRBlock* pushScopeBlock(IRBlock* preallocated = nullptr)
+ {
+ IRBlock* prevScopeEndBlock = context->scopeEndBlock;
+
+ auto builder = getBuilder();
+ context->scopeEndBlock = preallocated ? preallocated : builder->createBlock();
+ return prevScopeEndBlock;
+ }
+
+ /// Pop the current scope end block and restore the previous one.
+ ///
+ /// This is needed for `defer` to be aware of scopes. `previous` should be
+ /// the block returned from the corresponding pushScopeBlock. `preallocated`
+ /// should be true if the corresponding pushScopeBlock was given a block
+ /// as a parameter.
+ void popScopeBlock(IRBlock* previous, bool preallocated)
+ {
+ if (!preallocated)
+ {
+ // If pushScopeBlock actually created the block, we have to insert
+ // or deallocate it here. Otherwise, we assume that the caller
+ // handles the end block.
+ auto builder = getBuilder();
+ if (context->scopeEndBlock->hasUses())
+ {
+ // The end of the scope was referenced, so we need to actually
+ // keep it around and jump through it.
+ // Move the terminator to the scope end block.
+ emitBranchIfNeeded(context->scopeEndBlock);
+ builder->insertBlock(context->scopeEndBlock);
+ builder->setInsertInto(context->scopeEndBlock);
+ }
+ else
+ {
+ // Scope end block was left unused, so we may as well delete it.
+ context->scopeEndBlock->removeAndDeallocate();
+ }
+ }
+
+ context->scopeEndBlock = previous;
+ }
+
void visitIfStmt(IfStmt* stmt)
{
auto builder = getBuilder();
@@ -6043,11 +6094,13 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
ifInst = builder->emitIfElse(irCond, thenBlock, elseBlock, afterBlock);
insertBlock(thenBlock);
+ IRBlock* prevScopeEndBlock = pushScopeBlock(afterBlock);
lowerStmt(context, thenStmt);
emitBranchIfNeeded(afterBlock);
insertBlock(elseBlock);
lowerStmt(context, elseStmt);
+ popScopeBlock(prevScopeEndBlock, true);
insertBlock(afterBlock);
}
@@ -6059,7 +6112,10 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
ifInst = builder->emitIf(irCond, thenBlock, afterBlock);
insertBlock(thenBlock);
+
+ IRBlock* prevScopeEndBlock = pushScopeBlock(afterBlock);
lowerStmt(context, thenStmt);
+ popScopeBlock(prevScopeEndBlock, true);
insertBlock(afterBlock);
}
@@ -6150,7 +6206,9 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// Emit the body of the loop
insertBlock(bodyLabel);
+ IRBlock* prevScopeEndBlock = pushScopeBlock(continueLabel);
lowerStmt(context, stmt->statement);
+ popScopeBlock(prevScopeEndBlock, true);
if (auto inferredMaxIters = stmt->findModifier<InferredMaxItersAttribute>())
{
@@ -6256,7 +6314,9 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// Emit the body of the loop
insertBlock(bodyLabel);
+ IRBlock* prevScopeEndBlock = pushScopeBlock(continueLabel);
lowerStmt(context, stmt->statement);
+ popScopeBlock(prevScopeEndBlock, true);
// At the end of the body we need to jump back to the top.
emitBranchIfNeeded(loopHead);
@@ -6300,7 +6360,9 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
insertBlock(loopHead);
// Emit the body of the loop
+ IRBlock* prevScopeEndBlock = pushScopeBlock(continueLabel);
lowerStmt(context, stmt->statement);
+ popScopeBlock(prevScopeEndBlock, true);
insertBlock(testLabel);
@@ -6429,10 +6491,12 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
void visitBlockStmt(BlockStmt* stmt)
{
- // To lower a block (scope) statement,
- // just lower its body. The IR doesn't
- // need to reflect the scoping of the AST.
+ IRBlock* prevScopeEndBlock = pushScopeBlock(nullptr);
+
+ // To lower a block (scope) statement, just lower its body.
lowerStmt(context, stmt->body);
+
+ popScopeBlock(prevScopeEndBlock, false);
}
void visitReturnStmt(ReturnStmt* stmt)
@@ -6523,6 +6587,29 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
}
}
+ void visitDeferStmt(DeferStmt* stmt)
+ {
+ auto builder = getBuilder();
+ startBlockIfNeeded(stmt);
+
+ IRBlock* deferBlock = builder->createBlock();
+ IRBlock* mergeBlock = builder->createBlock();
+
+ builder->emitDefer(deferBlock, mergeBlock, context->scopeEndBlock);
+
+ builder->insertBlock(deferBlock);
+ builder->setInsertInto(deferBlock);
+
+ IRBlock* prevScopeEndBlock = pushScopeBlock(mergeBlock);
+ lowerStmt(context, stmt->statement);
+ popScopeBlock(prevScopeEndBlock, true);
+
+ builder->emitBranch(mergeBlock);
+
+ builder->insertBlock(mergeBlock);
+ builder->setInsertInto(mergeBlock);
+ }
+
void visitDiscardStmt(DiscardStmt* stmt)
{
startBlockIfNeeded(stmt);
@@ -11714,6 +11801,9 @@ RefPtr<IRModule> generateIRForTranslationUnit(
// normal `call` + `ifElse`, etc.
lowerErrorHandling(module, compileRequest->getSink());
+ // Lower `defer` so that later passes need not be aware of it.
+ lowerDefer(module, compileRequest->getSink());
+
// Synthesize some code we want to make sure is inlined and simplified
synthesizeBitFieldAccessors(module);
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 1b862de77..97777c9fe 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -215,6 +215,7 @@ public:
BreakStmt* ParseBreakStatement();
ContinueStmt* ParseContinueStatement();
ReturnStmt* ParseReturnStatement();
+ DeferStmt* ParseDeferStatement();
ExpressionStmt* ParseExpressionStatement();
Expr* ParseExpression(Precedence level = Precedence::Comma);
@@ -5770,6 +5771,10 @@ Stmt* Parser::ParseStatement(Stmt* parentStmt)
{
statement = parseCompileTimeStmt(this);
}
+ else if (LookAheadToken("defer"))
+ {
+ statement = ParseDeferStatement();
+ }
else if (LookAheadToken("try"))
{
statement = ParseExpressionStatement();
@@ -6299,6 +6304,15 @@ ReturnStmt* Parser::ParseReturnStatement()
return returnStatement;
}
+DeferStmt* Parser::ParseDeferStatement()
+{
+ DeferStmt* deferStatement = astBuilder->create<DeferStmt>();
+ FillPosition(deferStatement);
+ ReadToken("defer");
+ deferStatement->statement = ParseStatement();
+ return deferStatement;
+}
+
ExpressionStmt* Parser::ParseExpressionStatement()
{
ExpressionStmt* statement = astBuilder->create<ExpressionStmt>();