diff options
| author | Julius Ikkala <julius.ikkala@gmail.com> | 2025-04-07 06:08:29 +0300 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-06 20:08:29 -0700 |
| commit | 1b82501dd0c74347cda4a2c7fe5a84fd610bb485 (patch) | |
| tree | f283a491e0545aa6b890a988ac9fb14f192b4663 /source | |
| parent | 680fb0b4e9cbb65d46677183a3f68630be1f6179 (diff) | |
Add defer statement (#6619)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-iterator.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ast-stmt.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-check-stmt.cpp | 29 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-defer.cpp | 258 | ||||
| -rw-r--r-- | source/slang/slang-ir-lower-defer.h | 18 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 16 | ||||
| -rw-r--r-- | source/slang/slang-language-server-ast-lookup.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 96 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 14 |
14 files changed, 468 insertions, 7 deletions
diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index 436e97c1b..c7da945f2 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -430,6 +430,12 @@ struct ASTIterator iterator->visitExpr(stmt->expression); } + void visitDeferStmt(DeferStmt* stmt) + { + iterator->maybeDispatchCallback(stmt); + dispatchIfNotNull(stmt->statement); + } + void visitWhileStmt(WhileStmt* stmt) { iterator->maybeDispatchCallback(stmt); diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h index 09c491287..f6580ba21 100644 --- a/source/slang/slang-ast-stmt.h +++ b/source/slang/slang-ast-stmt.h @@ -253,6 +253,13 @@ class ReturnStmt : public Stmt Expr* expression = nullptr; }; +class DeferStmt : public Stmt +{ + SLANG_AST_CLASS(DeferStmt) + + Stmt* statement = nullptr; +}; + class ExpressionStmt : public Stmt { SLANG_AST_CLASS(ExpressionStmt) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 34143ad75..4ab909118 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -650,6 +650,8 @@ struct SemanticsDeclReferenceVisitor : public SemanticsDeclVisitorBase, void visitReturnStmt(ReturnStmt* stmt) { dispatchIfNotNull(stmt->expression); } + void visitDeferStmt(DeferStmt* stmt) { dispatchIfNotNull(stmt->statement); } + void visitWhileStmt(WhileStmt* stmt) { dispatchIfNotNull(stmt->predicate); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 95716744c..f7681ba45 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2976,7 +2976,7 @@ struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitor<SemanticsStmt void checkStmt(Stmt* stmt); template<typename T> - T* FindOuterStmt(); + T* FindOuterStmt(Stmt* searchUntil = nullptr); Stmt* findOuterStmtWithLabel(Name* label); @@ -3020,6 +3020,8 @@ struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitor<SemanticsStmt void visitReturnStmt(ReturnStmt* stmt); + void visitDeferStmt(DeferStmt* stmt); + void visitWhileStmt(WhileStmt* stmt); void visitGpuForeachStmt(GpuForeachStmt* stmt); diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index db6f00d23..c85eb7593 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -119,9 +119,10 @@ void SemanticsStmtVisitor::checkStmt(Stmt* stmt) } template<typename T> -T* SemanticsStmtVisitor::FindOuterStmt() +T* SemanticsStmtVisitor::FindOuterStmt(Stmt* searchUntil) { - for (auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next) + for (auto outerStmtInfo = m_outerStmts; outerStmtInfo && outerStmtInfo->stmt != searchUntil; + outerStmtInfo = outerStmtInfo->next) { auto outerStmt = outerStmtInfo->stmt; auto found = as<T>(outerStmt); @@ -178,6 +179,14 @@ void SemanticsStmtVisitor::visitBreakStmt(BreakStmt* stmt) getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop); } } + + // If there is a defer statement before the breakable statement, it's + // illegal. + if (FindOuterStmt<DeferStmt>(targetStmt)) + { + getSink()->diagnose(stmt, Diagnostics::breakInsideDefer); + } + stmt->parentStmt = targetStmt; } @@ -188,6 +197,11 @@ void SemanticsStmtVisitor::visitContinueStmt(ContinueStmt* stmt) { getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop); } + + if (FindOuterStmt<DeferStmt>(outer)) + { + getSink()->diagnose(stmt, Diagnostics::continueInsideDefer); + } stmt->parentStmt = outer; } @@ -497,6 +511,11 @@ void SemanticsStmtVisitor::visitReturnStmt(ReturnStmt* stmt) } } } + + if (FindOuterStmt<DeferStmt>()) + { + getSink()->diagnose(stmt, Diagnostics::returnInsideDefer); + } } void SemanticsStmtVisitor::visitWhileStmt(WhileStmt* stmt) @@ -508,6 +527,12 @@ void SemanticsStmtVisitor::visitWhileStmt(WhileStmt* stmt) checkLoopInDifferentiableFunc(stmt); } +void SemanticsStmtVisitor::visitDeferStmt(DeferStmt* stmt) +{ + WithOuterStmt subContext(this, stmt); + subContext.checkStmt(stmt->statement); +} + void SemanticsStmtVisitor::visitExpressionStmt(ExpressionStmt* stmt) { stmt->expression = CheckExpr(stmt->expression); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index b7469deb0..f2c7fecc1 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -890,6 +890,14 @@ DIAGNOSTIC( DIAGNOSTIC(30106, Error, improperUseOfType, "type '$0' cannot be used in this context.") DIAGNOSTIC(30107, Error, parameterPackMustBeConst, "a parameter pack must be declared as 'const'.") +DIAGNOSTIC(30108, Error, breakInsideDefer, "'break' must not appear inside a defer statement.") +DIAGNOSTIC( + 30109, + Error, + continueInsideDefer, + "'continue' must not appear inside a defer statement.") +DIAGNOSTIC(30110, Error, returnInsideDefer, "'return' must not appear inside a defer statement.") + // Include DIAGNOSTIC( diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index bd590f08f..c7ed5affe 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -661,7 +661,9 @@ INST(MissingReturn, missingReturn, 0, 0) INST(Unreachable, unreachable, 0, 0) INST_RANGE(Unreachable, MissingReturn, Unreachable) -INST_RANGE(TerminatorInst, Return, Unreachable) +INST(Defer, defer, 3, 0) + +INST_RANGE(TerminatorInst, Return, Defer) INST(discard, discard, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 51608ebc9..645b13e00 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2852,6 +2852,15 @@ struct IRTryCall : IRTerminatorInst IRInst* getArg(UInt index) { return getOperand(index + 3); } }; +struct IRDefer : IRTerminatorInst +{ + IR_LEAF_ISA(Defer); + + IRBlock* getDeferBlock() { return cast<IRBlock>(getOperand(0)); } + IRBlock* getMergeBlock() { return cast<IRBlock>(getOperand(1)); } + IRBlock* getScopeBlock() { return cast<IRBlock>(getOperand(2)); } +}; + struct IRSwizzle : IRInst { IR_LEAF_ISA(swizzle); @@ -4462,6 +4471,8 @@ public: IRInst* emitThrow(IRInst* val); + IRInst* emitDefer(IRBlock* deferBlock, IRBlock* mergeBlock, IRBlock* scopeEndBlock); + IRInst* emitDiscard(); IRInst* emitCheckpointObject(IRInst* value); diff --git a/source/slang/slang-ir-lower-defer.cpp b/source/slang/slang-ir-lower-defer.cpp new file mode 100644 index 000000000..4e1ec9a8b --- /dev/null +++ b/source/slang/slang-ir-lower-defer.cpp @@ -0,0 +1,258 @@ +// slang-ir-lower-defer.cpp + +#include "slang-ir-lower-defer.h" + +#include "slang-ir-clone.h" +#include "slang-ir-dominators.h" +#include "slang-ir-inst-pass-base.h" +#include "slang-ir-insts.h" +#include "slang-ir.h" + +namespace Slang +{ + +struct DeferLoweringContext : InstPassBase +{ + DiagnosticSink* diagnosticSink; + + DeferLoweringContext(IRModule* inModule) + : InstPassBase(inModule) + { + } + + void inlineSingleBlockDefer(IRInst* beforeInst, IRBlock* deferBlock, IRBuilder* builder) + { + builder->setInsertBefore(beforeInst); + IRCloneEnv env; + for (IRInst* inst : deferBlock->getChildren()) + { + // Copy everything except the terminator; the terminator should only + // be a jump to mergeBlock, which isn't needed after inlining. + if (!as<IRTerminatorInst>(inst)) + cloneInst(&env, builder, inst); + } + } + + // Returns the new last block. + IRBlock* inlineDefer( + IRInst* beforeInst, + IRBlock* targetBlock, + const List<IRBlock*>& deferBlocks, + IRBlock* mergeBlock, + IRBuilder* builder) + { + // The single-block inlining case is simple, we can just dump the + // instructions at the target position, in the existing block. + if (deferBlocks.getCount() == 1) + { + inlineSingleBlockDefer(beforeInst, deferBlocks.getFirst(), builder); + return targetBlock; + } + + // Otherwise, we'll have to splice the blocks in. + IRCloneEnv env; + builder->setInsertAfter(targetBlock); + auto lastBlock = targetBlock; + + // Clone blocks first + for (auto block : deferBlocks) + { + auto clonedBlock = builder->createBlock(); + builder->addInst(clonedBlock); + env.mapOldValToNew[block] = clonedBlock; + } + + // Then, clone instructions, but mapping old blocks to new blocks. + for (auto block : deferBlocks) + { + auto clonedBlock = as<IRBlock>(env.mapOldValToNew.getValue(block)); + builder->setInsertInto(clonedBlock); + for (auto inst : block->getChildren()) + { + auto endBranch = as<IRUnconditionalBranch>(inst); + if (endBranch && endBranch->getTargetBlock() == mergeBlock) + { + lastBlock = clonedBlock; + } + else + cloneInst(&env, builder, inst); + } + } + + // Move old instructions to the last block's end. The last defer block + // shouldn't have a terminator at this point yet. + while (beforeInst) + { + auto nextInst = beforeInst->getNextInst(); + beforeInst->insertAtEnd(lastBlock); + beforeInst = nextInst; + } + + // Make target block jump to the cloned blocks. + builder->setInsertInto(targetBlock); + auto mainBlock = as<IRBlock>(env.mapOldValToNew.getValue(deferBlocks[0])); + builder->emitBranch(mainBlock); + + return lastBlock; + } + + HashSet<IRBlock*> findSuccessorBlocks(IRGlobalValueWithCode* func, IRBlock* block) + { + HashSet<IRBlock*> successorBlocksSet; + List<IRBlock*> successorWorkList; + successorWorkList.add(block); + + List<IRBlock*> postorder = getPostorder(func); + Index limitIndex = postorder.indexOf(block); + while (successorWorkList.getCount() > 0) + { + IRBlock* predecessor = successorWorkList.getLast(); + successorWorkList.removeLast(); + if (successorBlocksSet.contains(predecessor)) + continue; + + Index predecessorIndex = postorder.indexOf(predecessor); + // Does not succeed if it is after the given block in postorder. + if (predecessorIndex > limitIndex) + continue; + + successorBlocksSet.add(predecessor); + for (IRBlock* successor : predecessor->getSuccessors()) + successorWorkList.add(successor); + } + return successorBlocksSet; + } + + void processFunc(IRGlobalValueWithCode* func) + { + // Iterating over `defer` instructions in reverse order allows us to + // expand them in the correct order, including nested `defer`s. + // We also use this to determine scope extents. + List<IRBlock*> reverseBlocks = getReversePostorderOnReverseCFG(func); + List<IRDefer*> unhandledDefers; + + for (IRBlock* block : reverseBlocks) + { + for (auto child = block->getLastChild(); child; child = child->getPrevInst()) + { + if (auto defer = as<IRDefer>(child)) + unhandledDefers.add(defer); + } + } + + IRBuilder builder(module); + Dictionary<IRBlock*, IRBlock*> mapOldScopeToNew; + for (IRDefer* defer : unhandledDefers) + { + IRBlock* firstDeferBlock = defer->getDeferBlock(); + IRBlock* mergeBlock = defer->getMergeBlock(); + IRBlock* scopeEndBlock = defer->getScopeBlock(); + mapOldScopeToNew.tryGetValue(scopeEndBlock, scopeEndBlock); + IRBlock* parentBlock = as<IRBlock>(defer->getParent()); + + // The dominator tree gets invalidated on every iteration, so it's + // necessary to construct it inside the loop. + auto dom = module->findOrCreateDominatorTree(func); + + // Enumerate defer block range. That is, all blocks dominated by + // parentBlock and not dominated by mergeBlock. + auto deferDominatedBlocks = dom->getProperlyDominatedBlocks(firstDeferBlock); + List<IRBlock*> deferBlocks; + deferBlocks.add(firstDeferBlock); + for (IRBlock* block : deferDominatedBlocks) + { + if (!dom->properlyDominates(mergeBlock, block) && block != mergeBlock) + deferBlocks.add(block); + } + + auto dominatedBlocks = dom->getProperlyDominatedBlocks(mergeBlock); + + + HashSet<IRBlock*> scopeSuccessorBlocksSet = findSuccessorBlocks(func, scopeEndBlock); + HashSet<IRBlock*> scopeBlocksSet; + scopeBlocksSet.add(mergeBlock); + for (IRBlock* block : dominatedBlocks) + { + if (!scopeSuccessorBlocksSet.contains(block)) + scopeBlocksSet.add(block); + } + + // All jumps from blocks in scope to blocks out of scope are to be + // preceded by a copy of the deferBlocks. + for (IRBlock* block : scopeBlocksSet) + { + auto terminator = block->getTerminator(); + SLANG_ASSERT(terminator); + bool exits = false; + switch (terminator->getOp()) + { + case kIROp_Return: + case kIROp_discard: + case kIROp_Throw: + exits = true; + break; + case kIROp_unconditionalBranch: + { + auto targetBlock = as<IRBlock>(terminator->getOperand(0)); + if (!scopeBlocksSet.contains(targetBlock)) + { + exits = true; + } + } + break; + case kIROp_conditionalBranch: + { + auto trueBlock = as<IRBlock>(terminator->getOperand(1)); + auto falseBlock = as<IRBlock>(terminator->getOperand(2)); + if (!scopeBlocksSet.contains(trueBlock) || + !scopeBlocksSet.contains(falseBlock)) + { + exits = true; + } + } + break; + default: + break; + } + + if (exits) + { // Duplicate child instructions to the end of this block. + auto newEnd = inlineDefer(terminator, block, deferBlocks, mergeBlock, &builder); + if (newEnd != block) + { + mapOldScopeToNew[block] = newEnd; + } + } + } + + // Replace defer with unconditional branch to mergeBlock. Defer + // blocks should now be orphaned, and we can remove them too. + defer->removeAndDeallocate(); + builder.setInsertInto(parentBlock); + builder.emitBranch(mergeBlock); + + for (IRBlock* deferBlock : deferBlocks) + { + deferBlock->removeAndDeallocate(); + } + + // Some blocks got removed and added, so mark analysis of the + // function with defer as outdated. + module->invalidateAnalysisForInst(func); + } + } + + void processModule() + { + processInstsOfType<IRFunc>(kIROp_Func, [&](IRFunc* func) { processFunc(func); }); + } +}; + +void lowerDefer(IRModule* module, DiagnosticSink* sink) +{ + DeferLoweringContext context(module); + context.diagnosticSink = sink; + return context.processModule(); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-lower-defer.h b/source/slang/slang-ir-lower-defer.h new file mode 100644 index 000000000..c00104745 --- /dev/null +++ b/source/slang/slang-ir-lower-defer.h @@ -0,0 +1,18 @@ +// slang-ir-lower-defer.h +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ +struct IRModule; +class DiagnosticSink; + +/// Lower the `defer` statements. +/// +/// Duplicates the child instructions of each `defer` to the end of each +/// dominated block whose terminator jumps to a location that is not dominated +/// by the `defer`. Also removes all `IRDefer` instructions after that. +void lowerDefer(IRModule* module, DiagnosticSink* sink); + +} // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 1287d1598..f75fe2f48 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -596,6 +596,13 @@ static IRBlock::SuccessorList getSuccessors(IRInst* terminator) end = operands + terminator->getOperandCount() + 1; stride = 2; break; + + case kIROp_Defer: + // defer <deferBlock> <mergeBlock> <scopeEndBlock> + begin = operands + 0; + end = begin + 1; + break; + default: SLANG_UNEXPECTED("unhandled terminator instruction"); UNREACHABLE_RETURN(IRBlock::SuccessorList(nullptr, nullptr)); @@ -869,6 +876,7 @@ bool isTerminatorInst(IROp op) case kIROp_Switch: case kIROp_Unreachable: case kIROp_MissingReturn: + case kIROp_Defer: return true; } } @@ -5698,6 +5706,14 @@ IRInst* IRBuilder::emitReturn() return inst; } +IRInst* IRBuilder::emitDefer(IRBlock* deferBlock, IRBlock* mergeBlock, IRBlock* scopeEndBlock) +{ + auto inst = + createInst<IRDefer>(this, kIROp_Defer, nullptr, deferBlock, mergeBlock, scopeEndBlock); + addInst(inst); + return inst; +} + IRInst* IRBuilder::emitThrow(IRInst* val) { auto inst = createInst<IRThrow>(this, kIROp_Throw, nullptr, val); diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index ebc3fe9ad..7375756f5 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -642,6 +642,8 @@ struct ASTLookupStmtVisitor : public StmtVisitor<ASTLookupStmtVisitor, bool> bool visitReturnStmt(ReturnStmt* stmt) { return checkExpr(stmt->expression); } + bool visitDeferStmt(DeferStmt* stmt) { return dispatchIfNotNull(stmt->statement); } + bool visitWhileStmt(WhileStmt* stmt) { if (checkExpr(stmt->predicate)) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index e3c4ddf05..e6ec68660 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -19,6 +19,7 @@ #include "slang-ir-insert-debug-value-store.h" #include "slang-ir-insts.h" #include "slang-ir-loop-inversion.h" +#include "slang-ir-lower-defer.h" #include "slang-ir-lower-error-handling.h" #include "slang-ir-lower-expand-type.h" #include "slang-ir-missing-return.h" @@ -593,6 +594,9 @@ struct IRGenContext // The element index if we are inside an `expand` expression. IRInst* expandIndex = nullptr; + // The current scope end for use with `defer`. + IRBlock* scopeEndBlock = nullptr; + // Callback function to call when after lowering a type. std::function<IRType*(IRGenContext* context, Type* type, IRType* irType)> lowerTypeCallback = nullptr; @@ -6021,6 +6025,53 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> startBlock(); } + /// Create a new scope end block and return the previous one. + /// + /// This is needed for `defer` to be aware of scopes. `preallocated` can + /// be specified if you already have a block at the end of the scope, like + /// in `for` loops. + IRBlock* pushScopeBlock(IRBlock* preallocated = nullptr) + { + IRBlock* prevScopeEndBlock = context->scopeEndBlock; + + auto builder = getBuilder(); + context->scopeEndBlock = preallocated ? preallocated : builder->createBlock(); + return prevScopeEndBlock; + } + + /// Pop the current scope end block and restore the previous one. + /// + /// This is needed for `defer` to be aware of scopes. `previous` should be + /// the block returned from the corresponding pushScopeBlock. `preallocated` + /// should be true if the corresponding pushScopeBlock was given a block + /// as a parameter. + void popScopeBlock(IRBlock* previous, bool preallocated) + { + if (!preallocated) + { + // If pushScopeBlock actually created the block, we have to insert + // or deallocate it here. Otherwise, we assume that the caller + // handles the end block. + auto builder = getBuilder(); + if (context->scopeEndBlock->hasUses()) + { + // The end of the scope was referenced, so we need to actually + // keep it around and jump through it. + // Move the terminator to the scope end block. + emitBranchIfNeeded(context->scopeEndBlock); + builder->insertBlock(context->scopeEndBlock); + builder->setInsertInto(context->scopeEndBlock); + } + else + { + // Scope end block was left unused, so we may as well delete it. + context->scopeEndBlock->removeAndDeallocate(); + } + } + + context->scopeEndBlock = previous; + } + void visitIfStmt(IfStmt* stmt) { auto builder = getBuilder(); @@ -6043,11 +6094,13 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> ifInst = builder->emitIfElse(irCond, thenBlock, elseBlock, afterBlock); insertBlock(thenBlock); + IRBlock* prevScopeEndBlock = pushScopeBlock(afterBlock); lowerStmt(context, thenStmt); emitBranchIfNeeded(afterBlock); insertBlock(elseBlock); lowerStmt(context, elseStmt); + popScopeBlock(prevScopeEndBlock, true); insertBlock(afterBlock); } @@ -6059,7 +6112,10 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> ifInst = builder->emitIf(irCond, thenBlock, afterBlock); insertBlock(thenBlock); + + IRBlock* prevScopeEndBlock = pushScopeBlock(afterBlock); lowerStmt(context, thenStmt); + popScopeBlock(prevScopeEndBlock, true); insertBlock(afterBlock); } @@ -6150,7 +6206,9 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // Emit the body of the loop insertBlock(bodyLabel); + IRBlock* prevScopeEndBlock = pushScopeBlock(continueLabel); lowerStmt(context, stmt->statement); + popScopeBlock(prevScopeEndBlock, true); if (auto inferredMaxIters = stmt->findModifier<InferredMaxItersAttribute>()) { @@ -6256,7 +6314,9 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // Emit the body of the loop insertBlock(bodyLabel); + IRBlock* prevScopeEndBlock = pushScopeBlock(continueLabel); lowerStmt(context, stmt->statement); + popScopeBlock(prevScopeEndBlock, true); // At the end of the body we need to jump back to the top. emitBranchIfNeeded(loopHead); @@ -6300,7 +6360,9 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> insertBlock(loopHead); // Emit the body of the loop + IRBlock* prevScopeEndBlock = pushScopeBlock(continueLabel); lowerStmt(context, stmt->statement); + popScopeBlock(prevScopeEndBlock, true); insertBlock(testLabel); @@ -6429,10 +6491,12 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> void visitBlockStmt(BlockStmt* stmt) { - // To lower a block (scope) statement, - // just lower its body. The IR doesn't - // need to reflect the scoping of the AST. + IRBlock* prevScopeEndBlock = pushScopeBlock(nullptr); + + // To lower a block (scope) statement, just lower its body. lowerStmt(context, stmt->body); + + popScopeBlock(prevScopeEndBlock, false); } void visitReturnStmt(ReturnStmt* stmt) @@ -6523,6 +6587,29 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> } } + void visitDeferStmt(DeferStmt* stmt) + { + auto builder = getBuilder(); + startBlockIfNeeded(stmt); + + IRBlock* deferBlock = builder->createBlock(); + IRBlock* mergeBlock = builder->createBlock(); + + builder->emitDefer(deferBlock, mergeBlock, context->scopeEndBlock); + + builder->insertBlock(deferBlock); + builder->setInsertInto(deferBlock); + + IRBlock* prevScopeEndBlock = pushScopeBlock(mergeBlock); + lowerStmt(context, stmt->statement); + popScopeBlock(prevScopeEndBlock, true); + + builder->emitBranch(mergeBlock); + + builder->insertBlock(mergeBlock); + builder->setInsertInto(mergeBlock); + } + void visitDiscardStmt(DiscardStmt* stmt) { startBlockIfNeeded(stmt); @@ -11714,6 +11801,9 @@ RefPtr<IRModule> generateIRForTranslationUnit( // normal `call` + `ifElse`, etc. lowerErrorHandling(module, compileRequest->getSink()); + // Lower `defer` so that later passes need not be aware of it. + lowerDefer(module, compileRequest->getSink()); + // Synthesize some code we want to make sure is inlined and simplified synthesizeBitFieldAccessors(module); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 1b862de77..97777c9fe 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -215,6 +215,7 @@ public: BreakStmt* ParseBreakStatement(); ContinueStmt* ParseContinueStatement(); ReturnStmt* ParseReturnStatement(); + DeferStmt* ParseDeferStatement(); ExpressionStmt* ParseExpressionStatement(); Expr* ParseExpression(Precedence level = Precedence::Comma); @@ -5770,6 +5771,10 @@ Stmt* Parser::ParseStatement(Stmt* parentStmt) { statement = parseCompileTimeStmt(this); } + else if (LookAheadToken("defer")) + { + statement = ParseDeferStatement(); + } else if (LookAheadToken("try")) { statement = ParseExpressionStatement(); @@ -6299,6 +6304,15 @@ ReturnStmt* Parser::ParseReturnStatement() return returnStatement; } +DeferStmt* Parser::ParseDeferStatement() +{ + DeferStmt* deferStatement = astBuilder->create<DeferStmt>(); + FillPosition(deferStatement); + ReadToken("defer"); + deferStatement->statement = ParseStatement(); + return deferStatement; +} + ExpressionStmt* Parser::ParseExpressionStatement() { ExpressionStmt* statement = astBuilder->create<ExpressionStmt>(); |
