diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-builder.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ast-stmt.h | 43 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-stmt.cpp | 122 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 40 |
5 files changed, 149 insertions, 60 deletions
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index 67dfaaf52..daf49f3f7 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -670,6 +670,8 @@ public: Index getId() { return m_id; } + BreakableStmt::UniqueID generateUniqueIDForStmt() { return create<UniqueStmtIDNode>(); } + /// Ctor ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name); diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h index f6580ba21..4107664bf 100644 --- a/source/slang/slang-ast-stmt.h +++ b/source/slang/slang-ast-stmt.h @@ -84,10 +84,26 @@ class IfStmt : public Stmt Stmt* negativeStatement = nullptr; }; +class UniqueStmtIDNode : public Decl +{ + SLANG_AST_CLASS(UniqueStmtIDNode) +}; + // A statement that can be escaped with a `break` class BreakableStmt : public ScopeStmt { SLANG_ABSTRACT_AST_CLASS(BreakableStmt) + + /// A unique ID for this statement. + /// + /// Used by `ChildStmt` to reference the + /// enclosing statement. + /// + UniqueStmtIDNode* uniqueID = kInvalidUniqueID; + + SLANG_UNREFLECTED + typedef UniqueStmtIDNode* UniqueID; + static constexpr UniqueID kInvalidUniqueID = nullptr; }; class SwitchStmt : public BreakableStmt @@ -98,7 +114,20 @@ class SwitchStmt : public BreakableStmt Stmt* body = nullptr; }; -class TargetCaseStmt : public Stmt +// A statement that is expected to appear lexically nested inside +// some other construct, and thus needs to keep track of the +// outer statement that it is associated with... +class ChildStmt : public Stmt +{ + SLANG_ABSTRACT_AST_CLASS(ChildStmt) + + /// The unique ID of the enclosing statement this + /// child statement refers to. + /// + BreakableStmt::UniqueID targetOuterStmtID = BreakableStmt::kInvalidUniqueID; +}; + +class TargetCaseStmt : public ChildStmt { SLANG_AST_CLASS(TargetCaseStmt) int32_t capability; @@ -106,7 +135,7 @@ class TargetCaseStmt : public Stmt Stmt* body = nullptr; }; -class TargetSwitchStmt : public Stmt +class TargetSwitchStmt : public BreakableStmt { SLANG_AST_CLASS(TargetSwitchStmt) @@ -127,16 +156,6 @@ class IntrinsicAsmStmt : public Stmt List<Expr*> args; }; -// A statement that is expected to appear lexically nested inside -// some other construct, and thus needs to keep track of the -// outer statement that it is associated with... -class ChildStmt : public Stmt -{ - SLANG_ABSTRACT_AST_CLASS(ChildStmt) - - Stmt* parentStmt = nullptr; -}; - // a `case` or `default` statement inside a `switch` // // Note(tfoley): A correct AST for a C-like language would treat diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 8a1e79ce8..c9406cd1f 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -3037,6 +3037,8 @@ struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitor<SemanticsStmt private: void validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* sink); + + void generateUniqueIDForStmt(BreakableStmt* stmt); }; struct SemanticsDeclVisitorBase : public SemanticsVisitor diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index c85eb7593..8a914c60b 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -149,60 +149,102 @@ Stmt* SemanticsStmtVisitor::findOuterStmtWithLabel(Name* label) return nullptr; } +void SemanticsStmtVisitor::generateUniqueIDForStmt(BreakableStmt* stmt) +{ + stmt->uniqueID = getASTBuilder()->generateUniqueIDForStmt(); +} + void SemanticsStmtVisitor::visitBreakStmt(BreakStmt* stmt) { - Stmt* targetStmt = nullptr; + // We need to identify the enclosing statement that + // this `break` is meant to break out of. + // + BreakableStmt* targetOuterStmt = nullptr; if (stmt->targetLabel.type == TokenType::Identifier) { - // This is a break statement with an explicit target label. - // Try to find the outer stmt with the label. - targetStmt = findOuterStmtWithLabel(stmt->targetLabel.getName()); - if (!targetStmt) + // If this is a `break` statement that specifies + // an explicit label, then we will search for + // an outer statement matching that label. + // + auto foundOuterStmt = findOuterStmtWithLabel(stmt->targetLabel.getName()); + if (!foundOuterStmt) { getSink()->diagnose(stmt, Diagnostics::breakLabelNotFound, stmt->targetLabel.getName()); } - if (!as<BreakableStmt>(targetStmt)) + else { - getSink()->diagnose( - stmt, - Diagnostics::targetLabelDoesNotMarkBreakableStmt, - stmt->targetLabel.getName()); + // It is possible that the labelled statement + // is not a valid one for a `break` to target, + // so we check for that next. + // + targetOuterStmt = as<BreakableStmt>(foundOuterStmt); + if (!targetOuterStmt) + { + getSink()->diagnose( + stmt, + Diagnostics::targetLabelDoesNotMarkBreakableStmt, + stmt->targetLabel.getName()); + } } } else { - // For `break` statements without an explicit target, - // find the inner most breakable stmt. - targetStmt = FindOuterStmt<BreakableStmt>(); - if (!targetStmt) + // If there is no explicit label on the `break` statement, + // then we are simply searching for the inner-most + // enclosing statement that is a valid `break` target. + // + targetOuterStmt = FindOuterStmt<BreakableStmt>(); + if (!targetOuterStmt) { getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop); } } - // If there is a defer statement before the breakable statement, it's - // illegal. - if (FindOuterStmt<DeferStmt>(targetStmt)) + // We do not (currently) allow a `break` to proceed "through" + // an enclosing `defer` statement. Thus, we search for + // a possible enclosing `defer` statement, between the + // `stmt` being checked and the `targetOuterStmt` that + // `stmt` is trying to branch to. + // + // TODO: This is a reasonable feature to add down the line; + // it simply involves more implementation complexity than + // the simpler cases of `defer`. + // + if (targetOuterStmt) { - getSink()->diagnose(stmt, Diagnostics::breakInsideDefer); - } + if (FindOuterStmt<DeferStmt>(targetOuterStmt)) + { + getSink()->diagnose(stmt, Diagnostics::breakInsideDefer); + } - stmt->parentStmt = targetStmt; + // We stash the ID of the target statement in the `break` + // statement so that they can be correlated later, during + // code generation. + // + stmt->targetOuterStmtID = targetOuterStmt->uniqueID; + } } void SemanticsStmtVisitor::visitContinueStmt(ContinueStmt* stmt) { - auto outer = FindOuterStmt<LoopStmt>(); - if (!outer) + auto targetOuterStmt = FindOuterStmt<LoopStmt>(); + if (!targetOuterStmt) { getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop); } - - if (FindOuterStmt<DeferStmt>(outer)) + else { - getSink()->diagnose(stmt, Diagnostics::continueInsideDefer); + if (FindOuterStmt<DeferStmt>(targetOuterStmt)) + { + getSink()->diagnose(stmt, Diagnostics::continueInsideDefer); + } + + // We stash the ID of the target statement in the `continue` + // statement so that they can be correlated later, during + // code generation. + // + stmt->targetOuterStmtID = targetOuterStmt->uniqueID; } - stmt->parentStmt = outer; } Expr* SemanticsVisitor::checkPredicateExpr(Expr* expr) @@ -219,6 +261,7 @@ Expr* SemanticsVisitor::checkPredicateExpr(Expr* expr) void SemanticsStmtVisitor::visitDoWhileStmt(DoWhileStmt* stmt) { + generateUniqueIDForStmt(stmt); checkModifiers(stmt); WithOuterStmt subContext(this, stmt); @@ -229,6 +272,7 @@ void SemanticsStmtVisitor::visitDoWhileStmt(DoWhileStmt* stmt) void SemanticsStmtVisitor::visitForStmt(ForStmt* stmt) { + generateUniqueIDForStmt(stmt); WithOuterStmt subContext(this, stmt); checkModifiers(stmt); checkStmt(stmt->initialStatement); @@ -342,6 +386,7 @@ void SemanticsStmtVisitor::validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* s void SemanticsStmtVisitor::visitSwitchStmt(SwitchStmt* stmt) { + generateUniqueIDForStmt(stmt); WithOuterStmt subContext(this, stmt); // TODO(tfoley): need to coerce condition to an integral type... @@ -372,11 +417,20 @@ void SemanticsStmtVisitor::visitCaseStmt(CaseStmt* stmt) stmt->expr = expr; stmt->exprVal = exprVal; - stmt->parentStmt = switchStmt; + + if (switchStmt) + { + // We stash the ID of the target statement in the `case` + // statement so that they can be correlated later, during + // code generation. + // + stmt->targetOuterStmtID = switchStmt->uniqueID; + } } void SemanticsStmtVisitor::visitTargetSwitchStmt(TargetSwitchStmt* stmt) { + generateUniqueIDForStmt(stmt); WithOuterStmt subContext(this, stmt); HashSet<Stmt*> checkedStmt; for (auto caseStmt : stmt->targetCases) @@ -436,6 +490,10 @@ void SemanticsStmtVisitor::visitTargetCaseStmt(TargetCaseStmt* stmt) { getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch); } + else + { + stmt->targetOuterStmtID = switchStmt->uniqueID; + } WithOuterStmt subContext(this, stmt); subContext.checkStmt(stmt->body); } @@ -454,7 +512,14 @@ void SemanticsStmtVisitor::visitDefaultStmt(DefaultStmt* stmt) { getSink()->diagnose(stmt, Diagnostics::defaultOutsideSwitch); } - stmt->parentStmt = switchStmt; + else + { + // We stash the ID of the target statement in the `case` + // statement so that they can be correlated later, during + // code generation. + // + stmt->targetOuterStmtID = switchStmt->uniqueID; + } } void SemanticsStmtVisitor::visitIfStmt(IfStmt* stmt) @@ -520,6 +585,7 @@ void SemanticsStmtVisitor::visitReturnStmt(ReturnStmt* stmt) void SemanticsStmtVisitor::visitWhileStmt(WhileStmt* stmt) { + generateUniqueIDForStmt(stmt); checkModifiers(stmt); WithOuterStmt subContext(this, stmt); stmt->predicate = checkPredicateExpr(stmt->predicate); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 260596dc3..7bc7d4fb4 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -486,8 +486,8 @@ struct SharedIRGenContext // Map from an AST-level statement that can be // used as the target of a `break` or `continue` // to the appropriate basic block to jump to. - Dictionary<Stmt*, IRBlock*> breakLabels; - Dictionary<Stmt*, IRBlock*> continueLabels; + Dictionary<BreakableStmt::UniqueID, IRBlock*> breakLabels; + Dictionary<BreakableStmt::UniqueID, IRBlock*> continueLabels; Dictionary<SourceFile*, IRInst*> mapSourceFileToDebugSourceInst; Dictionary<String, IRInst*> mapSourcePathToDebugSourceInst; @@ -6181,8 +6181,8 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // Register the `break` and `continue` labels so // that we can find them for nested statements. - context->shared->breakLabels.add(stmt, breakLabel); - context->shared->continueLabels.add(stmt, continueLabel); + context->shared->breakLabels.add(stmt->uniqueID, breakLabel); + context->shared->continueLabels.add(stmt->uniqueID, continueLabel); // Emit the branch that will start out loop, // and then insert the block for the head. @@ -6288,8 +6288,8 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // Register the `break` and `continue` labels so // that we can find them for nested statements. - context->shared->breakLabels.add(stmt, breakLabel); - context->shared->continueLabels.add(stmt, continueLabel); + context->shared->breakLabels.add(stmt->uniqueID, breakLabel); + context->shared->continueLabels.add(stmt->uniqueID, continueLabel); // Emit the branch that will start out loop, // and then insert the block for the head. @@ -6347,8 +6347,8 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // Register the `break` and `continue` labels so // that we can find them for nested statements. - context->shared->breakLabels.add(stmt, breakLabel); - context->shared->continueLabels.add(stmt, continueLabel); + context->shared->breakLabels.add(stmt->uniqueID, breakLabel); + context->shared->continueLabels.add(stmt->uniqueID, continueLabel); // Emit the branch that will start out loop, // and then insert the block for the head. @@ -6622,14 +6622,14 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // Semantic checking is responsible for finding // the statement taht this `break` breaks out of - auto parentStmt = stmt->parentStmt; - SLANG_ASSERT(parentStmt); + auto targetStmtID = stmt->targetOuterStmtID; + SLANG_ASSERT(targetStmtID != BreakableStmt::kInvalidUniqueID); // We just need to look up the basic block that // corresponds to the break label for that statement, // and then emit an instruction to jump to it. IRBlock* targetBlock = nullptr; - context->shared->breakLabels.tryGetValue(parentStmt, targetBlock); + context->shared->breakLabels.tryGetValue(targetStmtID, targetBlock); SLANG_ASSERT(targetBlock); getBuilder()->emitBreak(targetBlock); } @@ -6640,15 +6640,15 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // Semantic checking is responsible for finding // the loop that this `continue` statement continues - auto parentStmt = stmt->parentStmt; - SLANG_ASSERT(parentStmt); + auto targetStmtID = stmt->targetOuterStmtID; + SLANG_ASSERT(targetStmtID != BreakableStmt::kInvalidUniqueID); // We just need to look up the basic block that // corresponds to the continue label for that statement, // and then emit an instruction to jump to it. IRBlock* targetBlock = nullptr; - context->shared->continueLabels.tryGetValue(parentStmt, targetBlock); + context->shared->continueLabels.tryGetValue(targetStmtID, targetBlock); SLANG_ASSERT(targetBlock); getBuilder()->emitContinue(targetBlock); } @@ -6864,7 +6864,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // Register the `break` label so // that we can find it for nested statements. - context->shared->breakLabels.add(stmt, breakLabel); + context->shared->breakLabels.add(stmt->uniqueID, breakLabel); builder->setInsertInto(initialBlock->getParent()); @@ -6935,7 +6935,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // (and that control flow will fall through to otherwise). // This is the block that subsequent code will go into. insertBlock(breakLabel); - context->shared->breakLabels.remove(stmt); + context->shared->breakLabels.remove(stmt->uniqueID); } void visitTargetSwitchStmt(TargetSwitchStmt* stmt) @@ -6947,7 +6947,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> startBlockIfNeeded(stmt); auto initialBlock = builder->getBlock(); auto breakLabel = builder->createBlock(); - context->shared->breakLabels.add(stmt, breakLabel); + context->shared->breakLabels.add(stmt->uniqueID, breakLabel); builder->setInsertInto(initialBlock->getParent()); List<IRInst*> args; args.add(breakLabel); @@ -6966,7 +6966,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> args.add(builder->getIntValue(builder->getIntType(), targetCase->capability)); args.add(caseBlock); } - context->shared->breakLabels.remove(stmt); + context->shared->breakLabels.remove(stmt->uniqueID); builder->setInsertInto(initialBlock); auto parentFunc = initialBlock->getParent(); @@ -7066,7 +7066,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // Register the `break` label so // that we can find it for nested statements. - context->shared->breakLabels.add(stmt, breakLabel); + context->shared->breakLabels.add(stmt->uniqueID, breakLabel); builder->setInsertInto(initialBlock->getParent()); @@ -7119,7 +7119,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // (and that control flow will fall through to otherwise). // This is the block that subsequent code will go into. insertBlock(breakLabel); - context->shared->breakLabels.remove(stmt); + context->shared->breakLabels.remove(stmt->uniqueID); // If there is the branch attribute output the IR decoration if (stmt->hasModifier<BranchAttribute>()) |
