summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-builder.h2
-rw-r--r--source/slang/slang-ast-stmt.h43
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-check-stmt.cpp122
-rw-r--r--source/slang/slang-lower-to-ir.cpp40
5 files changed, 149 insertions, 60 deletions
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index 67dfaaf52..daf49f3f7 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -670,6 +670,8 @@ public:
Index getId() { return m_id; }
+ BreakableStmt::UniqueID generateUniqueIDForStmt() { return create<UniqueStmtIDNode>(); }
+
/// Ctor
ASTBuilder(SharedASTBuilder* sharedASTBuilder, const String& name);
diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h
index f6580ba21..4107664bf 100644
--- a/source/slang/slang-ast-stmt.h
+++ b/source/slang/slang-ast-stmt.h
@@ -84,10 +84,26 @@ class IfStmt : public Stmt
Stmt* negativeStatement = nullptr;
};
+class UniqueStmtIDNode : public Decl
+{
+ SLANG_AST_CLASS(UniqueStmtIDNode)
+};
+
// A statement that can be escaped with a `break`
class BreakableStmt : public ScopeStmt
{
SLANG_ABSTRACT_AST_CLASS(BreakableStmt)
+
+ /// A unique ID for this statement.
+ ///
+ /// Used by `ChildStmt` to reference the
+ /// enclosing statement.
+ ///
+ UniqueStmtIDNode* uniqueID = kInvalidUniqueID;
+
+ SLANG_UNREFLECTED
+ typedef UniqueStmtIDNode* UniqueID;
+ static constexpr UniqueID kInvalidUniqueID = nullptr;
};
class SwitchStmt : public BreakableStmt
@@ -98,7 +114,20 @@ class SwitchStmt : public BreakableStmt
Stmt* body = nullptr;
};
-class TargetCaseStmt : public Stmt
+// A statement that is expected to appear lexically nested inside
+// some other construct, and thus needs to keep track of the
+// outer statement that it is associated with...
+class ChildStmt : public Stmt
+{
+ SLANG_ABSTRACT_AST_CLASS(ChildStmt)
+
+ /// The unique ID of the enclosing statement this
+ /// child statement refers to.
+ ///
+ BreakableStmt::UniqueID targetOuterStmtID = BreakableStmt::kInvalidUniqueID;
+};
+
+class TargetCaseStmt : public ChildStmt
{
SLANG_AST_CLASS(TargetCaseStmt)
int32_t capability;
@@ -106,7 +135,7 @@ class TargetCaseStmt : public Stmt
Stmt* body = nullptr;
};
-class TargetSwitchStmt : public Stmt
+class TargetSwitchStmt : public BreakableStmt
{
SLANG_AST_CLASS(TargetSwitchStmt)
@@ -127,16 +156,6 @@ class IntrinsicAsmStmt : public Stmt
List<Expr*> args;
};
-// A statement that is expected to appear lexically nested inside
-// some other construct, and thus needs to keep track of the
-// outer statement that it is associated with...
-class ChildStmt : public Stmt
-{
- SLANG_ABSTRACT_AST_CLASS(ChildStmt)
-
- Stmt* parentStmt = nullptr;
-};
-
// a `case` or `default` statement inside a `switch`
//
// Note(tfoley): A correct AST for a C-like language would treat
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 8a1e79ce8..c9406cd1f 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -3037,6 +3037,8 @@ struct SemanticsStmtVisitor : public SemanticsVisitor, StmtVisitor<SemanticsStmt
private:
void validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* sink);
+
+ void generateUniqueIDForStmt(BreakableStmt* stmt);
};
struct SemanticsDeclVisitorBase : public SemanticsVisitor
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index c85eb7593..8a914c60b 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -149,60 +149,102 @@ Stmt* SemanticsStmtVisitor::findOuterStmtWithLabel(Name* label)
return nullptr;
}
+void SemanticsStmtVisitor::generateUniqueIDForStmt(BreakableStmt* stmt)
+{
+ stmt->uniqueID = getASTBuilder()->generateUniqueIDForStmt();
+}
+
void SemanticsStmtVisitor::visitBreakStmt(BreakStmt* stmt)
{
- Stmt* targetStmt = nullptr;
+ // We need to identify the enclosing statement that
+ // this `break` is meant to break out of.
+ //
+ BreakableStmt* targetOuterStmt = nullptr;
if (stmt->targetLabel.type == TokenType::Identifier)
{
- // This is a break statement with an explicit target label.
- // Try to find the outer stmt with the label.
- targetStmt = findOuterStmtWithLabel(stmt->targetLabel.getName());
- if (!targetStmt)
+ // If this is a `break` statement that specifies
+ // an explicit label, then we will search for
+ // an outer statement matching that label.
+ //
+ auto foundOuterStmt = findOuterStmtWithLabel(stmt->targetLabel.getName());
+ if (!foundOuterStmt)
{
getSink()->diagnose(stmt, Diagnostics::breakLabelNotFound, stmt->targetLabel.getName());
}
- if (!as<BreakableStmt>(targetStmt))
+ else
{
- getSink()->diagnose(
- stmt,
- Diagnostics::targetLabelDoesNotMarkBreakableStmt,
- stmt->targetLabel.getName());
+ // It is possible that the labelled statement
+ // is not a valid one for a `break` to target,
+ // so we check for that next.
+ //
+ targetOuterStmt = as<BreakableStmt>(foundOuterStmt);
+ if (!targetOuterStmt)
+ {
+ getSink()->diagnose(
+ stmt,
+ Diagnostics::targetLabelDoesNotMarkBreakableStmt,
+ stmt->targetLabel.getName());
+ }
}
}
else
{
- // For `break` statements without an explicit target,
- // find the inner most breakable stmt.
- targetStmt = FindOuterStmt<BreakableStmt>();
- if (!targetStmt)
+ // If there is no explicit label on the `break` statement,
+ // then we are simply searching for the inner-most
+ // enclosing statement that is a valid `break` target.
+ //
+ targetOuterStmt = FindOuterStmt<BreakableStmt>();
+ if (!targetOuterStmt)
{
getSink()->diagnose(stmt, Diagnostics::breakOutsideLoop);
}
}
- // If there is a defer statement before the breakable statement, it's
- // illegal.
- if (FindOuterStmt<DeferStmt>(targetStmt))
+ // We do not (currently) allow a `break` to proceed "through"
+ // an enclosing `defer` statement. Thus, we search for
+ // a possible enclosing `defer` statement, between the
+ // `stmt` being checked and the `targetOuterStmt` that
+ // `stmt` is trying to branch to.
+ //
+ // TODO: This is a reasonable feature to add down the line;
+ // it simply involves more implementation complexity than
+ // the simpler cases of `defer`.
+ //
+ if (targetOuterStmt)
{
- getSink()->diagnose(stmt, Diagnostics::breakInsideDefer);
- }
+ if (FindOuterStmt<DeferStmt>(targetOuterStmt))
+ {
+ getSink()->diagnose(stmt, Diagnostics::breakInsideDefer);
+ }
- stmt->parentStmt = targetStmt;
+ // We stash the ID of the target statement in the `break`
+ // statement so that they can be correlated later, during
+ // code generation.
+ //
+ stmt->targetOuterStmtID = targetOuterStmt->uniqueID;
+ }
}
void SemanticsStmtVisitor::visitContinueStmt(ContinueStmt* stmt)
{
- auto outer = FindOuterStmt<LoopStmt>();
- if (!outer)
+ auto targetOuterStmt = FindOuterStmt<LoopStmt>();
+ if (!targetOuterStmt)
{
getSink()->diagnose(stmt, Diagnostics::continueOutsideLoop);
}
-
- if (FindOuterStmt<DeferStmt>(outer))
+ else
{
- getSink()->diagnose(stmt, Diagnostics::continueInsideDefer);
+ if (FindOuterStmt<DeferStmt>(targetOuterStmt))
+ {
+ getSink()->diagnose(stmt, Diagnostics::continueInsideDefer);
+ }
+
+ // We stash the ID of the target statement in the `continue`
+ // statement so that they can be correlated later, during
+ // code generation.
+ //
+ stmt->targetOuterStmtID = targetOuterStmt->uniqueID;
}
- stmt->parentStmt = outer;
}
Expr* SemanticsVisitor::checkPredicateExpr(Expr* expr)
@@ -219,6 +261,7 @@ Expr* SemanticsVisitor::checkPredicateExpr(Expr* expr)
void SemanticsStmtVisitor::visitDoWhileStmt(DoWhileStmt* stmt)
{
+ generateUniqueIDForStmt(stmt);
checkModifiers(stmt);
WithOuterStmt subContext(this, stmt);
@@ -229,6 +272,7 @@ void SemanticsStmtVisitor::visitDoWhileStmt(DoWhileStmt* stmt)
void SemanticsStmtVisitor::visitForStmt(ForStmt* stmt)
{
+ generateUniqueIDForStmt(stmt);
WithOuterStmt subContext(this, stmt);
checkModifiers(stmt);
checkStmt(stmt->initialStatement);
@@ -342,6 +386,7 @@ void SemanticsStmtVisitor::validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* s
void SemanticsStmtVisitor::visitSwitchStmt(SwitchStmt* stmt)
{
+ generateUniqueIDForStmt(stmt);
WithOuterStmt subContext(this, stmt);
// TODO(tfoley): need to coerce condition to an integral type...
@@ -372,11 +417,20 @@ void SemanticsStmtVisitor::visitCaseStmt(CaseStmt* stmt)
stmt->expr = expr;
stmt->exprVal = exprVal;
- stmt->parentStmt = switchStmt;
+
+ if (switchStmt)
+ {
+ // We stash the ID of the target statement in the `case`
+ // statement so that they can be correlated later, during
+ // code generation.
+ //
+ stmt->targetOuterStmtID = switchStmt->uniqueID;
+ }
}
void SemanticsStmtVisitor::visitTargetSwitchStmt(TargetSwitchStmt* stmt)
{
+ generateUniqueIDForStmt(stmt);
WithOuterStmt subContext(this, stmt);
HashSet<Stmt*> checkedStmt;
for (auto caseStmt : stmt->targetCases)
@@ -436,6 +490,10 @@ void SemanticsStmtVisitor::visitTargetCaseStmt(TargetCaseStmt* stmt)
{
getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch);
}
+ else
+ {
+ stmt->targetOuterStmtID = switchStmt->uniqueID;
+ }
WithOuterStmt subContext(this, stmt);
subContext.checkStmt(stmt->body);
}
@@ -454,7 +512,14 @@ void SemanticsStmtVisitor::visitDefaultStmt(DefaultStmt* stmt)
{
getSink()->diagnose(stmt, Diagnostics::defaultOutsideSwitch);
}
- stmt->parentStmt = switchStmt;
+ else
+ {
+ // We stash the ID of the target statement in the `case`
+ // statement so that they can be correlated later, during
+ // code generation.
+ //
+ stmt->targetOuterStmtID = switchStmt->uniqueID;
+ }
}
void SemanticsStmtVisitor::visitIfStmt(IfStmt* stmt)
@@ -520,6 +585,7 @@ void SemanticsStmtVisitor::visitReturnStmt(ReturnStmt* stmt)
void SemanticsStmtVisitor::visitWhileStmt(WhileStmt* stmt)
{
+ generateUniqueIDForStmt(stmt);
checkModifiers(stmt);
WithOuterStmt subContext(this, stmt);
stmt->predicate = checkPredicateExpr(stmt->predicate);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 260596dc3..7bc7d4fb4 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -486,8 +486,8 @@ struct SharedIRGenContext
// Map from an AST-level statement that can be
// used as the target of a `break` or `continue`
// to the appropriate basic block to jump to.
- Dictionary<Stmt*, IRBlock*> breakLabels;
- Dictionary<Stmt*, IRBlock*> continueLabels;
+ Dictionary<BreakableStmt::UniqueID, IRBlock*> breakLabels;
+ Dictionary<BreakableStmt::UniqueID, IRBlock*> continueLabels;
Dictionary<SourceFile*, IRInst*> mapSourceFileToDebugSourceInst;
Dictionary<String, IRInst*> mapSourcePathToDebugSourceInst;
@@ -6181,8 +6181,8 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// Register the `break` and `continue` labels so
// that we can find them for nested statements.
- context->shared->breakLabels.add(stmt, breakLabel);
- context->shared->continueLabels.add(stmt, continueLabel);
+ context->shared->breakLabels.add(stmt->uniqueID, breakLabel);
+ context->shared->continueLabels.add(stmt->uniqueID, continueLabel);
// Emit the branch that will start out loop,
// and then insert the block for the head.
@@ -6288,8 +6288,8 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// Register the `break` and `continue` labels so
// that we can find them for nested statements.
- context->shared->breakLabels.add(stmt, breakLabel);
- context->shared->continueLabels.add(stmt, continueLabel);
+ context->shared->breakLabels.add(stmt->uniqueID, breakLabel);
+ context->shared->continueLabels.add(stmt->uniqueID, continueLabel);
// Emit the branch that will start out loop,
// and then insert the block for the head.
@@ -6347,8 +6347,8 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// Register the `break` and `continue` labels so
// that we can find them for nested statements.
- context->shared->breakLabels.add(stmt, breakLabel);
- context->shared->continueLabels.add(stmt, continueLabel);
+ context->shared->breakLabels.add(stmt->uniqueID, breakLabel);
+ context->shared->continueLabels.add(stmt->uniqueID, continueLabel);
// Emit the branch that will start out loop,
// and then insert the block for the head.
@@ -6622,14 +6622,14 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// Semantic checking is responsible for finding
// the statement taht this `break` breaks out of
- auto parentStmt = stmt->parentStmt;
- SLANG_ASSERT(parentStmt);
+ auto targetStmtID = stmt->targetOuterStmtID;
+ SLANG_ASSERT(targetStmtID != BreakableStmt::kInvalidUniqueID);
// We just need to look up the basic block that
// corresponds to the break label for that statement,
// and then emit an instruction to jump to it.
IRBlock* targetBlock = nullptr;
- context->shared->breakLabels.tryGetValue(parentStmt, targetBlock);
+ context->shared->breakLabels.tryGetValue(targetStmtID, targetBlock);
SLANG_ASSERT(targetBlock);
getBuilder()->emitBreak(targetBlock);
}
@@ -6640,15 +6640,15 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// Semantic checking is responsible for finding
// the loop that this `continue` statement continues
- auto parentStmt = stmt->parentStmt;
- SLANG_ASSERT(parentStmt);
+ auto targetStmtID = stmt->targetOuterStmtID;
+ SLANG_ASSERT(targetStmtID != BreakableStmt::kInvalidUniqueID);
// We just need to look up the basic block that
// corresponds to the continue label for that statement,
// and then emit an instruction to jump to it.
IRBlock* targetBlock = nullptr;
- context->shared->continueLabels.tryGetValue(parentStmt, targetBlock);
+ context->shared->continueLabels.tryGetValue(targetStmtID, targetBlock);
SLANG_ASSERT(targetBlock);
getBuilder()->emitContinue(targetBlock);
}
@@ -6864,7 +6864,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// Register the `break` label so
// that we can find it for nested statements.
- context->shared->breakLabels.add(stmt, breakLabel);
+ context->shared->breakLabels.add(stmt->uniqueID, breakLabel);
builder->setInsertInto(initialBlock->getParent());
@@ -6935,7 +6935,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// (and that control flow will fall through to otherwise).
// This is the block that subsequent code will go into.
insertBlock(breakLabel);
- context->shared->breakLabels.remove(stmt);
+ context->shared->breakLabels.remove(stmt->uniqueID);
}
void visitTargetSwitchStmt(TargetSwitchStmt* stmt)
@@ -6947,7 +6947,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
startBlockIfNeeded(stmt);
auto initialBlock = builder->getBlock();
auto breakLabel = builder->createBlock();
- context->shared->breakLabels.add(stmt, breakLabel);
+ context->shared->breakLabels.add(stmt->uniqueID, breakLabel);
builder->setInsertInto(initialBlock->getParent());
List<IRInst*> args;
args.add(breakLabel);
@@ -6966,7 +6966,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
args.add(builder->getIntValue(builder->getIntType(), targetCase->capability));
args.add(caseBlock);
}
- context->shared->breakLabels.remove(stmt);
+ context->shared->breakLabels.remove(stmt->uniqueID);
builder->setInsertInto(initialBlock);
auto parentFunc = initialBlock->getParent();
@@ -7066,7 +7066,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// Register the `break` label so
// that we can find it for nested statements.
- context->shared->breakLabels.add(stmt, breakLabel);
+ context->shared->breakLabels.add(stmt->uniqueID, breakLabel);
builder->setInsertInto(initialBlock->getParent());
@@ -7119,7 +7119,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// (and that control flow will fall through to otherwise).
// This is the block that subsequent code will go into.
insertBlock(breakLabel);
- context->shared->breakLabels.remove(stmt);
+ context->shared->breakLabels.remove(stmt->uniqueID);
// If there is the branch attribute output the IR decoration
if (stmt->hasModifier<BranchAttribute>())