diff options
| author | Julius Ikkala <julius.ikkala@gmail.com> | 2025-04-07 06:08:29 +0300 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-06 20:08:29 -0700 |
| commit | 1b82501dd0c74347cda4a2c7fe5a84fd610bb485 (patch) | |
| tree | f283a491e0545aa6b890a988ac9fb14f192b4663 | |
| parent | 680fb0b4e9cbb65d46677183a3f68630be1f6179 (diff) | |
Add defer statement (#6619)
35 files changed, 910 insertions, 7 deletions
diff --git a/docs/user-guide/02-conventional-features.md b/docs/user-guide/02-conventional-features.md index aaeea4114..5c15986f8 100644 --- a/docs/user-guide/02-conventional-features.md +++ b/docs/user-guide/02-conventional-features.md @@ -435,6 +435,12 @@ Slang supports the following statement forms with nearly identical syntax to HLS * `return` statements +* `defer` statements + +> #### Note #### +> The `defer` statement in Slang is tied to scope. The deferred statement runs at the end of the scope like in Swift, not just at the end of the function like in Go. +> `defer` supports but does not require block statements: both `defer f();` and `defer { f(); g(); }` are legal. + > #### Note #### > Slang does not support the C/C++ `goto` keyword. diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index 436e97c1b..c7da945f2 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -430,6 +430,12 @@ struct ASTIterator iterator->visitExpr(stmt->expression); } + void visitDeferStmt(DeferStmt* stmt) + { + iterator->maybeDispatchCallback(stmt); + dispatchIfNotNull(stmt->statement); + } + void visitWhileStmt(WhileStmt* stmt) { iterator->maybeDispatchCallback(stmt); diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h index 09c491287..f6580ba21 100644 --- a/source/slang/slang-ast-stmt.h +++ b/source/slang/slang-ast-stmt.h @@ -253,6 +253,13 @@ class ReturnStmt : public Stmt Expr* expression = nullptr; }; +class DeferStmt : public Stmt +{ + SLANG_AST_CLASS(DeferStmt) + + Stmt* statement = nullptr; +}; + class ExpressionStmt : public Stmt { SLANG_AST_CLASS(ExpressionStmt) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 34143ad75..4ab909118 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -650,6 +650,8 @@ struct SemanticsDeclReferenceVisitor : public SemanticsDeclVisitorBase, void visitReturnStmt(ReturnStmt* stmt) { dispatchIfNotNull(stmt->expression); } + void visitDeferStmt(DeferStmt* stmt) { dispatchIfNotNull(stmt->statement); } + void visitWhileStmt(WhileStmt* stmt) { dispatchIfNotNull(stmt->predicate); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 95716744c..f7681ba45 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2976,7 +2976,7 @@ struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitor<SemanticsStmt void checkStmt(Stmt* stmt); template<typename T> - T* FindOuterStmt(); + T* FindOuterStmt(Stmt* searchUntil = nullptr); Stmt* findOuterStmtWithLabel(Name* label); @@ -3020,6 +3020,8 @@ struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitor<SemanticsStmt void visitReturnStmt(ReturnStmt* stmt); + void visitDeferStmt(DeferStmt* stmt); + void visitWhileStmt(WhileStmt* stmt); void visitGpuForeachStmt(GpuForeachStmt* stmt); diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index db6f00d23..c85eb7593 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -119,9 +119,10 @@ void SemanticsStmtVisitor::checkStmt(Stmt* stmt) } template<typename T> -T* SemanticsStmtVisitor::FindOuterStmt() +T* SemanticsStmtVisitor::FindOuterStmt(Stmt* searchUntil) { - for (auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next) + for (auto outerStmtInfo = m_outerStmts; outerStmtInfo && outerStmtInfo->stmt != searchUntil; + outerStmtInfo = outerStmtInfo->next) { auto outerStmt = outerStmtInfo->stmt; auto found = as<T>(outerStmt); @@ -178,6 +179,14 @@ void SemanticsStmtVisitor::visitBreakStmt(BreakStmt* stmt) getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop); } } + + // If there is a defer statement before the breakable statement, it's + // illegal. + if (FindOuterStmt<DeferStmt>(targetStmt)) + { + getSink()->diagnose(stmt, Diagnostics::breakInsideDefer); + } + stmt->parentStmt = targetStmt; } @@ -188,6 +197,11 @@ void SemanticsStmtVisitor::visitContinueStmt(ContinueStmt* stmt) { getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop); } + + if (FindOuterStmt<DeferStmt>(outer)) + { + getSink()->diagnose(stmt, Diagnostics::continueInsideDefer); + } stmt->parentStmt = outer; } @@ -497,6 +511,11 @@ void SemanticsStmtVisitor::visitReturnStmt(ReturnStmt* stmt) } } } + + if (FindOuterStmt<DeferStmt>()) + { + getSink()->diagnose(stmt, Diagnostics::returnInsideDefer); + } } void SemanticsStmtVisitor::visitWhileStmt(WhileStmt* stmt) @@ -508,6 +527,12 @@ void SemanticsStmtVisitor::visitWhileStmt(WhileStmt* stmt) checkLoopInDifferentiableFunc(stmt); } +void SemanticsStmtVisitor::visitDeferStmt(DeferStmt* stmt) +{ + WithOuterStmt subContext(this, stmt); + subContext.checkStmt(stmt->statement); +} + void SemanticsStmtVisitor::visitExpressionStmt(ExpressionStmt* stmt) { stmt->expression = CheckExpr(stmt->expression); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index b7469deb0..f2c7fecc1 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -890,6 +890,14 @@ DIAGNOSTIC( DIAGNOSTIC(30106, Error, improperUseOfType, "type '$0' cannot be used in this context.") DIAGNOSTIC(30107, Error, parameterPackMustBeConst, "a parameter pack must be declared as 'const'.") +DIAGNOSTIC(30108, Error, breakInsideDefer, "'break' must not appear inside a defer statement.") +DIAGNOSTIC( + 30109, + Error, + continueInsideDefer, + "'continue' must not appear inside a defer statement.") +DIAGNOSTIC(30110, Error, returnInsideDefer, "'return' must not appear inside a defer statement.") + // Include DIAGNOSTIC( diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index bd590f08f..c7ed5affe 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -661,7 +661,9 @@ INST(MissingReturn, missingReturn, 0, 0) INST(Unreachable, unreachable, 0, 0) INST_RANGE(Unreachable, MissingReturn, Unreachable) -INST_RANGE(TerminatorInst, Return, Unreachable) +INST(Defer, defer, 3, 0) + +INST_RANGE(TerminatorInst, Return, Defer) INST(discard, discard, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 51608ebc9..645b13e00 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2852,6 +2852,15 @@ struct IRTryCall : IRTerminatorInst IRInst* getArg(UInt index) { return getOperand(index + 3); } }; +struct IRDefer : IRTerminatorInst +{ + IR_LEAF_ISA(Defer); + + IRBlock* getDeferBlock() { return cast<IRBlock>(getOperand(0)); } + IRBlock* getMergeBlock() { return cast<IRBlock>(getOperand(1)); } + IRBlock* getScopeBlock() { return cast<IRBlock>(getOperand(2)); } +}; + struct IRSwizzle : IRInst { IR_LEAF_ISA(swizzle); @@ -4462,6 +4471,8 @@ public: IRInst* emitThrow(IRInst* val); + IRInst* emitDefer(IRBlock* deferBlock, IRBlock* mergeBlock, IRBlock* scopeEndBlock); + IRInst* emitDiscard(); IRInst* emitCheckpointObject(IRInst* value); diff --git a/source/slang/slang-ir-lower-defer.cpp b/source/slang/slang-ir-lower-defer.cpp new file mode 100644 index 000000000..4e1ec9a8b --- /dev/null +++ b/source/slang/slang-ir-lower-defer.cpp @@ -0,0 +1,258 @@ +// slang-ir-lower-defer.cpp + +#include "slang-ir-lower-defer.h" + +#include "slang-ir-clone.h" +#include "slang-ir-dominators.h" +#include "slang-ir-inst-pass-base.h" +#include "slang-ir-insts.h" +#include "slang-ir.h" + +namespace Slang +{ + +struct DeferLoweringContext : InstPassBase +{ + DiagnosticSink* diagnosticSink; + + DeferLoweringContext(IRModule* inModule) + : InstPassBase(inModule) + { + } + + void inlineSingleBlockDefer(IRInst* beforeInst, IRBlock* deferBlock, IRBuilder* builder) + { + builder->setInsertBefore(beforeInst); + IRCloneEnv env; + for (IRInst* inst : deferBlock->getChildren()) + { + // Copy everything except the terminator; the terminator should only + // be a jump to mergeBlock, which isn't needed after inlining. + if (!as<IRTerminatorInst>(inst)) + cloneInst(&env, builder, inst); + } + } + + // Returns the new last block. + IRBlock* inlineDefer( + IRInst* beforeInst, + IRBlock* targetBlock, + const List<IRBlock*>& deferBlocks, + IRBlock* mergeBlock, + IRBuilder* builder) + { + // The single-block inlining case is simple, we can just dump the + // instructions at the target position, in the existing block. + if (deferBlocks.getCount() == 1) + { + inlineSingleBlockDefer(beforeInst, deferBlocks.getFirst(), builder); + return targetBlock; + } + + // Otherwise, we'll have to splice the blocks in. + IRCloneEnv env; + builder->setInsertAfter(targetBlock); + auto lastBlock = targetBlock; + + // Clone blocks first + for (auto block : deferBlocks) + { + auto clonedBlock = builder->createBlock(); + builder->addInst(clonedBlock); + env.mapOldValToNew[block] = clonedBlock; + } + + // Then, clone instructions, but mapping old blocks to new blocks. + for (auto block : deferBlocks) + { + auto clonedBlock = as<IRBlock>(env.mapOldValToNew.getValue(block)); + builder->setInsertInto(clonedBlock); + for (auto inst : block->getChildren()) + { + auto endBranch = as<IRUnconditionalBranch>(inst); + if (endBranch && endBranch->getTargetBlock() == mergeBlock) + { + lastBlock = clonedBlock; + } + else + cloneInst(&env, builder, inst); + } + } + + // Move old instructions to the last block's end. The last defer block + // shouldn't have a terminator at this point yet. + while (beforeInst) + { + auto nextInst = beforeInst->getNextInst(); + beforeInst->insertAtEnd(lastBlock); + beforeInst = nextInst; + } + + // Make target block jump to the cloned blocks. + builder->setInsertInto(targetBlock); + auto mainBlock = as<IRBlock>(env.mapOldValToNew.getValue(deferBlocks[0])); + builder->emitBranch(mainBlock); + + return lastBlock; + } + + HashSet<IRBlock*> findSuccessorBlocks(IRGlobalValueWithCode* func, IRBlock* block) + { + HashSet<IRBlock*> successorBlocksSet; + List<IRBlock*> successorWorkList; + successorWorkList.add(block); + + List<IRBlock*> postorder = getPostorder(func); + Index limitIndex = postorder.indexOf(block); + while (successorWorkList.getCount() > 0) + { + IRBlock* predecessor = successorWorkList.getLast(); + successorWorkList.removeLast(); + if (successorBlocksSet.contains(predecessor)) + continue; + + Index predecessorIndex = postorder.indexOf(predecessor); + // Does not succeed if it is after the given block in postorder. + if (predecessorIndex > limitIndex) + continue; + + successorBlocksSet.add(predecessor); + for (IRBlock* successor : predecessor->getSuccessors()) + successorWorkList.add(successor); + } + return successorBlocksSet; + } + + void processFunc(IRGlobalValueWithCode* func) + { + // Iterating over `defer` instructions in reverse order allows us to + // expand them in the correct order, including nested `defer`s. + // We also use this to determine scope extents. + List<IRBlock*> reverseBlocks = getReversePostorderOnReverseCFG(func); + List<IRDefer*> unhandledDefers; + + for (IRBlock* block : reverseBlocks) + { + for (auto child = block->getLastChild(); child; child = child->getPrevInst()) + { + if (auto defer = as<IRDefer>(child)) + unhandledDefers.add(defer); + } + } + + IRBuilder builder(module); + Dictionary<IRBlock*, IRBlock*> mapOldScopeToNew; + for (IRDefer* defer : unhandledDefers) + { + IRBlock* firstDeferBlock = defer->getDeferBlock(); + IRBlock* mergeBlock = defer->getMergeBlock(); + IRBlock* scopeEndBlock = defer->getScopeBlock(); + mapOldScopeToNew.tryGetValue(scopeEndBlock, scopeEndBlock); + IRBlock* parentBlock = as<IRBlock>(defer->getParent()); + + // The dominator tree gets invalidated on every iteration, so it's + // necessary to construct it inside the loop. + auto dom = module->findOrCreateDominatorTree(func); + + // Enumerate defer block range. That is, all blocks dominated by + // parentBlock and not dominated by mergeBlock. + auto deferDominatedBlocks = dom->getProperlyDominatedBlocks(firstDeferBlock); + List<IRBlock*> deferBlocks; + deferBlocks.add(firstDeferBlock); + for (IRBlock* block : deferDominatedBlocks) + { + if (!dom->properlyDominates(mergeBlock, block) && block != mergeBlock) + deferBlocks.add(block); + } + + auto dominatedBlocks = dom->getProperlyDominatedBlocks(mergeBlock); + + + HashSet<IRBlock*> scopeSuccessorBlocksSet = findSuccessorBlocks(func, scopeEndBlock); + HashSet<IRBlock*> scopeBlocksSet; + scopeBlocksSet.add(mergeBlock); + for (IRBlock* block : dominatedBlocks) + { + if (!scopeSuccessorBlocksSet.contains(block)) + scopeBlocksSet.add(block); + } + + // All jumps from blocks in scope to blocks out of scope are to be + // preceded by a copy of the deferBlocks. + for (IRBlock* block : scopeBlocksSet) + { + auto terminator = block->getTerminator(); + SLANG_ASSERT(terminator); + bool exits = false; + switch (terminator->getOp()) + { + case kIROp_Return: + case kIROp_discard: + case kIROp_Throw: + exits = true; + break; + case kIROp_unconditionalBranch: + { + auto targetBlock = as<IRBlock>(terminator->getOperand(0)); + if (!scopeBlocksSet.contains(targetBlock)) + { + exits = true; + } + } + break; + case kIROp_conditionalBranch: + { + auto trueBlock = as<IRBlock>(terminator->getOperand(1)); + auto falseBlock = as<IRBlock>(terminator->getOperand(2)); + if (!scopeBlocksSet.contains(trueBlock) || + !scopeBlocksSet.contains(falseBlock)) + { + exits = true; + } + } + break; + default: + break; + } + + if (exits) + { // Duplicate child instructions to the end of this block. + auto newEnd = inlineDefer(terminator, block, deferBlocks, mergeBlock, &builder); + if (newEnd != block) + { + mapOldScopeToNew[block] = newEnd; + } + } + } + + // Replace defer with unconditional branch to mergeBlock. Defer + // blocks should now be orphaned, and we can remove them too. + defer->removeAndDeallocate(); + builder.setInsertInto(parentBlock); + builder.emitBranch(mergeBlock); + + for (IRBlock* deferBlock : deferBlocks) + { + deferBlock->removeAndDeallocate(); + } + + // Some blocks got removed and added, so mark analysis of the + // function with defer as outdated. + module->invalidateAnalysisForInst(func); + } + } + + void processModule() + { + processInstsOfType<IRFunc>(kIROp_Func, [&](IRFunc* func) { processFunc(func); }); + } +}; + +void lowerDefer(IRModule* module, DiagnosticSink* sink) +{ + DeferLoweringContext context(module); + context.diagnosticSink = sink; + return context.processModule(); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-lower-defer.h b/source/slang/slang-ir-lower-defer.h new file mode 100644 index 000000000..c00104745 --- /dev/null +++ b/source/slang/slang-ir-lower-defer.h @@ -0,0 +1,18 @@ +// slang-ir-lower-defer.h +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ +struct IRModule; +class DiagnosticSink; + +/// Lower the `defer` statements. +/// +/// Duplicates the child instructions of each `defer` to the end of each +/// dominated block whose terminator jumps to a location that is not dominated +/// by the `defer`. Also removes all `IRDefer` instructions after that. +void lowerDefer(IRModule* module, DiagnosticSink* sink); + +} // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 1287d1598..f75fe2f48 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -596,6 +596,13 @@ static IRBlock::SuccessorList getSuccessors(IRInst* terminator) end = operands + terminator->getOperandCount() + 1; stride = 2; break; + + case kIROp_Defer: + // defer <deferBlock> <mergeBlock> <scopeEndBlock> + begin = operands + 0; + end = begin + 1; + break; + default: SLANG_UNEXPECTED("unhandled terminator instruction"); UNREACHABLE_RETURN(IRBlock::SuccessorList(nullptr, nullptr)); @@ -869,6 +876,7 @@ bool isTerminatorInst(IROp op) case kIROp_Switch: case kIROp_Unreachable: case kIROp_MissingReturn: + case kIROp_Defer: return true; } } @@ -5698,6 +5706,14 @@ IRInst* IRBuilder::emitReturn() return inst; } +IRInst* IRBuilder::emitDefer(IRBlock* deferBlock, IRBlock* mergeBlock, IRBlock* scopeEndBlock) +{ + auto inst = + createInst<IRDefer>(this, kIROp_Defer, nullptr, deferBlock, mergeBlock, scopeEndBlock); + addInst(inst); + return inst; +} + IRInst* IRBuilder::emitThrow(IRInst* val) { auto inst = createInst<IRThrow>(this, kIROp_Throw, nullptr, val); diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index ebc3fe9ad..7375756f5 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -642,6 +642,8 @@ struct ASTLookupStmtVisitor : public StmtVisitor<ASTLookupStmtVisitor, bool> bool visitReturnStmt(ReturnStmt* stmt) { return checkExpr(stmt->expression); } + bool visitDeferStmt(DeferStmt* stmt) { return dispatchIfNotNull(stmt->statement); } + bool visitWhileStmt(WhileStmt* stmt) { if (checkExpr(stmt->predicate)) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index e3c4ddf05..e6ec68660 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -19,6 +19,7 @@ #include "slang-ir-insert-debug-value-store.h" #include "slang-ir-insts.h" #include "slang-ir-loop-inversion.h" +#include "slang-ir-lower-defer.h" #include "slang-ir-lower-error-handling.h" #include "slang-ir-lower-expand-type.h" #include "slang-ir-missing-return.h" @@ -593,6 +594,9 @@ struct IRGenContext // The element index if we are inside an `expand` expression. IRInst* expandIndex = nullptr; + // The current scope end for use with `defer`. + IRBlock* scopeEndBlock = nullptr; + // Callback function to call when after lowering a type. std::function<IRType*(IRGenContext* context, Type* type, IRType* irType)> lowerTypeCallback = nullptr; @@ -6021,6 +6025,53 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> startBlock(); } + /// Create a new scope end block and return the previous one. + /// + /// This is needed for `defer` to be aware of scopes. `preallocated` can + /// be specified if you already have a block at the end of the scope, like + /// in `for` loops. + IRBlock* pushScopeBlock(IRBlock* preallocated = nullptr) + { + IRBlock* prevScopeEndBlock = context->scopeEndBlock; + + auto builder = getBuilder(); + context->scopeEndBlock = preallocated ? preallocated : builder->createBlock(); + return prevScopeEndBlock; + } + + /// Pop the current scope end block and restore the previous one. + /// + /// This is needed for `defer` to be aware of scopes. `previous` should be + /// the block returned from the corresponding pushScopeBlock. `preallocated` + /// should be true if the corresponding pushScopeBlock was given a block + /// as a parameter. + void popScopeBlock(IRBlock* previous, bool preallocated) + { + if (!preallocated) + { + // If pushScopeBlock actually created the block, we have to insert + // or deallocate it here. Otherwise, we assume that the caller + // handles the end block. + auto builder = getBuilder(); + if (context->scopeEndBlock->hasUses()) + { + // The end of the scope was referenced, so we need to actually + // keep it around and jump through it. + // Move the terminator to the scope end block. + emitBranchIfNeeded(context->scopeEndBlock); + builder->insertBlock(context->scopeEndBlock); + builder->setInsertInto(context->scopeEndBlock); + } + else + { + // Scope end block was left unused, so we may as well delete it. + context->scopeEndBlock->removeAndDeallocate(); + } + } + + context->scopeEndBlock = previous; + } + void visitIfStmt(IfStmt* stmt) { auto builder = getBuilder(); @@ -6043,11 +6094,13 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> ifInst = builder->emitIfElse(irCond, thenBlock, elseBlock, afterBlock); insertBlock(thenBlock); + IRBlock* prevScopeEndBlock = pushScopeBlock(afterBlock); lowerStmt(context, thenStmt); emitBranchIfNeeded(afterBlock); insertBlock(elseBlock); lowerStmt(context, elseStmt); + popScopeBlock(prevScopeEndBlock, true); insertBlock(afterBlock); } @@ -6059,7 +6112,10 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> ifInst = builder->emitIf(irCond, thenBlock, afterBlock); insertBlock(thenBlock); + + IRBlock* prevScopeEndBlock = pushScopeBlock(afterBlock); lowerStmt(context, thenStmt); + popScopeBlock(prevScopeEndBlock, true); insertBlock(afterBlock); } @@ -6150,7 +6206,9 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // Emit the body of the loop insertBlock(bodyLabel); + IRBlock* prevScopeEndBlock = pushScopeBlock(continueLabel); lowerStmt(context, stmt->statement); + popScopeBlock(prevScopeEndBlock, true); if (auto inferredMaxIters = stmt->findModifier<InferredMaxItersAttribute>()) { @@ -6256,7 +6314,9 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // Emit the body of the loop insertBlock(bodyLabel); + IRBlock* prevScopeEndBlock = pushScopeBlock(continueLabel); lowerStmt(context, stmt->statement); + popScopeBlock(prevScopeEndBlock, true); // At the end of the body we need to jump back to the top. emitBranchIfNeeded(loopHead); @@ -6300,7 +6360,9 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> insertBlock(loopHead); // Emit the body of the loop + IRBlock* prevScopeEndBlock = pushScopeBlock(continueLabel); lowerStmt(context, stmt->statement); + popScopeBlock(prevScopeEndBlock, true); insertBlock(testLabel); @@ -6429,10 +6491,12 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> void visitBlockStmt(BlockStmt* stmt) { - // To lower a block (scope) statement, - // just lower its body. The IR doesn't - // need to reflect the scoping of the AST. + IRBlock* prevScopeEndBlock = pushScopeBlock(nullptr); + + // To lower a block (scope) statement, just lower its body. lowerStmt(context, stmt->body); + + popScopeBlock(prevScopeEndBlock, false); } void visitReturnStmt(ReturnStmt* stmt) @@ -6523,6 +6587,29 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> } } + void visitDeferStmt(DeferStmt* stmt) + { + auto builder = getBuilder(); + startBlockIfNeeded(stmt); + + IRBlock* deferBlock = builder->createBlock(); + IRBlock* mergeBlock = builder->createBlock(); + + builder->emitDefer(deferBlock, mergeBlock, context->scopeEndBlock); + + builder->insertBlock(deferBlock); + builder->setInsertInto(deferBlock); + + IRBlock* prevScopeEndBlock = pushScopeBlock(mergeBlock); + lowerStmt(context, stmt->statement); + popScopeBlock(prevScopeEndBlock, true); + + builder->emitBranch(mergeBlock); + + builder->insertBlock(mergeBlock); + builder->setInsertInto(mergeBlock); + } + void visitDiscardStmt(DiscardStmt* stmt) { startBlockIfNeeded(stmt); @@ -11714,6 +11801,9 @@ RefPtr<IRModule> generateIRForTranslationUnit( // normal `call` + `ifElse`, etc. lowerErrorHandling(module, compileRequest->getSink()); + // Lower `defer` so that later passes need not be aware of it. + lowerDefer(module, compileRequest->getSink()); + // Synthesize some code we want to make sure is inlined and simplified synthesizeBitFieldAccessors(module); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 1b862de77..97777c9fe 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -215,6 +215,7 @@ public: BreakStmt* ParseBreakStatement(); ContinueStmt* ParseContinueStatement(); ReturnStmt* ParseReturnStatement(); + DeferStmt* ParseDeferStatement(); ExpressionStmt* ParseExpressionStatement(); Expr* ParseExpression(Precedence level = Precedence::Comma); @@ -5770,6 +5771,10 @@ Stmt* Parser::ParseStatement(Stmt* parentStmt) { statement = parseCompileTimeStmt(this); } + else if (LookAheadToken("defer")) + { + statement = ParseDeferStatement(); + } else if (LookAheadToken("try")) { statement = ParseExpressionStatement(); @@ -6299,6 +6304,15 @@ ReturnStmt* Parser::ParseReturnStatement() return returnStatement; } +DeferStmt* Parser::ParseDeferStatement() +{ + DeferStmt* deferStatement = astBuilder->create<DeferStmt>(); + FillPosition(deferStatement); + ReadToken("defer"); + deferStatement->statement = ParseStatement(); + return deferStatement; +} + ExpressionStmt* Parser::ParseExpressionStatement() { ExpressionStmt* statement = astBuilder->create<ExpressionStmt>(); diff --git a/tests/language-feature/defer/autodiff.slang b/tests/language-feature/defer/autodiff.slang new file mode 100644 index 000000000..e7dd871fa --- /dev/null +++ b/tests/language-feature/defer/autodiff.slang @@ -0,0 +1,56 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[Differentiable] +float testFunc(float a) +{ + float x = a / (abs(a) + 0.5f); + { + defer x = sqrt(x); + + x = x * 0.5f + 0.5f; + + if (a < 0) + { + x += 1.0f; + // NOTE suprising but correct behaviour here: 'defer' occurs after + // the return statement's value has been computed, so mutating 'x' + // no longer affects anything. + return x; + } + + x += 0.5f; + } + return x; +} + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + outputBuffer[0] = testFunc(0); + outputBuffer[1] = testFunc(0.5); + outputBuffer[2] = testFunc(-0.5); + + DifferentialPair<float> d1 = diffPair(0.0); + bwd_diff(testFunc)(d1, 1.0); + DifferentialPair<float> d2 = diffPair(0.5); + bwd_diff(testFunc)(d2, 1.0); + DifferentialPair<float> d3 = diffPair(-0.5); + bwd_diff(testFunc)(d3, 1.0); + + outputBuffer[3] = d1.d; + outputBuffer[4] = d2.d; + outputBuffer[5] = d3.d; + + d1 = diffPair(0.0, 1.0); + d2 = diffPair(0.5, 1.0); + d3 = diffPair(-0.5, 1.0); + + outputBuffer[6] = fwd_diff(testFunc)(d1).d; + outputBuffer[7] = fwd_diff(testFunc)(d2).d; + outputBuffer[8] = fwd_diff(testFunc)(d3).d; +} diff --git a/tests/language-feature/defer/autodiff.slang.expected.txt b/tests/language-feature/defer/autodiff.slang.expected.txt new file mode 100644 index 000000000..471a03b7c --- /dev/null +++ b/tests/language-feature/defer/autodiff.slang.expected.txt @@ -0,0 +1,10 @@ +type: float +1.000000 +1.118034 +1.250000 +0.500000 +0.111803 +0.250000 +0.500000 +0.111803 +0.250000 diff --git a/tests/language-feature/defer/complex-block.slang b/tests/language-feature/defer/complex-block.slang new file mode 100644 index 000000000..c2d232035 --- /dev/null +++ b/tests/language-feature/defer/complex-block.slang @@ -0,0 +1,36 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + int i = 0; + int j = 0; + defer { + if (j == 2) + { + outputBuffer[i++] = j; + } + else + { + outputBuffer[i++] = 2048; + } + } + defer { + if (j == 1) + { + outputBuffer[i++] = j; + } + else + { + outputBuffer[i++] = 1024; + } + j++; + } + outputBuffer[i++] = j; + j++; +} diff --git a/tests/language-feature/defer/complex-block.slang.expected.txt b/tests/language-feature/defer/complex-block.slang.expected.txt new file mode 100644 index 000000000..4539bbf2d --- /dev/null +++ b/tests/language-feature/defer/complex-block.slang.expected.txt @@ -0,0 +1,3 @@ +0 +1 +2 diff --git a/tests/language-feature/defer/deferred-loop.slang b/tests/language-feature/defer/deferred-loop.slang new file mode 100644 index 000000000..0108dbbc1 --- /dev/null +++ b/tests/language-feature/defer/deferred-loop.slang @@ -0,0 +1,33 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + int i = 0; + int j = 0; + defer { + defer outputBuffer[i++] = j*3+3; + outputBuffer[i++] = 1; + for (int k = 0; k < 6; ++k) + { + defer outputBuffer[i++] = k*3+2; + + if(k == j-4) + continue; + + outputBuffer[i++] = k*3; + + if(k == j) + break; + + outputBuffer[i++] = k*3+1; + } + } + outputBuffer[i++] = j; + j += 4; +} diff --git a/tests/language-feature/defer/deferred-loop.slang.expected.txt b/tests/language-feature/defer/deferred-loop.slang.expected.txt new file mode 100644 index 000000000..6b625c8b4 --- /dev/null +++ b/tests/language-feature/defer/deferred-loop.slang.expected.txt @@ -0,0 +1,15 @@ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +A +B +C +E +F diff --git a/tests/language-feature/defer/generics.slang b/tests/language-feature/defer/generics.slang new file mode 100644 index 000000000..7ecd7de31 --- /dev/null +++ b/tests/language-feature/defer/generics.slang @@ -0,0 +1,48 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +interface IDroppable +{ + void drop(inout int i); +}; + +interface IDiscombobulatable +{ + void discombobulate(inout int i, int param); +} + +void genericFunc<T>(int param) + where T : IDroppable, IDiscombobulatable +{ + T t; + int i = 0; + defer t.drop(i); + + int p = param; + defer t.discombobulate(i, p); + t.discombobulate(i, p); + p += 1; +} + +struct TestType : IDroppable, IDiscombobulatable +{ + void drop(inout int i) + { + outputBuffer[i++] = 0xFF; + } + + void discombobulate(inout int i, int param) + { + outputBuffer[i++] = param; + } +} + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + genericFunc<TestType>(2); +} diff --git a/tests/language-feature/defer/generics.slang.expected.txt b/tests/language-feature/defer/generics.slang.expected.txt new file mode 100644 index 000000000..4a7a64d00 --- /dev/null +++ b/tests/language-feature/defer/generics.slang.expected.txt @@ -0,0 +1,3 @@ +2 +3 +FF diff --git a/tests/language-feature/defer/loop.slang b/tests/language-feature/defer/loop.slang new file mode 100644 index 000000000..b51955a14 --- /dev/null +++ b/tests/language-feature/defer/loop.slang @@ -0,0 +1,30 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + int i = 0; + defer outputBuffer[i++] = 128; + for (int j = 0; j < 4; j++) + { + defer outputBuffer[i++] = j*2+1; + outputBuffer[i++] = j*2; + } + + for (int j = 0; j < 5; j++) + { + if (j == 1) + continue; + defer outputBuffer[i++] = j*2+1; + if (j == 2) + continue; + outputBuffer[i++] = j*2; + if (j == 3) + break; + } +} diff --git a/tests/language-feature/defer/loop.slang.expected.txt b/tests/language-feature/defer/loop.slang.expected.txt new file mode 100644 index 000000000..cc7323432 --- /dev/null +++ b/tests/language-feature/defer/loop.slang.expected.txt @@ -0,0 +1,14 @@ +0 +1 +2 +3 +4 +5 +6 +7 +0 +1 +5 +6 +7 +80 diff --git a/tests/language-feature/defer/multiple-return.slang b/tests/language-feature/defer/multiple-return.slang new file mode 100644 index 000000000..a364c288e --- /dev/null +++ b/tests/language-feature/defer/multiple-return.slang @@ -0,0 +1,35 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//DISABLE_TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +// ^ Due to DeviceMemoryBarrier() missing. + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +void testFunc(int index1, int index2, int condition) +{ + outputBuffer[index1] = 128; + defer outputBuffer[index1] = 0; + + outputBuffer[index2] = 3; + DeviceMemoryBarrier(); + + if (condition == 0) + { + outputBuffer[index2] = 1; + return; + } + + if (condition == 1) + { + outputBuffer[index2] = 2; + return; + } +} + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + for (int i = 0; i < 3; ++i) + testFunc(2*i, 2*i+1, i); +} diff --git a/tests/language-feature/defer/multiple-return.slang.expected.txt b/tests/language-feature/defer/multiple-return.slang.expected.txt new file mode 100644 index 000000000..5b06fc619 --- /dev/null +++ b/tests/language-feature/defer/multiple-return.slang.expected.txt @@ -0,0 +1,6 @@ +0 +1 +0 +2 +0 +3 diff --git a/tests/language-feature/defer/nested.slang b/tests/language-feature/defer/nested.slang new file mode 100644 index 000000000..6de518209 --- /dev/null +++ b/tests/language-feature/defer/nested.slang @@ -0,0 +1,18 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + int i = 0; + defer { + defer outputBuffer[i++] = 3; + defer outputBuffer[i++] = 2; + outputBuffer[i++] = 1; + } + outputBuffer[i++] = 0; +} diff --git a/tests/language-feature/defer/nested.slang.expected.txt b/tests/language-feature/defer/nested.slang.expected.txt new file mode 100644 index 000000000..bc856dafa --- /dev/null +++ b/tests/language-feature/defer/nested.slang.expected.txt @@ -0,0 +1,4 @@ +0 +1 +2 +3 diff --git a/tests/language-feature/defer/no-block.slang b/tests/language-feature/defer/no-block.slang new file mode 100644 index 000000000..e78c91112 --- /dev/null +++ b/tests/language-feature/defer/no-block.slang @@ -0,0 +1,28 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + int i = 0; + defer defer outputBuffer[i++] = 0xFF; + + outputBuffer[i++] = 0; + + for (int j = 0; j < 4; j++) + defer outputBuffer[i++] = j; + + if (i == 5) + defer outputBuffer[i++] = 0x80; + + while (i < 7) + defer outputBuffer[i++] = 0x81; + + do defer outputBuffer[i++] = 0x82; while (false); + + outputBuffer[i++] = 5; +} diff --git a/tests/language-feature/defer/no-block.slang.expected.txt b/tests/language-feature/defer/no-block.slang.expected.txt new file mode 100644 index 000000000..ad8c0f654 --- /dev/null +++ b/tests/language-feature/defer/no-block.slang.expected.txt @@ -0,0 +1,10 @@ +0 +0 +1 +2 +3 +80 +81 +82 +5 +FF diff --git a/tests/language-feature/defer/scoped.slang b/tests/language-feature/defer/scoped.slang new file mode 100644 index 000000000..447c02357 --- /dev/null +++ b/tests/language-feature/defer/scoped.slang @@ -0,0 +1,33 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + int i = 0; + defer outputBuffer[i++] = 7; + + { + defer outputBuffer[i++] = 0; + } + + outputBuffer[i++] = 1; + + { + defer outputBuffer[i++] = 2; + } + + for (int j = 0; j < 4; ++j) + { + { + defer outputBuffer[i++] = j+3; + if (j == 2) + continue; + } + outputBuffer[i++] = j+0x13; + } +} diff --git a/tests/language-feature/defer/scoped.slang.expected.txt b/tests/language-feature/defer/scoped.slang.expected.txt new file mode 100644 index 000000000..8d280bd37 --- /dev/null +++ b/tests/language-feature/defer/scoped.slang.expected.txt @@ -0,0 +1,11 @@ +0 +1 +2 +3 +13 +4 +14 +5 +6 +16 +7 diff --git a/tests/language-feature/defer/switch.slang b/tests/language-feature/defer/switch.slang new file mode 100644 index 000000000..704e6e884 --- /dev/null +++ b/tests/language-feature/defer/switch.slang @@ -0,0 +1,34 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +void testFunc(int i, int param) +{ + defer outputBuffer[i++] = param*3+2; + switch(param) + { + case 0: + defer outputBuffer[i++] = 1; + outputBuffer[i++] = 0; + break; + case 1: + defer outputBuffer[i++] = 4; + outputBuffer[i++] = 3; + break; + default: + defer outputBuffer[i++] = 7; + outputBuffer[i++] = 6; + break; + }; +} + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + testFunc(0, 0); + testFunc(3, 1); + testFunc(6, 2); +} diff --git a/tests/language-feature/defer/switch.slang.expected.txt b/tests/language-feature/defer/switch.slang.expected.txt new file mode 100644 index 000000000..1000f9005 --- /dev/null +++ b/tests/language-feature/defer/switch.slang.expected.txt @@ -0,0 +1,9 @@ +0 +1 +2 +3 +4 +5 +6 +7 +8 |
