summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorsriramm-nv <85252063+sriramm-nv@users.noreply.github.com>2024-04-03 15:10:16 -0700
committerGitHub <noreply@github.com>2024-04-03 15:10:16 -0700
commitf6c49fdb2cc7ead1943d944097220cedd142792f (patch)
tree86f8a09544d89cab007f3168f10396bcee73d12e /source
parent2768e429e556f3978825beaf71cf361626057135 (diff)
Fix assertions due to malformed switch statements (#3858)
* Fix assertions due to malformed switch statements Fixes the issue #2955 * Checks for multiple case statements with same values * Checks for multiple default cases * Constant-folds case exprs into an Integer value * fix the comments, and updated error code * one-line comment on diagnostic code
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ast-stmt.h2
-rw-r--r--source/slang/slang-check-impl.h3
-rw-r--r--source/slang/slang-check-stmt.cpp54
-rw-r--r--source/slang/slang-diagnostic-defs.h4
4 files changed, 58 insertions, 5 deletions
diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h
index 055785333..afa606456 100644
--- a/source/slang/slang-ast-stmt.h
+++ b/source/slang/slang-ast-stmt.h
@@ -143,6 +143,8 @@ class CaseStmt : public CaseStmtBase
SLANG_AST_CLASS(CaseStmt)
Expr* expr = nullptr;
+
+ Val* exprVal = nullptr;
};
// a `default` statement inside a `switch`
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 48bd2093b..55edba6b9 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2741,6 +2741,9 @@ namespace Slang
void tryInferLoopMaxIterations(ForStmt* stmt);
void checkLoopInDifferentiableFunc(Stmt* stmt);
+
+ private:
+ void validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* sink);
};
struct SemanticsDeclVisitorBase
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index b68465087..729b24d35 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -239,6 +239,47 @@ namespace Slang
subContext.checkStmt(stmt->body);
}
+ void SemanticsStmtVisitor::validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* sink)
+ {
+ auto blockStmt = as<BlockStmt>(stmt->body);
+ if (!blockStmt)
+ return;
+
+ auto seqStmt = as<SeqStmt>(blockStmt->body);
+ if (!seqStmt)
+ return;
+
+ bool hasDefaultStmt = false;
+ HashSet<Val*> caseStmtVals;
+ for (auto& sStmt : seqStmt->stmts)
+ {
+ if (auto caseStmt = as<CaseStmt>(sStmt))
+ {
+ // check that all case tags are unique
+ if (caseStmt->exprVal)
+ {
+ // exprVal contains the constant folded expr, that is checked for
+ // uniqueness within the scope of the switch statement.
+ if (!caseStmtVals.add(caseStmt->exprVal))
+ {
+ sink->diagnose(sStmt, Diagnostics::switchDuplicateCases);
+ return;
+ }
+ }
+ }
+ else if (auto defaultStmt = as<DefaultStmt>(sStmt))
+ {
+ // check that there is at most one `default` clause
+ if (hasDefaultStmt)
+ {
+ sink->diagnose(sStmt, Diagnostics::switchMultipleDefault);
+ return;
+ }
+ hasDefaultStmt = true;
+ }
+ }
+ }
+
void SemanticsStmtVisitor::visitSwitchStmt(SwitchStmt* stmt)
{
WithOuterStmt subContext(this, stmt);
@@ -247,16 +288,18 @@ namespace Slang
stmt->condition = CheckExpr(stmt->condition);
subContext.checkStmt(stmt->body);
- // TODO(tfoley): need to check that all case tags are unique
-
- // TODO(tfoley): check that there is at most one `default` clause
+ // check the case value exits within the switch
+ validateCaseStmts(stmt, getSink());
}
void SemanticsStmtVisitor::visitCaseStmt(CaseStmt* stmt)
{
- // TODO(tfoley): Need to coerce to type being switch on,
- // and ensure that value is a compile-time constant
auto expr = CheckExpr(stmt->expr);
+
+ // coerce to type being switch on, and ensure that value is a compile-time constant
+ // The Vals in the AST are pointer-unique, making them easy to check for duplicates
+ // by addeing them to a HashSet.
+ auto exprVal = tryConstantFoldExpr(expr, ConstantFoldingKind::CompileTime, nullptr);
auto switchStmt = FindOuterStmt<SwitchStmt>();
if (!switchStmt)
@@ -270,6 +313,7 @@ namespace Slang
}
stmt->expr = expr;
+ stmt->exprVal = exprVal;
stmt->parentStmt = switchStmt;
}
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 752bcf03b..882a4c314 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -476,6 +476,10 @@ DIAGNOSTIC(30504, Warning, forLoopTerminatesInFewerIterationsThanMaxIters, "the
DIAGNOSTIC(30505, Warning, loopRunsForZeroIterations, "the loop runs for 0 iterations and will be removed.")
DIAGNOSTIC(30510, Error, loopInDiffFuncRequireUnrollOrMaxIters, "loops inside a differentiable function need to provide either '[MaxIters(n)]' or '[ForceUnroll]' attribute.")
+// Switch
+DIAGNOSTIC(30600, Error, switchMultipleDefault, "multiple 'default' cases not allowed within a 'switch' statement")
+DIAGNOSTIC(30601, Error, switchDuplicateCases, "duplicate cases not allowed within a 'switch' statement")
+
// TODO: need to assign numbers to all these extra diagnostics...
DIAGNOSTIC(39999, Fatal, cyclicReference, "cyclic reference '$0'.")
DIAGNOSTIC(39999, Error, cyclicReferenceInInheritance, "cyclic reference in inheritance graph '$0'.")