summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorTheresa Foley <10618364+tangent-vector@users.noreply.github.com>2025-04-17 09:53:37 -0700
committerGitHub <noreply@github.com>2025-04-17 09:53:37 -0700
commit8d1dca337e4b74c4b88a434eb2df5889410aff7c (patch)
tree8d729a92b249d90723863264d547e3d4f2fae012 /source
parent04db5a95657a8c1ad1db36570eadaeedbea01cbb (diff)
Eliminate back-reference in ChildStmt (#6835)
* Eliminate back-reference in ChildStmt This change is part of a larger effort to improve the code for AST serialization in the Slang compiler. Tree structures are understandably easier to serialize than DAGs, and DAGs are easier than fully generaal graphs. The Slang AST nodes form a tree structure... except when they don't. Among the exceptions to nice tree-structured ASTs are: 1. References to `Decl`s are encoded as pointers to the AST `Decl` nodes themselves. This can result in cycles in the graph, and requires care in serialization. 2. Nodes that inherit from `Val` represent, well, *values* instead of actual pieces of syntax, and as such they are deduplicated so that identical values will (hopefully) be identical pointers. This results in a DAG structure for `Val`s, but at least it's not a general graph (except for cycles that go through a `Decl`). 3. There are some minor cases of DAG-structured sharing that the parser can introduce to deal with cases when a traditional-style declaration includes multiple declarators. E.g., given: ``` static int a, b; ``` The resulting `DeclGroup` will include distinct `Decl`s for `a` and `b`, which will share the `static` modifier through a `SharedModifiers` node, and the `int` type specifier through a `SharedTypeExpr` node. This duplication can be ignored, for the purposes of serialization, since duplicating those parts of the AST has no major down-sides. 4. There is the case of `ChildStmt`, used for things like `break` and `continue`, which stores a direct `Stmt*` to the enclosing parent statement being targetted. Storing the target is useful so that IR lowering doesn't need to repeat the work that the semantic checking logic did to associate each child statement with its parent. The parent link inside of `ChildStmt` creates a cycle in the AST `Stmt` hierarchy, since the outer statement contains the inner, and the inner statement stores a pointer to the outer. This change eliminates the last of these sources of complication for AST serialization, by changing the `ChildStmt` type to stored an integer ID for the enclosing statement that it matches to, and having each `BreakableStmt` (used to represent the outer `switch`, or loop, or whatever) generate its own unique ID as part of semantic checking. Note: if necessary, it is reasonable for the outer statement to have its unique ID generated as part of parsing, rather than semantic checking. * format code * Change unique ID to be a proper Decl The fix here is to make the "unique ID" representation be a full `Decl`-derived AST node, so that it is both allowed to break the tree-structuring rules cleanly, and it is also trivially guaranteed to be unique across all loaded ASTs. * format code --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-builder.h2
-rw-r--r--source/slang/slang-ast-stmt.h43
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-check-stmt.cpp122
-rw-r--r--source/slang/slang-lower-to-ir.cpp40
5 files changed, 149 insertions, 60 deletions
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index 67dfaaf52..daf49f3f7 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -670,6 +670,8 @@ public:
Index getId() { return m_id; }
+ BreakableStmt::UniqueID generateUniqueIDForStmt() { return create<UniqueStmtIDNode>(); }
+
/// Ctor
ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name);
diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h
index f6580ba21..4107664bf 100644
--- a/source/slang/slang-ast-stmt.h
+++ b/source/slang/slang-ast-stmt.h
@@ -84,10 +84,26 @@ class IfStmt : public Stmt
Stmt* negativeStatement = nullptr;
};
+class UniqueStmtIDNode : public Decl
+{
+ SLANG_AST_CLASS(UniqueStmtIDNode)
+};
+
// A statement that can be escaped with a `break`
class BreakableStmt : public ScopeStmt
{
SLANG_ABSTRACT_AST_CLASS(BreakableStmt)
+
+ /// A unique ID for this statement.
+ ///
+ /// Used by `ChildStmt` to reference the
+ /// enclosing statement.
+ ///
+ UniqueStmtIDNode* uniqueID = kInvalidUniqueID;
+
+ SLANG_UNREFLECTED
+ typedef UniqueStmtIDNode* UniqueID;
+ static constexpr UniqueID kInvalidUniqueID = nullptr;
};
class SwitchStmt : public BreakableStmt
@@ -98,7 +114,20 @@ class SwitchStmt : public BreakableStmt
Stmt* body = nullptr;
};
-class TargetCaseStmt : public Stmt
+// A statement that is expected to appear lexically nested inside
+// some other construct, and thus needs to keep track of the
+// outer statement that it is associated with...
+class ChildStmt : public Stmt
+{
+ SLANG_ABSTRACT_AST_CLASS(ChildStmt)
+
+ /// The unique ID of the enclosing statement this
+ /// child statement refers to.
+ ///
+ BreakableStmt::UniqueID targetOuterStmtID = BreakableStmt::kInvalidUniqueID;
+};
+
+class TargetCaseStmt : public ChildStmt
{
SLANG_AST_CLASS(TargetCaseStmt)
int32_t capability;
@@ -106,7 +135,7 @@ class TargetCaseStmt : public Stmt
Stmt* body = nullptr;
};
-class TargetSwitchStmt : public Stmt
+class TargetSwitchStmt : public BreakableStmt
{
SLANG_AST_CLASS(TargetSwitchStmt)
@@ -127,16 +156,6 @@ class IntrinsicAsmStmt : public Stmt
List<Expr*> args;
};
-// A statement that is expected to appear lexically nested inside
-// some other construct, and thus needs to keep track of the
-// outer statement that it is associated with...
-class ChildStmt : public Stmt
-{
- SLANG_ABSTRACT_AST_CLASS(ChildStmt)
-
- Stmt* parentStmt = nullptr;
-};
-
// a `case` or `default` statement inside a `switch`
//
// Note(tfoley): A correct AST for a C-like language would treat
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 8a1e79ce8..c9406cd1f 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -3037,6 +3037,8 @@ struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitor<SemanticsStmt
private:
void validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* sink);
+
+ void generateUniqueIDForStmt(BreakableStmt* stmt);
};
struct SemanticsDeclVisitorBase : public SemanticsVisitor
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index c85eb7593..8a914c60b 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -149,60 +149,102 @@ Stmt* SemanticsStmtVisitor::findOuterStmtWithLabel(Name* label)
return nullptr;
}
+void SemanticsStmtVisitor::generateUniqueIDForStmt(BreakableStmt* stmt)
+{
+ stmt->uniqueID = getASTBuilder()->generateUniqueIDForStmt();
+}
+
void SemanticsStmtVisitor::visitBreakStmt(BreakStmt* stmt)
{
- Stmt* targetStmt = nullptr;
+ // We need to identify the enclosing statement that
+ // this `break` is meant to break out of.
+ //
+ BreakableStmt* targetOuterStmt = nullptr;
if (stmt->targetLabel.type == TokenType::Identifier)
{
- // This is a break statement with an explicit target label.
- // Try to find the outer stmt with the label.
- targetStmt = findOuterStmtWithLabel(stmt->targetLabel.getName());
- if (!targetStmt)
+ // If this is a `break` statement that specifies
+ // an explicit label, then we will search for
+ // an outer statement matching that label.
+ //
+ auto foundOuterStmt = findOuterStmtWithLabel(stmt->targetLabel.getName());
+ if (!foundOuterStmt)
{
getSink()->diagnose(stmt, Diagnostics::breakLabelNotFound, stmt->targetLabel.getName());
}
- if (!as<BreakableStmt>(targetStmt))
+ else
{
- getSink()->diagnose(
- stmt,
- Diagnostics::targetLabelDoesNotMarkBreakableStmt,
- stmt->targetLabel.getName());
+ // It is possible that the labelled statement
+ // is not a valid one for a `break` to target,
+ // so we check for that next.
+ //
+ targetOuterStmt = as<BreakableStmt>(foundOuterStmt);
+ if (!targetOuterStmt)
+ {
+ getSink()->diagnose(
+ stmt,
+ Diagnostics::targetLabelDoesNotMarkBreakableStmt,
+ stmt->targetLabel.getName());
+ }
}
}
else
{
- // For `break` statements without an explicit target,
- // find the inner most breakable stmt.
- targetStmt = FindOuterStmt<BreakableStmt>();
- if (!targetStmt)
+ // If there is no explicit label on the `break` statement,
+ // then we are simply searching for the inner-most
+ // enclosing statement that is a valid `break` target.
+ //
+ targetOuterStmt = FindOuterStmt<BreakableStmt>();
+ if (!targetOuterStmt)
{
getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop);
}
}
- // If there is a defer statement before the breakable statement, it's
- // illegal.
- if (FindOuterStmt<DeferStmt>(targetStmt))
+ // We do not (currently) allow a `break` to proceed "through"
+ // an enclosing `defer` statement. Thus, we search for
+ // a possible enclosing `defer` statement, between the
+ // `stmt` being checked and the `targetOuterStmt` that
+ // `stmt` is trying to branch to.
+ //
+ // TODO: This is a reasonable feature to add down the line;
+ // it simply involves more implementation complexity than
+ // the simpler cases of `defer`.
+ //
+ if (targetOuterStmt)
{
- getSink()->diagnose(stmt, Diagnostics::breakInsideDefer);
- }
+ if (FindOuterStmt<DeferStmt>(targetOuterStmt))
+ {
+ getSink()->diagnose(stmt, Diagnostics::breakInsideDefer);
+ }
- stmt->parentStmt = targetStmt;
+ // We stash the ID of the target statement in the `break`
+ // statement so that they can be correlated later, during
+ // code generation.
+ //
+ stmt->targetOuterStmtID = targetOuterStmt->uniqueID;
+ }
}
void SemanticsStmtVisitor::visitContinueStmt(ContinueStmt* stmt)
{
- auto outer = FindOuterStmt<LoopStmt>();
- if (!outer)
+ auto targetOuterStmt = FindOuterStmt<LoopStmt>();
+ if (!targetOuterStmt)
{
getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop);
}
-
- if (FindOuterStmt<DeferStmt>(outer))
+ else
{
- getSink()->diagnose(stmt, Diagnostics::continueInsideDefer);
+ if (FindOuterStmt<DeferStmt>(targetOuterStmt))
+ {
+ getSink()->diagnose(stmt, Diagnostics::continueInsideDefer);
+ }
+
+ // We stash the ID of the target statement in the `continue`
+ // statement so that they can be correlated later, during
+ // code generation.
+ //
+ stmt->targetOuterStmtID = targetOuterStmt->uniqueID;
}
- stmt->parentStmt = outer;
}
Expr* SemanticsVisitor::checkPredicateExpr(Expr* expr)
@@ -219,6 +261,7 @@ Expr* SemanticsVisitor::checkPredicateExpr(Expr* expr)
void SemanticsStmtVisitor::visitDoWhileStmt(DoWhileStmt* stmt)
{
+ generateUniqueIDForStmt(stmt);
checkModifiers(stmt);
WithOuterStmt subContext(this, stmt);
@@ -229,6 +272,7 @@ void SemanticsStmtVisitor::visitDoWhileStmt(DoWhileStmt* stmt)
void SemanticsStmtVisitor::visitForStmt(ForStmt* stmt)
{
+ generateUniqueIDForStmt(stmt);
WithOuterStmt subContext(this, stmt);
checkModifiers(stmt);
checkStmt(stmt->initialStatement);
@@ -342,6 +386,7 @@ void SemanticsStmtVisitor::validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* s
void SemanticsStmtVisitor::visitSwitchStmt(SwitchStmt* stmt)
{
+ generateUniqueIDForStmt(stmt);
WithOuterStmt subContext(this, stmt);
// TODO(tfoley): need to coerce condition to an integral type...
@@ -372,11 +417,20 @@ void SemanticsStmtVisitor::visitCaseStmt(CaseStmt* stmt)
stmt->expr = expr;
stmt->exprVal = exprVal;
- stmt->parentStmt = switchStmt;
+
+ if (switchStmt)
+ {
+ // We stash the ID of the target statement in the `case`
+ // statement so that they can be correlated later, during
+ // code generation.
+ //
+ stmt->targetOuterStmtID = switchStmt->uniqueID;
+ }
}
void SemanticsStmtVisitor::visitTargetSwitchStmt(TargetSwitchStmt* stmt)
{
+ generateUniqueIDForStmt(stmt);
WithOuterStmt subContext(this, stmt);
HashSet<Stmt*> checkedStmt;
for (auto caseStmt : stmt->targetCases)
@@ -436,6 +490,10 @@ void SemanticsStmtVisitor::visitTargetCaseStmt(TargetCaseStmt* stmt)
{
getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch);
}
+ else
+ {
+ stmt->targetOuterStmtID = switchStmt->uniqueID;
+ }
WithOuterStmt subContext(this, stmt);
subContext.checkStmt(stmt->body);
}
@@ -454,7 +512,14 @@ void SemanticsStmtVisitor::visitDefaultStmt(DefaultStmt* stmt)
{
getSink()->diagnose(stmt, Diagnostics::defaultOutsideSwitch);
}
- stmt->parentStmt = switchStmt;
+ else
+ {
+ // We stash the ID of the target statement in the `case`
+ // statement so that they can be correlated later, during
+ // code generation.
+ //
+ stmt->targetOuterStmtID = switchStmt->uniqueID;
+ }
}
void SemanticsStmtVisitor::visitIfStmt(IfStmt* stmt)
@@ -520,6 +585,7 @@ void SemanticsStmtVisitor::visitReturnStmt(ReturnStmt* stmt)
void SemanticsStmtVisitor::visitWhileStmt(WhileStmt* stmt)
{
+ generateUniqueIDForStmt(stmt);
checkModifiers(stmt);
WithOuterStmt subContext(this, stmt);
stmt->predicate = checkPredicateExpr(stmt->predicate);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 260596dc3..7bc7d4fb4 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -486,8 +486,8 @@ struct SharedIRGenContext
// Map from an AST-level statement that can be
// used as the target of a `break` or `continue`
// to the appropriate basic block to jump to.
- Dictionary<Stmt*, IRBlock*> breakLabels;
- Dictionary<Stmt*, IRBlock*> continueLabels;
+ Dictionary<BreakableStmt::UniqueID, IRBlock*> breakLabels;
+ Dictionary<BreakableStmt::UniqueID, IRBlock*> continueLabels;
Dictionary<SourceFile*, IRInst*> mapSourceFileToDebugSourceInst;
Dictionary<String, IRInst*> mapSourcePathToDebugSourceInst;
@@ -6181,8 +6181,8 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// Register the `break` and `continue` labels so
// that we can find them for nested statements.
- context->shared->breakLabels.add(stmt, breakLabel);
- context->shared->continueLabels.add(stmt, continueLabel);
+ context->shared->breakLabels.add(stmt->uniqueID, breakLabel);
+ context->shared->continueLabels.add(stmt->uniqueID, continueLabel);
// Emit the branch that will start out loop,
// and then insert the block for the head.
@@ -6288,8 +6288,8 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// Register the `break` and `continue` labels so
// that we can find them for nested statements.
- context->shared->breakLabels.add(stmt, breakLabel);
- context->shared->continueLabels.add(stmt, continueLabel);
+ context->shared->breakLabels.add(stmt->uniqueID, breakLabel);
+ context->shared->continueLabels.add(stmt->uniqueID, continueLabel);
// Emit the branch that will start out loop,
// and then insert the block for the head.
@@ -6347,8 +6347,8 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// Register the `break` and `continue` labels so
// that we can find them for nested statements.
- context->shared->breakLabels.add(stmt, breakLabel);
- context->shared->continueLabels.add(stmt, continueLabel);
+ context->shared->breakLabels.add(stmt->uniqueID, breakLabel);
+ context->shared->continueLabels.add(stmt->uniqueID, continueLabel);
// Emit the branch that will start out loop,
// and then insert the block for the head.
@@ -6622,14 +6622,14 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// Semantic checking is responsible for finding
// the statement taht this `break` breaks out of
- auto parentStmt = stmt->parentStmt;
- SLANG_ASSERT(parentStmt);
+ auto targetStmtID = stmt->targetOuterStmtID;
+ SLANG_ASSERT(targetStmtID != BreakableStmt::kInvalidUniqueID);
// We just need to look up the basic block that
// corresponds to the break label for that statement,
// and then emit an instruction to jump to it.
IRBlock* targetBlock = nullptr;
- context->shared->breakLabels.tryGetValue(parentStmt, targetBlock);
+ context->shared->breakLabels.tryGetValue(targetStmtID, targetBlock);
SLANG_ASSERT(targetBlock);
getBuilder()->emitBreak(targetBlock);
}
@@ -6640,15 +6640,15 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// Semantic checking is responsible for finding
// the loop that this `continue` statement continues
- auto parentStmt = stmt->parentStmt;
- SLANG_ASSERT(parentStmt);
+ auto targetStmtID = stmt->targetOuterStmtID;
+ SLANG_ASSERT(targetStmtID != BreakableStmt::kInvalidUniqueID);
// We just need to look up the basic block that
// corresponds to the continue label for that statement,
// and then emit an instruction to jump to it.
IRBlock* targetBlock = nullptr;
- context->shared->continueLabels.tryGetValue(parentStmt, targetBlock);
+ context->shared->continueLabels.tryGetValue(targetStmtID, targetBlock);
SLANG_ASSERT(targetBlock);
getBuilder()->emitContinue(targetBlock);
}
@@ -6864,7 +6864,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// Register the `break` label so
// that we can find it for nested statements.
- context->shared->breakLabels.add(stmt, breakLabel);
+ context->shared->breakLabels.add(stmt->uniqueID, breakLabel);
builder->setInsertInto(initialBlock->getParent());
@@ -6935,7 +6935,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// (and that control flow will fall through to otherwise).
// This is the block that subsequent code will go into.
insertBlock(breakLabel);
- context->shared->breakLabels.remove(stmt);
+ context->shared->breakLabels.remove(stmt->uniqueID);
}
void visitTargetSwitchStmt(TargetSwitchStmt* stmt)
@@ -6947,7 +6947,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
startBlockIfNeeded(stmt);
auto initialBlock = builder->getBlock();
auto breakLabel = builder->createBlock();
- context->shared->breakLabels.add(stmt, breakLabel);
+ context->shared->breakLabels.add(stmt->uniqueID, breakLabel);
builder->setInsertInto(initialBlock->getParent());
List<IRInst*> args;
args.add(breakLabel);
@@ -6966,7 +6966,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
args.add(builder->getIntValue(builder->getIntType(), targetCase->capability));
args.add(caseBlock);
}
- context->shared->breakLabels.remove(stmt);
+ context->shared->breakLabels.remove(stmt->uniqueID);
builder->setInsertInto(initialBlock);
auto parentFunc = initialBlock->getParent();
@@ -7066,7 +7066,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// Register the `break` label so
// that we can find it for nested statements.
- context->shared->breakLabels.add(stmt, breakLabel);
+ context->shared->breakLabels.add(stmt->uniqueID, breakLabel);
builder->setInsertInto(initialBlock->getParent());
@@ -7119,7 +7119,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// (and that control flow will fall through to otherwise).
// This is the block that subsequent code will go into.
insertBlock(breakLabel);
- context->shared->breakLabels.remove(stmt);
+ context->shared->breakLabels.remove(stmt->uniqueID);
// If there is the branch attribute output the IR decoration
if (stmt->hasModifier<BranchAttribute>())