From 1b82501dd0c74347cda4a2c7fe5a84fd610bb485 Mon Sep 17 00:00:00 2001 From: Julius Ikkala Date: Mon, 7 Apr 2025 06:08:29 +0300 Subject: Add defer statement (#6619) --- docs/user-guide/02-conventional-features.md | 6 + source/slang/slang-ast-iterator.h | 6 + source/slang/slang-ast-stmt.h | 7 + source/slang/slang-check-decl.cpp | 2 + source/slang/slang-check-impl.h | 4 +- source/slang/slang-check-stmt.cpp | 29 ++- source/slang/slang-diagnostic-defs.h | 8 + source/slang/slang-ir-inst-defs.h | 4 +- source/slang/slang-ir-insts.h | 11 + source/slang/slang-ir-lower-defer.cpp | 258 +++++++++++++++++++++ source/slang/slang-ir-lower-defer.h | 18 ++ source/slang/slang-ir.cpp | 16 ++ source/slang/slang-language-server-ast-lookup.cpp | 2 + source/slang/slang-lower-to-ir.cpp | 96 +++++++- source/slang/slang-parser.cpp | 14 ++ tests/language-feature/defer/autodiff.slang | 56 +++++ .../defer/autodiff.slang.expected.txt | 10 + tests/language-feature/defer/complex-block.slang | 36 +++ .../defer/complex-block.slang.expected.txt | 3 + tests/language-feature/defer/deferred-loop.slang | 33 +++ .../defer/deferred-loop.slang.expected.txt | 15 ++ tests/language-feature/defer/generics.slang | 48 ++++ .../defer/generics.slang.expected.txt | 3 + tests/language-feature/defer/loop.slang | 30 +++ .../language-feature/defer/loop.slang.expected.txt | 14 ++ tests/language-feature/defer/multiple-return.slang | 35 +++ .../defer/multiple-return.slang.expected.txt | 6 + tests/language-feature/defer/nested.slang | 18 ++ .../defer/nested.slang.expected.txt | 4 + tests/language-feature/defer/no-block.slang | 28 +++ .../defer/no-block.slang.expected.txt | 10 + tests/language-feature/defer/scoped.slang | 33 +++ .../defer/scoped.slang.expected.txt | 11 + tests/language-feature/defer/switch.slang | 34 +++ .../defer/switch.slang.expected.txt | 9 + 35 files changed, 910 insertions(+), 7 deletions(-) create mode 100644 source/slang/slang-ir-lower-defer.cpp create mode 100644 source/slang/slang-ir-lower-defer.h create mode 100644 tests/language-feature/defer/autodiff.slang create mode 100644 tests/language-feature/defer/autodiff.slang.expected.txt create mode 100644 tests/language-feature/defer/complex-block.slang create mode 100644 tests/language-feature/defer/complex-block.slang.expected.txt create mode 100644 tests/language-feature/defer/deferred-loop.slang create mode 100644 tests/language-feature/defer/deferred-loop.slang.expected.txt create mode 100644 tests/language-feature/defer/generics.slang create mode 100644 tests/language-feature/defer/generics.slang.expected.txt create mode 100644 tests/language-feature/defer/loop.slang create mode 100644 tests/language-feature/defer/loop.slang.expected.txt create mode 100644 tests/language-feature/defer/multiple-return.slang create mode 100644 tests/language-feature/defer/multiple-return.slang.expected.txt create mode 100644 tests/language-feature/defer/nested.slang create mode 100644 tests/language-feature/defer/nested.slang.expected.txt create mode 100644 tests/language-feature/defer/no-block.slang create mode 100644 tests/language-feature/defer/no-block.slang.expected.txt create mode 100644 tests/language-feature/defer/scoped.slang create mode 100644 tests/language-feature/defer/scoped.slang.expected.txt create mode 100644 tests/language-feature/defer/switch.slang create mode 100644 tests/language-feature/defer/switch.slang.expected.txt diff --git a/docs/user-guide/02-conventional-features.md b/docs/user-guide/02-conventional-features.md index aaeea4114..5c15986f8 100644 --- a/docs/user-guide/02-conventional-features.md +++ b/docs/user-guide/02-conventional-features.md @@ -435,6 +435,12 @@ Slang supports the following statement forms with nearly identical syntax to HLS * `return` statements +* `defer` statements + +> #### Note #### +> The `defer` statement in Slang is tied to scope. The deferred statement runs at the end of the scope like in Swift, not just at the end of the function like in Go. +> `defer` supports but does not require block statements: both `defer f();` and `defer { f(); g(); }` are legal. + > #### Note #### > Slang does not support the C/C++ `goto` keyword. diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index 436e97c1b..c7da945f2 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -430,6 +430,12 @@ struct ASTIterator iterator->visitExpr(stmt->expression); } + void visitDeferStmt(DeferStmt* stmt) + { + iterator->maybeDispatchCallback(stmt); + dispatchIfNotNull(stmt->statement); + } + void visitWhileStmt(WhileStmt* stmt) { iterator->maybeDispatchCallback(stmt); diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h index 09c491287..f6580ba21 100644 --- a/source/slang/slang-ast-stmt.h +++ b/source/slang/slang-ast-stmt.h @@ -253,6 +253,13 @@ class ReturnStmt : public Stmt Expr* expression = nullptr; }; +class DeferStmt : public Stmt +{ + SLANG_AST_CLASS(DeferStmt) + + Stmt* statement = nullptr; +}; + class ExpressionStmt : public Stmt { SLANG_AST_CLASS(ExpressionStmt) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 34143ad75..4ab909118 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -650,6 +650,8 @@ struct SemanticsDeclReferenceVisitor : public SemanticsDeclVisitorBase, void visitReturnStmt(ReturnStmt* stmt) { dispatchIfNotNull(stmt->expression); } + void visitDeferStmt(DeferStmt* stmt) { dispatchIfNotNull(stmt->statement); } + void visitWhileStmt(WhileStmt* stmt) { dispatchIfNotNull(stmt->predicate); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 95716744c..f7681ba45 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2976,7 +2976,7 @@ struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitor - T* FindOuterStmt(); + T* FindOuterStmt(Stmt* searchUntil = nullptr); Stmt* findOuterStmtWithLabel(Name* label); @@ -3020,6 +3020,8 @@ struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitor -T* SemanticsStmtVisitor::FindOuterStmt() +T* SemanticsStmtVisitor::FindOuterStmt(Stmt* searchUntil) { - for (auto outerStmtInfo = m_outerStmts; outerStmtInfo; outerStmtInfo = outerStmtInfo->next) + for (auto outerStmtInfo = m_outerStmts; outerStmtInfo && outerStmtInfo->stmt != searchUntil; + outerStmtInfo = outerStmtInfo->next) { auto outerStmt = outerStmtInfo->stmt; auto found = as(outerStmt); @@ -178,6 +179,14 @@ void SemanticsStmtVisitor::visitBreakStmt(BreakStmt* stmt) getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop); } } + + // If there is a defer statement before the breakable statement, it's + // illegal. + if (FindOuterStmt(targetStmt)) + { + getSink()->diagnose(stmt, Diagnostics::breakInsideDefer); + } + stmt->parentStmt = targetStmt; } @@ -188,6 +197,11 @@ void SemanticsStmtVisitor::visitContinueStmt(ContinueStmt* stmt) { getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop); } + + if (FindOuterStmt(outer)) + { + getSink()->diagnose(stmt, Diagnostics::continueInsideDefer); + } stmt->parentStmt = outer; } @@ -497,6 +511,11 @@ void SemanticsStmtVisitor::visitReturnStmt(ReturnStmt* stmt) } } } + + if (FindOuterStmt()) + { + getSink()->diagnose(stmt, Diagnostics::returnInsideDefer); + } } void SemanticsStmtVisitor::visitWhileStmt(WhileStmt* stmt) @@ -508,6 +527,12 @@ void SemanticsStmtVisitor::visitWhileStmt(WhileStmt* stmt) checkLoopInDifferentiableFunc(stmt); } +void SemanticsStmtVisitor::visitDeferStmt(DeferStmt* stmt) +{ + WithOuterStmt subContext(this, stmt); + subContext.checkStmt(stmt->statement); +} + void SemanticsStmtVisitor::visitExpressionStmt(ExpressionStmt* stmt) { stmt->expression = CheckExpr(stmt->expression); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index b7469deb0..f2c7fecc1 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -890,6 +890,14 @@ DIAGNOSTIC( DIAGNOSTIC(30106, Error, improperUseOfType, "type '$0' cannot be used in this context.") DIAGNOSTIC(30107, Error, parameterPackMustBeConst, "a parameter pack must be declared as 'const'.") +DIAGNOSTIC(30108, Error, breakInsideDefer, "'break' must not appear inside a defer statement.") +DIAGNOSTIC( + 30109, + Error, + continueInsideDefer, + "'continue' must not appear inside a defer statement.") +DIAGNOSTIC(30110, Error, returnInsideDefer, "'return' must not appear inside a defer statement.") + // Include DIAGNOSTIC( diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index bd590f08f..c7ed5affe 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -661,7 +661,9 @@ INST(MissingReturn, missingReturn, 0, 0) INST(Unreachable, unreachable, 0, 0) INST_RANGE(Unreachable, MissingReturn, Unreachable) -INST_RANGE(TerminatorInst, Return, Unreachable) +INST(Defer, defer, 3, 0) + +INST_RANGE(TerminatorInst, Return, Defer) INST(discard, discard, 0, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 51608ebc9..645b13e00 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2852,6 +2852,15 @@ struct IRTryCall : IRTerminatorInst IRInst* getArg(UInt index) { return getOperand(index + 3); } }; +struct IRDefer : IRTerminatorInst +{ + IR_LEAF_ISA(Defer); + + IRBlock* getDeferBlock() { return cast(getOperand(0)); } + IRBlock* getMergeBlock() { return cast(getOperand(1)); } + IRBlock* getScopeBlock() { return cast(getOperand(2)); } +}; + struct IRSwizzle : IRInst { IR_LEAF_ISA(swizzle); @@ -4462,6 +4471,8 @@ public: IRInst* emitThrow(IRInst* val); + IRInst* emitDefer(IRBlock* deferBlock, IRBlock* mergeBlock, IRBlock* scopeEndBlock); + IRInst* emitDiscard(); IRInst* emitCheckpointObject(IRInst* value); diff --git a/source/slang/slang-ir-lower-defer.cpp b/source/slang/slang-ir-lower-defer.cpp new file mode 100644 index 000000000..4e1ec9a8b --- /dev/null +++ b/source/slang/slang-ir-lower-defer.cpp @@ -0,0 +1,258 @@ +// slang-ir-lower-defer.cpp + +#include "slang-ir-lower-defer.h" + +#include "slang-ir-clone.h" +#include "slang-ir-dominators.h" +#include "slang-ir-inst-pass-base.h" +#include "slang-ir-insts.h" +#include "slang-ir.h" + +namespace Slang +{ + +struct DeferLoweringContext : InstPassBase +{ + DiagnosticSink* diagnosticSink; + + DeferLoweringContext(IRModule* inModule) + : InstPassBase(inModule) + { + } + + void inlineSingleBlockDefer(IRInst* beforeInst, IRBlock* deferBlock, IRBuilder* builder) + { + builder->setInsertBefore(beforeInst); + IRCloneEnv env; + for (IRInst* inst : deferBlock->getChildren()) + { + // Copy everything except the terminator; the terminator should only + // be a jump to mergeBlock, which isn't needed after inlining. + if (!as(inst)) + cloneInst(&env, builder, inst); + } + } + + // Returns the new last block. + IRBlock* inlineDefer( + IRInst* beforeInst, + IRBlock* targetBlock, + const List& deferBlocks, + IRBlock* mergeBlock, + IRBuilder* builder) + { + // The single-block inlining case is simple, we can just dump the + // instructions at the target position, in the existing block. + if (deferBlocks.getCount() == 1) + { + inlineSingleBlockDefer(beforeInst, deferBlocks.getFirst(), builder); + return targetBlock; + } + + // Otherwise, we'll have to splice the blocks in. + IRCloneEnv env; + builder->setInsertAfter(targetBlock); + auto lastBlock = targetBlock; + + // Clone blocks first + for (auto block : deferBlocks) + { + auto clonedBlock = builder->createBlock(); + builder->addInst(clonedBlock); + env.mapOldValToNew[block] = clonedBlock; + } + + // Then, clone instructions, but mapping old blocks to new blocks. + for (auto block : deferBlocks) + { + auto clonedBlock = as(env.mapOldValToNew.getValue(block)); + builder->setInsertInto(clonedBlock); + for (auto inst : block->getChildren()) + { + auto endBranch = as(inst); + if (endBranch && endBranch->getTargetBlock() == mergeBlock) + { + lastBlock = clonedBlock; + } + else + cloneInst(&env, builder, inst); + } + } + + // Move old instructions to the last block's end. The last defer block + // shouldn't have a terminator at this point yet. + while (beforeInst) + { + auto nextInst = beforeInst->getNextInst(); + beforeInst->insertAtEnd(lastBlock); + beforeInst = nextInst; + } + + // Make target block jump to the cloned blocks. + builder->setInsertInto(targetBlock); + auto mainBlock = as(env.mapOldValToNew.getValue(deferBlocks[0])); + builder->emitBranch(mainBlock); + + return lastBlock; + } + + HashSet findSuccessorBlocks(IRGlobalValueWithCode* func, IRBlock* block) + { + HashSet successorBlocksSet; + List successorWorkList; + successorWorkList.add(block); + + List postorder = getPostorder(func); + Index limitIndex = postorder.indexOf(block); + while (successorWorkList.getCount() > 0) + { + IRBlock* predecessor = successorWorkList.getLast(); + successorWorkList.removeLast(); + if (successorBlocksSet.contains(predecessor)) + continue; + + Index predecessorIndex = postorder.indexOf(predecessor); + // Does not succeed if it is after the given block in postorder. + if (predecessorIndex > limitIndex) + continue; + + successorBlocksSet.add(predecessor); + for (IRBlock* successor : predecessor->getSuccessors()) + successorWorkList.add(successor); + } + return successorBlocksSet; + } + + void processFunc(IRGlobalValueWithCode* func) + { + // Iterating over `defer` instructions in reverse order allows us to + // expand them in the correct order, including nested `defer`s. + // We also use this to determine scope extents. + List reverseBlocks = getReversePostorderOnReverseCFG(func); + List unhandledDefers; + + for (IRBlock* block : reverseBlocks) + { + for (auto child = block->getLastChild(); child; child = child->getPrevInst()) + { + if (auto defer = as(child)) + unhandledDefers.add(defer); + } + } + + IRBuilder builder(module); + Dictionary mapOldScopeToNew; + for (IRDefer* defer : unhandledDefers) + { + IRBlock* firstDeferBlock = defer->getDeferBlock(); + IRBlock* mergeBlock = defer->getMergeBlock(); + IRBlock* scopeEndBlock = defer->getScopeBlock(); + mapOldScopeToNew.tryGetValue(scopeEndBlock, scopeEndBlock); + IRBlock* parentBlock = as(defer->getParent()); + + // The dominator tree gets invalidated on every iteration, so it's + // necessary to construct it inside the loop. + auto dom = module->findOrCreateDominatorTree(func); + + // Enumerate defer block range. That is, all blocks dominated by + // parentBlock and not dominated by mergeBlock. + auto deferDominatedBlocks = dom->getProperlyDominatedBlocks(firstDeferBlock); + List deferBlocks; + deferBlocks.add(firstDeferBlock); + for (IRBlock* block : deferDominatedBlocks) + { + if (!dom->properlyDominates(mergeBlock, block) && block != mergeBlock) + deferBlocks.add(block); + } + + auto dominatedBlocks = dom->getProperlyDominatedBlocks(mergeBlock); + + + HashSet scopeSuccessorBlocksSet = findSuccessorBlocks(func, scopeEndBlock); + HashSet scopeBlocksSet; + scopeBlocksSet.add(mergeBlock); + for (IRBlock* block : dominatedBlocks) + { + if (!scopeSuccessorBlocksSet.contains(block)) + scopeBlocksSet.add(block); + } + + // All jumps from blocks in scope to blocks out of scope are to be + // preceded by a copy of the deferBlocks. + for (IRBlock* block : scopeBlocksSet) + { + auto terminator = block->getTerminator(); + SLANG_ASSERT(terminator); + bool exits = false; + switch (terminator->getOp()) + { + case kIROp_Return: + case kIROp_discard: + case kIROp_Throw: + exits = true; + break; + case kIROp_unconditionalBranch: + { + auto targetBlock = as(terminator->getOperand(0)); + if (!scopeBlocksSet.contains(targetBlock)) + { + exits = true; + } + } + break; + case kIROp_conditionalBranch: + { + auto trueBlock = as(terminator->getOperand(1)); + auto falseBlock = as(terminator->getOperand(2)); + if (!scopeBlocksSet.contains(trueBlock) || + !scopeBlocksSet.contains(falseBlock)) + { + exits = true; + } + } + break; + default: + break; + } + + if (exits) + { // Duplicate child instructions to the end of this block. + auto newEnd = inlineDefer(terminator, block, deferBlocks, mergeBlock, &builder); + if (newEnd != block) + { + mapOldScopeToNew[block] = newEnd; + } + } + } + + // Replace defer with unconditional branch to mergeBlock. Defer + // blocks should now be orphaned, and we can remove them too. + defer->removeAndDeallocate(); + builder.setInsertInto(parentBlock); + builder.emitBranch(mergeBlock); + + for (IRBlock* deferBlock : deferBlocks) + { + deferBlock->removeAndDeallocate(); + } + + // Some blocks got removed and added, so mark analysis of the + // function with defer as outdated. + module->invalidateAnalysisForInst(func); + } + } + + void processModule() + { + processInstsOfType(kIROp_Func, [&](IRFunc* func) { processFunc(func); }); + } +}; + +void lowerDefer(IRModule* module, DiagnosticSink* sink) +{ + DeferLoweringContext context(module); + context.diagnosticSink = sink; + return context.processModule(); +} + +} // namespace Slang diff --git a/source/slang/slang-ir-lower-defer.h b/source/slang/slang-ir-lower-defer.h new file mode 100644 index 000000000..c00104745 --- /dev/null +++ b/source/slang/slang-ir-lower-defer.h @@ -0,0 +1,18 @@ +// slang-ir-lower-defer.h +#pragma once + +#include "slang-ir.h" + +namespace Slang +{ +struct IRModule; +class DiagnosticSink; + +/// Lower the `defer` statements. +/// +/// Duplicates the child instructions of each `defer` to the end of each +/// dominated block whose terminator jumps to a location that is not dominated +/// by the `defer`. Also removes all `IRDefer` instructions after that. +void lowerDefer(IRModule* module, DiagnosticSink* sink); + +} // namespace Slang diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 1287d1598..f75fe2f48 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -596,6 +596,13 @@ static IRBlock::SuccessorList getSuccessors(IRInst* terminator) end = operands + terminator->getOperandCount() + 1; stride = 2; break; + + case kIROp_Defer: + // defer + begin = operands + 0; + end = begin + 1; + break; + default: SLANG_UNEXPECTED("unhandled terminator instruction"); UNREACHABLE_RETURN(IRBlock::SuccessorList(nullptr, nullptr)); @@ -869,6 +876,7 @@ bool isTerminatorInst(IROp op) case kIROp_Switch: case kIROp_Unreachable: case kIROp_MissingReturn: + case kIROp_Defer: return true; } } @@ -5698,6 +5706,14 @@ IRInst* IRBuilder::emitReturn() return inst; } +IRInst* IRBuilder::emitDefer(IRBlock* deferBlock, IRBlock* mergeBlock, IRBlock* scopeEndBlock) +{ + auto inst = + createInst(this, kIROp_Defer, nullptr, deferBlock, mergeBlock, scopeEndBlock); + addInst(inst); + return inst; +} + IRInst* IRBuilder::emitThrow(IRInst* val) { auto inst = createInst(this, kIROp_Throw, nullptr, val); diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index ebc3fe9ad..7375756f5 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -642,6 +642,8 @@ struct ASTLookupStmtVisitor : public StmtVisitor bool visitReturnStmt(ReturnStmt* stmt) { return checkExpr(stmt->expression); } + bool visitDeferStmt(DeferStmt* stmt) { return dispatchIfNotNull(stmt->statement); } + bool visitWhileStmt(WhileStmt* stmt) { if (checkExpr(stmt->predicate)) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index e3c4ddf05..e6ec68660 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -19,6 +19,7 @@ #include "slang-ir-insert-debug-value-store.h" #include "slang-ir-insts.h" #include "slang-ir-loop-inversion.h" +#include "slang-ir-lower-defer.h" #include "slang-ir-lower-error-handling.h" #include "slang-ir-lower-expand-type.h" #include "slang-ir-missing-return.h" @@ -593,6 +594,9 @@ struct IRGenContext // The element index if we are inside an `expand` expression. IRInst* expandIndex = nullptr; + // The current scope end for use with `defer`. + IRBlock* scopeEndBlock = nullptr; + // Callback function to call when after lowering a type. std::function lowerTypeCallback = nullptr; @@ -6021,6 +6025,53 @@ struct StmtLoweringVisitor : StmtVisitor startBlock(); } + /// Create a new scope end block and return the previous one. + /// + /// This is needed for `defer` to be aware of scopes. `preallocated` can + /// be specified if you already have a block at the end of the scope, like + /// in `for` loops. + IRBlock* pushScopeBlock(IRBlock* preallocated = nullptr) + { + IRBlock* prevScopeEndBlock = context->scopeEndBlock; + + auto builder = getBuilder(); + context->scopeEndBlock = preallocated ? preallocated : builder->createBlock(); + return prevScopeEndBlock; + } + + /// Pop the current scope end block and restore the previous one. + /// + /// This is needed for `defer` to be aware of scopes. `previous` should be + /// the block returned from the corresponding pushScopeBlock. `preallocated` + /// should be true if the corresponding pushScopeBlock was given a block + /// as a parameter. + void popScopeBlock(IRBlock* previous, bool preallocated) + { + if (!preallocated) + { + // If pushScopeBlock actually created the block, we have to insert + // or deallocate it here. Otherwise, we assume that the caller + // handles the end block. + auto builder = getBuilder(); + if (context->scopeEndBlock->hasUses()) + { + // The end of the scope was referenced, so we need to actually + // keep it around and jump through it. + // Move the terminator to the scope end block. + emitBranchIfNeeded(context->scopeEndBlock); + builder->insertBlock(context->scopeEndBlock); + builder->setInsertInto(context->scopeEndBlock); + } + else + { + // Scope end block was left unused, so we may as well delete it. + context->scopeEndBlock->removeAndDeallocate(); + } + } + + context->scopeEndBlock = previous; + } + void visitIfStmt(IfStmt* stmt) { auto builder = getBuilder(); @@ -6043,11 +6094,13 @@ struct StmtLoweringVisitor : StmtVisitor ifInst = builder->emitIfElse(irCond, thenBlock, elseBlock, afterBlock); insertBlock(thenBlock); + IRBlock* prevScopeEndBlock = pushScopeBlock(afterBlock); lowerStmt(context, thenStmt); emitBranchIfNeeded(afterBlock); insertBlock(elseBlock); lowerStmt(context, elseStmt); + popScopeBlock(prevScopeEndBlock, true); insertBlock(afterBlock); } @@ -6059,7 +6112,10 @@ struct StmtLoweringVisitor : StmtVisitor ifInst = builder->emitIf(irCond, thenBlock, afterBlock); insertBlock(thenBlock); + + IRBlock* prevScopeEndBlock = pushScopeBlock(afterBlock); lowerStmt(context, thenStmt); + popScopeBlock(prevScopeEndBlock, true); insertBlock(afterBlock); } @@ -6150,7 +6206,9 @@ struct StmtLoweringVisitor : StmtVisitor // Emit the body of the loop insertBlock(bodyLabel); + IRBlock* prevScopeEndBlock = pushScopeBlock(continueLabel); lowerStmt(context, stmt->statement); + popScopeBlock(prevScopeEndBlock, true); if (auto inferredMaxIters = stmt->findModifier()) { @@ -6256,7 +6314,9 @@ struct StmtLoweringVisitor : StmtVisitor // Emit the body of the loop insertBlock(bodyLabel); + IRBlock* prevScopeEndBlock = pushScopeBlock(continueLabel); lowerStmt(context, stmt->statement); + popScopeBlock(prevScopeEndBlock, true); // At the end of the body we need to jump back to the top. emitBranchIfNeeded(loopHead); @@ -6300,7 +6360,9 @@ struct StmtLoweringVisitor : StmtVisitor insertBlock(loopHead); // Emit the body of the loop + IRBlock* prevScopeEndBlock = pushScopeBlock(continueLabel); lowerStmt(context, stmt->statement); + popScopeBlock(prevScopeEndBlock, true); insertBlock(testLabel); @@ -6429,10 +6491,12 @@ struct StmtLoweringVisitor : StmtVisitor void visitBlockStmt(BlockStmt* stmt) { - // To lower a block (scope) statement, - // just lower its body. The IR doesn't - // need to reflect the scoping of the AST. + IRBlock* prevScopeEndBlock = pushScopeBlock(nullptr); + + // To lower a block (scope) statement, just lower its body. lowerStmt(context, stmt->body); + + popScopeBlock(prevScopeEndBlock, false); } void visitReturnStmt(ReturnStmt* stmt) @@ -6523,6 +6587,29 @@ struct StmtLoweringVisitor : StmtVisitor } } + void visitDeferStmt(DeferStmt* stmt) + { + auto builder = getBuilder(); + startBlockIfNeeded(stmt); + + IRBlock* deferBlock = builder->createBlock(); + IRBlock* mergeBlock = builder->createBlock(); + + builder->emitDefer(deferBlock, mergeBlock, context->scopeEndBlock); + + builder->insertBlock(deferBlock); + builder->setInsertInto(deferBlock); + + IRBlock* prevScopeEndBlock = pushScopeBlock(mergeBlock); + lowerStmt(context, stmt->statement); + popScopeBlock(prevScopeEndBlock, true); + + builder->emitBranch(mergeBlock); + + builder->insertBlock(mergeBlock); + builder->setInsertInto(mergeBlock); + } + void visitDiscardStmt(DiscardStmt* stmt) { startBlockIfNeeded(stmt); @@ -11714,6 +11801,9 @@ RefPtr generateIRForTranslationUnit( // normal `call` + `ifElse`, etc. lowerErrorHandling(module, compileRequest->getSink()); + // Lower `defer` so that later passes need not be aware of it. + lowerDefer(module, compileRequest->getSink()); + // Synthesize some code we want to make sure is inlined and simplified synthesizeBitFieldAccessors(module); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 1b862de77..97777c9fe 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -215,6 +215,7 @@ public: BreakStmt* ParseBreakStatement(); ContinueStmt* ParseContinueStatement(); ReturnStmt* ParseReturnStatement(); + DeferStmt* ParseDeferStatement(); ExpressionStmt* ParseExpressionStatement(); Expr* ParseExpression(Precedence level = Precedence::Comma); @@ -5770,6 +5771,10 @@ Stmt* Parser::ParseStatement(Stmt* parentStmt) { statement = parseCompileTimeStmt(this); } + else if (LookAheadToken("defer")) + { + statement = ParseDeferStatement(); + } else if (LookAheadToken("try")) { statement = ParseExpressionStatement(); @@ -6299,6 +6304,15 @@ ReturnStmt* Parser::ParseReturnStatement() return returnStatement; } +DeferStmt* Parser::ParseDeferStatement() +{ + DeferStmt* deferStatement = astBuilder->create(); + FillPosition(deferStatement); + ReadToken("defer"); + deferStatement->statement = ParseStatement(); + return deferStatement; +} + ExpressionStmt* Parser::ParseExpressionStatement() { ExpressionStmt* statement = astBuilder->create(); diff --git a/tests/language-feature/defer/autodiff.slang b/tests/language-feature/defer/autodiff.slang new file mode 100644 index 000000000..e7dd871fa --- /dev/null +++ b/tests/language-feature/defer/autodiff.slang @@ -0,0 +1,56 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[Differentiable] +float testFunc(float a) +{ + float x = a / (abs(a) + 0.5f); + { + defer x = sqrt(x); + + x = x * 0.5f + 0.5f; + + if (a < 0) + { + x += 1.0f; + // NOTE suprising but correct behaviour here: 'defer' occurs after + // the return statement's value has been computed, so mutating 'x' + // no longer affects anything. + return x; + } + + x += 0.5f; + } + return x; +} + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + outputBuffer[0] = testFunc(0); + outputBuffer[1] = testFunc(0.5); + outputBuffer[2] = testFunc(-0.5); + + DifferentialPair d1 = diffPair(0.0); + bwd_diff(testFunc)(d1, 1.0); + DifferentialPair d2 = diffPair(0.5); + bwd_diff(testFunc)(d2, 1.0); + DifferentialPair d3 = diffPair(-0.5); + bwd_diff(testFunc)(d3, 1.0); + + outputBuffer[3] = d1.d; + outputBuffer[4] = d2.d; + outputBuffer[5] = d3.d; + + d1 = diffPair(0.0, 1.0); + d2 = diffPair(0.5, 1.0); + d3 = diffPair(-0.5, 1.0); + + outputBuffer[6] = fwd_diff(testFunc)(d1).d; + outputBuffer[7] = fwd_diff(testFunc)(d2).d; + outputBuffer[8] = fwd_diff(testFunc)(d3).d; +} diff --git a/tests/language-feature/defer/autodiff.slang.expected.txt b/tests/language-feature/defer/autodiff.slang.expected.txt new file mode 100644 index 000000000..471a03b7c --- /dev/null +++ b/tests/language-feature/defer/autodiff.slang.expected.txt @@ -0,0 +1,10 @@ +type: float +1.000000 +1.118034 +1.250000 +0.500000 +0.111803 +0.250000 +0.500000 +0.111803 +0.250000 diff --git a/tests/language-feature/defer/complex-block.slang b/tests/language-feature/defer/complex-block.slang new file mode 100644 index 000000000..c2d232035 --- /dev/null +++ b/tests/language-feature/defer/complex-block.slang @@ -0,0 +1,36 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + int i = 0; + int j = 0; + defer { + if (j == 2) + { + outputBuffer[i++] = j; + } + else + { + outputBuffer[i++] = 2048; + } + } + defer { + if (j == 1) + { + outputBuffer[i++] = j; + } + else + { + outputBuffer[i++] = 1024; + } + j++; + } + outputBuffer[i++] = j; + j++; +} diff --git a/tests/language-feature/defer/complex-block.slang.expected.txt b/tests/language-feature/defer/complex-block.slang.expected.txt new file mode 100644 index 000000000..4539bbf2d --- /dev/null +++ b/tests/language-feature/defer/complex-block.slang.expected.txt @@ -0,0 +1,3 @@ +0 +1 +2 diff --git a/tests/language-feature/defer/deferred-loop.slang b/tests/language-feature/defer/deferred-loop.slang new file mode 100644 index 000000000..0108dbbc1 --- /dev/null +++ b/tests/language-feature/defer/deferred-loop.slang @@ -0,0 +1,33 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + int i = 0; + int j = 0; + defer { + defer outputBuffer[i++] = j*3+3; + outputBuffer[i++] = 1; + for (int k = 0; k < 6; ++k) + { + defer outputBuffer[i++] = k*3+2; + + if(k == j-4) + continue; + + outputBuffer[i++] = k*3; + + if(k == j) + break; + + outputBuffer[i++] = k*3+1; + } + } + outputBuffer[i++] = j; + j += 4; +} diff --git a/tests/language-feature/defer/deferred-loop.slang.expected.txt b/tests/language-feature/defer/deferred-loop.slang.expected.txt new file mode 100644 index 000000000..6b625c8b4 --- /dev/null +++ b/tests/language-feature/defer/deferred-loop.slang.expected.txt @@ -0,0 +1,15 @@ +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 +A +B +C +E +F diff --git a/tests/language-feature/defer/generics.slang b/tests/language-feature/defer/generics.slang new file mode 100644 index 000000000..7ecd7de31 --- /dev/null +++ b/tests/language-feature/defer/generics.slang @@ -0,0 +1,48 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +interface IDroppable +{ + void drop(inout int i); +}; + +interface IDiscombobulatable +{ + void discombobulate(inout int i, int param); +} + +void genericFunc(int param) + where T : IDroppable, IDiscombobulatable +{ + T t; + int i = 0; + defer t.drop(i); + + int p = param; + defer t.discombobulate(i, p); + t.discombobulate(i, p); + p += 1; +} + +struct TestType : IDroppable, IDiscombobulatable +{ + void drop(inout int i) + { + outputBuffer[i++] = 0xFF; + } + + void discombobulate(inout int i, int param) + { + outputBuffer[i++] = param; + } +} + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + genericFunc(2); +} diff --git a/tests/language-feature/defer/generics.slang.expected.txt b/tests/language-feature/defer/generics.slang.expected.txt new file mode 100644 index 000000000..4a7a64d00 --- /dev/null +++ b/tests/language-feature/defer/generics.slang.expected.txt @@ -0,0 +1,3 @@ +2 +3 +FF diff --git a/tests/language-feature/defer/loop.slang b/tests/language-feature/defer/loop.slang new file mode 100644 index 000000000..b51955a14 --- /dev/null +++ b/tests/language-feature/defer/loop.slang @@ -0,0 +1,30 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + int i = 0; + defer outputBuffer[i++] = 128; + for (int j = 0; j < 4; j++) + { + defer outputBuffer[i++] = j*2+1; + outputBuffer[i++] = j*2; + } + + for (int j = 0; j < 5; j++) + { + if (j == 1) + continue; + defer outputBuffer[i++] = j*2+1; + if (j == 2) + continue; + outputBuffer[i++] = j*2; + if (j == 3) + break; + } +} diff --git a/tests/language-feature/defer/loop.slang.expected.txt b/tests/language-feature/defer/loop.slang.expected.txt new file mode 100644 index 000000000..cc7323432 --- /dev/null +++ b/tests/language-feature/defer/loop.slang.expected.txt @@ -0,0 +1,14 @@ +0 +1 +2 +3 +4 +5 +6 +7 +0 +1 +5 +6 +7 +80 diff --git a/tests/language-feature/defer/multiple-return.slang b/tests/language-feature/defer/multiple-return.slang new file mode 100644 index 000000000..a364c288e --- /dev/null +++ b/tests/language-feature/defer/multiple-return.slang @@ -0,0 +1,35 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//DISABLE_TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj +// ^ Due to DeviceMemoryBarrier() missing. + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +void testFunc(int index1, int index2, int condition) +{ + outputBuffer[index1] = 128; + defer outputBuffer[index1] = 0; + + outputBuffer[index2] = 3; + DeviceMemoryBarrier(); + + if (condition == 0) + { + outputBuffer[index2] = 1; + return; + } + + if (condition == 1) + { + outputBuffer[index2] = 2; + return; + } +} + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + for (int i = 0; i < 3; ++i) + testFunc(2*i, 2*i+1, i); +} diff --git a/tests/language-feature/defer/multiple-return.slang.expected.txt b/tests/language-feature/defer/multiple-return.slang.expected.txt new file mode 100644 index 000000000..5b06fc619 --- /dev/null +++ b/tests/language-feature/defer/multiple-return.slang.expected.txt @@ -0,0 +1,6 @@ +0 +1 +0 +2 +0 +3 diff --git a/tests/language-feature/defer/nested.slang b/tests/language-feature/defer/nested.slang new file mode 100644 index 000000000..6de518209 --- /dev/null +++ b/tests/language-feature/defer/nested.slang @@ -0,0 +1,18 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + int i = 0; + defer { + defer outputBuffer[i++] = 3; + defer outputBuffer[i++] = 2; + outputBuffer[i++] = 1; + } + outputBuffer[i++] = 0; +} diff --git a/tests/language-feature/defer/nested.slang.expected.txt b/tests/language-feature/defer/nested.slang.expected.txt new file mode 100644 index 000000000..bc856dafa --- /dev/null +++ b/tests/language-feature/defer/nested.slang.expected.txt @@ -0,0 +1,4 @@ +0 +1 +2 +3 diff --git a/tests/language-feature/defer/no-block.slang b/tests/language-feature/defer/no-block.slang new file mode 100644 index 000000000..e78c91112 --- /dev/null +++ b/tests/language-feature/defer/no-block.slang @@ -0,0 +1,28 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + int i = 0; + defer defer outputBuffer[i++] = 0xFF; + + outputBuffer[i++] = 0; + + for (int j = 0; j < 4; j++) + defer outputBuffer[i++] = j; + + if (i == 5) + defer outputBuffer[i++] = 0x80; + + while (i < 7) + defer outputBuffer[i++] = 0x81; + + do defer outputBuffer[i++] = 0x82; while (false); + + outputBuffer[i++] = 5; +} diff --git a/tests/language-feature/defer/no-block.slang.expected.txt b/tests/language-feature/defer/no-block.slang.expected.txt new file mode 100644 index 000000000..ad8c0f654 --- /dev/null +++ b/tests/language-feature/defer/no-block.slang.expected.txt @@ -0,0 +1,10 @@ +0 +0 +1 +2 +3 +80 +81 +82 +5 +FF diff --git a/tests/language-feature/defer/scoped.slang b/tests/language-feature/defer/scoped.slang new file mode 100644 index 000000000..447c02357 --- /dev/null +++ b/tests/language-feature/defer/scoped.slang @@ -0,0 +1,33 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + int i = 0; + defer outputBuffer[i++] = 7; + + { + defer outputBuffer[i++] = 0; + } + + outputBuffer[i++] = 1; + + { + defer outputBuffer[i++] = 2; + } + + for (int j = 0; j < 4; ++j) + { + { + defer outputBuffer[i++] = j+3; + if (j == 2) + continue; + } + outputBuffer[i++] = j+0x13; + } +} diff --git a/tests/language-feature/defer/scoped.slang.expected.txt b/tests/language-feature/defer/scoped.slang.expected.txt new file mode 100644 index 000000000..8d280bd37 --- /dev/null +++ b/tests/language-feature/defer/scoped.slang.expected.txt @@ -0,0 +1,11 @@ +0 +1 +2 +3 +13 +4 +14 +5 +6 +16 +7 diff --git a/tests/language-feature/defer/switch.slang b/tests/language-feature/defer/switch.slang new file mode 100644 index 000000000..704e6e884 --- /dev/null +++ b/tests/language-feature/defer/switch.slang @@ -0,0 +1,34 @@ +//TEST(compute):COMPARE_COMPUTE: -shaderobj +//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj +//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +void testFunc(int i, int param) +{ + defer outputBuffer[i++] = param*3+2; + switch(param) + { + case 0: + defer outputBuffer[i++] = 1; + outputBuffer[i++] = 0; + break; + case 1: + defer outputBuffer[i++] = 4; + outputBuffer[i++] = 3; + break; + default: + defer outputBuffer[i++] = 7; + outputBuffer[i++] = 6; + break; + }; +} + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + testFunc(0, 0); + testFunc(3, 1); + testFunc(6, 2); +} diff --git a/tests/language-feature/defer/switch.slang.expected.txt b/tests/language-feature/defer/switch.slang.expected.txt new file mode 100644 index 000000000..1000f9005 --- /dev/null +++ b/tests/language-feature/defer/switch.slang.expected.txt @@ -0,0 +1,9 @@ +0 +1 +2 +3 +4 +5 +6 +7 +8 -- cgit v1.2.3