diff options
| author | Theresa Foley <10618364+tangent-vector@users.noreply.github.com> | 2025-04-17 09:53:37 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-17 09:53:37 -0700 |
| commit | 8d1dca337e4b74c4b88a434eb2df5889410aff7c (patch) | |
| tree | 8d729a92b249d90723863264d547e3d4f2fae012 /source | |
| parent | 04db5a95657a8c1ad1db36570eadaeedbea01cbb (diff) | |
Eliminate back-reference in ChildStmt (#6835)
* Eliminate back-reference in ChildStmt
This change is part of a larger effort to improve the code for AST
serialization in the Slang compiler.
Tree structures are understandably easier to serialize than DAGs,
and DAGs are easier than fully generaal graphs.
The Slang AST nodes form a tree structure... except when they don't.
Among the exceptions to nice tree-structured ASTs are:
1. References to `Decl`s are encoded as pointers to the AST `Decl`
nodes themselves. This can result in cycles in the graph, and
requires care in serialization.
2. Nodes that inherit from `Val` represent, well, *values* instead
of actual pieces of syntax, and as such they are deduplicated so
that identical values will (hopefully) be identical pointers.
This results in a DAG structure for `Val`s, but at least it's not
a general graph (except for cycles that go through a `Decl`).
3. There are some minor cases of DAG-structured sharing that the
parser can introduce to deal with cases when a traditional-style
declaration includes multiple declarators. E.g., given:
```
static int a, b;
```
The resulting `DeclGroup` will include distinct `Decl`s for `a`
and `b`, which will share the `static` modifier through a
`SharedModifiers` node, and the `int` type specifier through a
`SharedTypeExpr` node.
This duplication can be ignored, for the purposes of serialization,
since duplicating those parts of the AST has no major down-sides.
4. There is the case of `ChildStmt`, used for things like `break`
and `continue`, which stores a direct `Stmt*` to the enclosing
parent statement being targetted. Storing the target is useful so
that IR lowering doesn't need to repeat the work that the semantic
checking logic did to associate each child statement with its parent.
The parent link inside of `ChildStmt` creates a cycle in the AST
`Stmt` hierarchy, since the outer statement contains the inner,
and the inner statement stores a pointer to the outer.
This change eliminates the last of these sources of complication for
AST serialization, by changing the `ChildStmt` type to stored an
integer ID for the enclosing statement that it matches to, and having
each `BreakableStmt` (used to represent the outer `switch`, or loop,
or whatever) generate its own unique ID as part of semantic checking.
Note: if necessary, it is reasonable for the outer statement to have
its unique ID generated as part of parsing, rather than semantic
checking.
* format code
* Change unique ID to be a proper Decl
The fix here is to make the "unique ID" representation be a full
`Decl`-derived AST node, so that it is both allowed to break the
tree-structuring rules cleanly, and it is also trivially guaranteed
to be unique across all loaded ASTs.
* format code
---------
Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ast-builder.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ast-stmt.h | 43 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-stmt.cpp | 122 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 40 |
5 files changed, 149 insertions, 60 deletions
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index 67dfaaf52..daf49f3f7 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -670,6 +670,8 @@ public: Index getId() { return m_id; } + BreakableStmt::UniqueID generateUniqueIDForStmt() { return create<UniqueStmtIDNode>(); } + /// Ctor ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name); diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h index f6580ba21..4107664bf 100644 --- a/source/slang/slang-ast-stmt.h +++ b/source/slang/slang-ast-stmt.h @@ -84,10 +84,26 @@ class IfStmt : public Stmt Stmt* negativeStatement = nullptr; }; +class UniqueStmtIDNode : public Decl +{ + SLANG_AST_CLASS(UniqueStmtIDNode) +}; + // A statement that can be escaped with a `break` class BreakableStmt : public ScopeStmt { SLANG_ABSTRACT_AST_CLASS(BreakableStmt) + + /// A unique ID for this statement. + /// + /// Used by `ChildStmt` to reference the + /// enclosing statement. + /// + UniqueStmtIDNode* uniqueID = kInvalidUniqueID; + + SLANG_UNREFLECTED + typedef UniqueStmtIDNode* UniqueID; + static constexpr UniqueID kInvalidUniqueID = nullptr; }; class SwitchStmt : public BreakableStmt @@ -98,7 +114,20 @@ class SwitchStmt : public BreakableStmt Stmt* body = nullptr; }; -class TargetCaseStmt : public Stmt +// A statement that is expected to appear lexically nested inside +// some other construct, and thus needs to keep track of the +// outer statement that it is associated with... +class ChildStmt : public Stmt +{ + SLANG_ABSTRACT_AST_CLASS(ChildStmt) + + /// The unique ID of the enclosing statement this + /// child statement refers to. + /// + BreakableStmt::UniqueID targetOuterStmtID = BreakableStmt::kInvalidUniqueID; +}; + +class TargetCaseStmt : public ChildStmt { SLANG_AST_CLASS(TargetCaseStmt) int32_t capability; @@ -106,7 +135,7 @@ class TargetCaseStmt : public Stmt Stmt* body = nullptr; }; -class TargetSwitchStmt : public Stmt +class TargetSwitchStmt : public BreakableStmt { SLANG_AST_CLASS(TargetSwitchStmt) @@ -127,16 +156,6 @@ class IntrinsicAsmStmt : public Stmt List<Expr*> args; }; -// A statement that is expected to appear lexically nested inside -// some other construct, and thus needs to keep track of the -// outer statement that it is associated with... -class ChildStmt : public Stmt -{ - SLANG_ABSTRACT_AST_CLASS(ChildStmt) - - Stmt* parentStmt = nullptr; -}; - // a `case` or `default` statement inside a `switch` // // Note(tfoley): A correct AST for a C-like language would treat diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 8a1e79ce8..c9406cd1f 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -3037,6 +3037,8 @@ struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitor<SemanticsStmt private: void validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* sink); + + void generateUniqueIDForStmt(BreakableStmt* stmt); }; struct SemanticsDeclVisitorBase : public SemanticsVisitor diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index c85eb7593..8a914c60b 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -149,60 +149,102 @@ Stmt* SemanticsStmtVisitor::findOuterStmtWithLabel(Name* label) return nullptr; } +void SemanticsStmtVisitor::generateUniqueIDForStmt(BreakableStmt* stmt) +{ + stmt->uniqueID = getASTBuilder()->generateUniqueIDForStmt(); +} + void SemanticsStmtVisitor::visitBreakStmt(BreakStmt* stmt) { - Stmt* targetStmt = nullptr; + // We need to identify the enclosing statement that + // this `break` is meant to break out of. + // + BreakableStmt* targetOuterStmt = nullptr; if (stmt->targetLabel.type == TokenType::Identifier) { - // This is a break statement with an explicit target label. - // Try to find the outer stmt with the label. - targetStmt = findOuterStmtWithLabel(stmt->targetLabel.getName()); - if (!targetStmt) + // If this is a `break` statement that specifies + // an explicit label, then we will search for + // an outer statement matching that label. + // + auto foundOuterStmt = findOuterStmtWithLabel(stmt->targetLabel.getName()); + if (!foundOuterStmt) { getSink()->diagnose(stmt, Diagnostics::breakLabelNotFound, stmt->targetLabel.getName()); } - if (!as<BreakableStmt>(targetStmt)) + else { - getSink()->diagnose( - stmt, - Diagnostics::targetLabelDoesNotMarkBreakableStmt, - stmt->targetLabel.getName()); + // It is possible that the labelled statement + // is not a valid one for a `break` to target, + // so we check for that next. + // + targetOuterStmt = as<BreakableStmt>(foundOuterStmt); + if (!targetOuterStmt) + { + getSink()->diagnose( + stmt, + Diagnostics::targetLabelDoesNotMarkBreakableStmt, + stmt->targetLabel.getName()); + } } } else { - // For `break` statements without an explicit target, - // find the inner most breakable stmt. - targetStmt = FindOuterStmt<BreakableStmt>(); - if (!targetStmt) + // If there is no explicit label on the `break` statement, + // then we are simply searching for the inner-most + // enclosing statement that is a valid `break` target. + // + targetOuterStmt = FindOuterStmt<BreakableStmt>(); + if (!targetOuterStmt) { getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop); } } - // If there is a defer statement before the breakable statement, it's - // illegal. - if (FindOuterStmt<DeferStmt>(targetStmt)) + // We do not (currently) allow a `break` to proceed "through" + // an enclosing `defer` statement. Thus, we search for + // a possible enclosing `defer` statement, between the + // `stmt` being checked and the `targetOuterStmt` that + // `stmt` is trying to branch to. + // + // TODO: This is a reasonable feature to add down the line; + // it simply involves more implementation complexity than + // the simpler cases of `defer`. + // + if (targetOuterStmt) { - getSink()->diagnose(stmt, Diagnostics::breakInsideDefer); - } + if (FindOuterStmt<DeferStmt>(targetOuterStmt)) + { + getSink()->diagnose(stmt, Diagnostics::breakInsideDefer); + } - stmt->parentStmt = targetStmt; + // We stash the ID of the target statement in the `break` + // statement so that they can be correlated later, during + // code generation. + // + stmt->targetOuterStmtID = targetOuterStmt->uniqueID; + } } void SemanticsStmtVisitor::visitContinueStmt(ContinueStmt* stmt) { - auto outer = FindOuterStmt<LoopStmt>(); - if (!outer) + auto targetOuterStmt = FindOuterStmt<LoopStmt>(); + if (!targetOuterStmt) { getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop); } - - if (FindOuterStmt<DeferStmt>(outer)) + else { - getSink()->diagnose(stmt, Diagnostics::continueInsideDefer); + if (FindOuterStmt<DeferStmt>(targetOuterStmt)) + { + getSink()->diagnose(stmt, Diagnostics::continueInsideDefer); + } + + // We stash the ID of the target statement in the `continue` + // statement so that they can be correlated later, during + // code generation. + // + stmt->targetOuterStmtID = targetOuterStmt->uniqueID; } - stmt->parentStmt = outer; } Expr* SemanticsVisitor::checkPredicateExpr(Expr* expr) @@ -219,6 +261,7 @@ Expr* SemanticsVisitor::checkPredicateExpr(Expr* expr) void SemanticsStmtVisitor::visitDoWhileStmt(DoWhileStmt* stmt) { + generateUniqueIDForStmt(stmt); checkModifiers(stmt); WithOuterStmt subContext(this, stmt); @@ -229,6 +272,7 @@ void SemanticsStmtVisitor::visitDoWhileStmt(DoWhileStmt* stmt) void SemanticsStmtVisitor::visitForStmt(ForStmt* stmt) { + generateUniqueIDForStmt(stmt); WithOuterStmt subContext(this, stmt); checkModifiers(stmt); checkStmt(stmt->initialStatement); @@ -342,6 +386,7 @@ void SemanticsStmtVisitor::validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* s void SemanticsStmtVisitor::visitSwitchStmt(SwitchStmt* stmt) { + generateUniqueIDForStmt(stmt); WithOuterStmt subContext(this, stmt); // TODO(tfoley): need to coerce condition to an integral type... @@ -372,11 +417,20 @@ void SemanticsStmtVisitor::visitCaseStmt(CaseStmt* stmt) stmt->expr = expr; stmt->exprVal = exprVal; - stmt->parentStmt = switchStmt; + + if (switchStmt) + { + // We stash the ID of the target statement in the `case` + // statement so that they can be correlated later, during + // code generation. + // + stmt->targetOuterStmtID = switchStmt->uniqueID; + } } void SemanticsStmtVisitor::visitTargetSwitchStmt(TargetSwitchStmt* stmt) { + generateUniqueIDForStmt(stmt); WithOuterStmt subContext(this, stmt); HashSet<Stmt*> checkedStmt; for (auto caseStmt : stmt->targetCases) @@ -436,6 +490,10 @@ void SemanticsStmtVisitor::visitTargetCaseStmt(TargetCaseStmt* stmt) { getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch); } + else + { + stmt->targetOuterStmtID = switchStmt->uniqueID; + } WithOuterStmt subContext(this, stmt); subContext.checkStmt(stmt->body); } @@ -454,7 +512,14 @@ void SemanticsStmtVisitor::visitDefaultStmt(DefaultStmt* stmt) { getSink()->diagnose(stmt, Diagnostics::defaultOutsideSwitch); } - stmt->parentStmt = switchStmt; + else + { + // We stash the ID of the target statement in the `case` + // statement so that they can be correlated later, during + // code generation. + // + stmt->targetOuterStmtID = switchStmt->uniqueID; + } } void SemanticsStmtVisitor::visitIfStmt(IfStmt* stmt) @@ -520,6 +585,7 @@ void SemanticsStmtVisitor::visitReturnStmt(ReturnStmt* stmt) void SemanticsStmtVisitor::visitWhileStmt(WhileStmt* stmt) { + generateUniqueIDForStmt(stmt); checkModifiers(stmt); WithOuterStmt subContext(this, stmt); stmt->predicate = checkPredicateExpr(stmt->predicate); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 260596dc3..7bc7d4fb4 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -486,8 +486,8 @@ struct SharedIRGenContext // Map from an AST-level statement that can be // used as the target of a `break` or `continue` // to the appropriate basic block to jump to. - Dictionary<Stmt*, IRBlock*> breakLabels; - Dictionary<Stmt*, IRBlock*> continueLabels; + Dictionary<BreakableStmt::UniqueID, IRBlock*> breakLabels; + Dictionary<BreakableStmt::UniqueID, IRBlock*> continueLabels; Dictionary<SourceFile*, IRInst*> mapSourceFileToDebugSourceInst; Dictionary<String, IRInst*> mapSourcePathToDebugSourceInst; @@ -6181,8 +6181,8 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // Register the `break` and `continue` labels so // that we can find them for nested statements. - context->shared->breakLabels.add(stmt, breakLabel); - context->shared->continueLabels.add(stmt, continueLabel); + context->shared->breakLabels.add(stmt->uniqueID, breakLabel); + context->shared->continueLabels.add(stmt->uniqueID, continueLabel); // Emit the branch that will start out loop, // and then insert the block for the head. @@ -6288,8 +6288,8 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // Register the `break` and `continue` labels so // that we can find them for nested statements. - context->shared->breakLabels.add(stmt, breakLabel); - context->shared->continueLabels.add(stmt, continueLabel); + context->shared->breakLabels.add(stmt->uniqueID, breakLabel); + context->shared->continueLabels.add(stmt->uniqueID, continueLabel); // Emit the branch that will start out loop, // and then insert the block for the head. @@ -6347,8 +6347,8 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // Register the `break` and `continue` labels so // that we can find them for nested statements. - context->shared->breakLabels.add(stmt, breakLabel); - context->shared->continueLabels.add(stmt, continueLabel); + context->shared->breakLabels.add(stmt->uniqueID, breakLabel); + context->shared->continueLabels.add(stmt->uniqueID, continueLabel); // Emit the branch that will start out loop, // and then insert the block for the head. @@ -6622,14 +6622,14 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // Semantic checking is responsible for finding // the statement taht this `break` breaks out of - auto parentStmt = stmt->parentStmt; - SLANG_ASSERT(parentStmt); + auto targetStmtID = stmt->targetOuterStmtID; + SLANG_ASSERT(targetStmtID != BreakableStmt::kInvalidUniqueID); // We just need to look up the basic block that // corresponds to the break label for that statement, // and then emit an instruction to jump to it. IRBlock* targetBlock = nullptr; - context->shared->breakLabels.tryGetValue(parentStmt, targetBlock); + context->shared->breakLabels.tryGetValue(targetStmtID, targetBlock); SLANG_ASSERT(targetBlock); getBuilder()->emitBreak(targetBlock); } @@ -6640,15 +6640,15 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // Semantic checking is responsible for finding // the loop that this `continue` statement continues - auto parentStmt = stmt->parentStmt; - SLANG_ASSERT(parentStmt); + auto targetStmtID = stmt->targetOuterStmtID; + SLANG_ASSERT(targetStmtID != BreakableStmt::kInvalidUniqueID); // We just need to look up the basic block that // corresponds to the continue label for that statement, // and then emit an instruction to jump to it. IRBlock* targetBlock = nullptr; - context->shared->continueLabels.tryGetValue(parentStmt, targetBlock); + context->shared->continueLabels.tryGetValue(targetStmtID, targetBlock); SLANG_ASSERT(targetBlock); getBuilder()->emitContinue(targetBlock); } @@ -6864,7 +6864,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // Register the `break` label so // that we can find it for nested statements. - context->shared->breakLabels.add(stmt, breakLabel); + context->shared->breakLabels.add(stmt->uniqueID, breakLabel); builder->setInsertInto(initialBlock->getParent()); @@ -6935,7 +6935,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // (and that control flow will fall through to otherwise). // This is the block that subsequent code will go into. insertBlock(breakLabel); - context->shared->breakLabels.remove(stmt); + context->shared->breakLabels.remove(stmt->uniqueID); } void visitTargetSwitchStmt(TargetSwitchStmt* stmt) @@ -6947,7 +6947,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> startBlockIfNeeded(stmt); auto initialBlock = builder->getBlock(); auto breakLabel = builder->createBlock(); - context->shared->breakLabels.add(stmt, breakLabel); + context->shared->breakLabels.add(stmt->uniqueID, breakLabel); builder->setInsertInto(initialBlock->getParent()); List<IRInst*> args; args.add(breakLabel); @@ -6966,7 +6966,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> args.add(builder->getIntValue(builder->getIntType(), targetCase->capability)); args.add(caseBlock); } - context->shared->breakLabels.remove(stmt); + context->shared->breakLabels.remove(stmt->uniqueID); builder->setInsertInto(initialBlock); auto parentFunc = initialBlock->getParent(); @@ -7066,7 +7066,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // Register the `break` label so // that we can find it for nested statements. - context->shared->breakLabels.add(stmt, breakLabel); + context->shared->breakLabels.add(stmt->uniqueID, breakLabel); builder->setInsertInto(initialBlock->getParent()); @@ -7119,7 +7119,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // (and that control flow will fall through to otherwise). // This is the block that subsequent code will go into. insertBlock(breakLabel); - context->shared->breakLabels.remove(stmt); + context->shared->breakLabels.remove(stmt->uniqueID); // If there is the branch attribute output the IR decoration if (stmt->hasModifier<BranchAttribute>()) |
