summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ast-stmt.h2
-rw-r--r--source/slang/slang-check-impl.h3
-rw-r--r--source/slang/slang-check-stmt.cpp54
-rw-r--r--source/slang/slang-diagnostic-defs.h4
-rw-r--r--tests/diagnostics/switch-duplicate-case.slang33
-rw-r--r--tests/diagnostics/switch-multiple-defaults.slang14
6 files changed, 105 insertions, 5 deletions
diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h
index 055785333..afa606456 100644
--- a/source/slang/slang-ast-stmt.h
+++ b/source/slang/slang-ast-stmt.h
@@ -143,6 +143,8 @@ class CaseStmt : public CaseStmtBase
SLANG_AST_CLASS(CaseStmt)
Expr* expr = nullptr;
+
+ Val* exprVal = nullptr;
};
// a `default` statement inside a `switch`
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 48bd2093b..55edba6b9 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2741,6 +2741,9 @@ namespace Slang
void tryInferLoopMaxIterations(ForStmt* stmt);
void checkLoopInDifferentiableFunc(Stmt* stmt);
+
+ private:
+ void validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* sink);
};
struct SemanticsDeclVisitorBase
diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp
index b68465087..729b24d35 100644
--- a/source/slang/slang-check-stmt.cpp
+++ b/source/slang/slang-check-stmt.cpp
@@ -239,6 +239,47 @@ namespace Slang
subContext.checkStmt(stmt->body);
}
+ void SemanticsStmtVisitor::validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* sink)
+ {
+ auto blockStmt = as<BlockStmt>(stmt->body);
+ if (!blockStmt)
+ return;
+
+ auto seqStmt = as<SeqStmt>(blockStmt->body);
+ if (!seqStmt)
+ return;
+
+ bool hasDefaultStmt = false;
+ HashSet<Val*> caseStmtVals;
+ for (auto& sStmt : seqStmt->stmts)
+ {
+ if (auto caseStmt = as<CaseStmt>(sStmt))
+ {
+ // check that all case tags are unique
+ if (caseStmt->exprVal)
+ {
+ // exprVal contains the constant folded expr, that is checked for
+ // uniqueness within the scope of the switch statement.
+ if (!caseStmtVals.add(caseStmt->exprVal))
+ {
+ sink->diagnose(sStmt, Diagnostics::switchDuplicateCases);
+ return;
+ }
+ }
+ }
+ else if (auto defaultStmt = as<DefaultStmt>(sStmt))
+ {
+ // check that there is at most one `default` clause
+ if (hasDefaultStmt)
+ {
+ sink->diagnose(sStmt, Diagnostics::switchMultipleDefault);
+ return;
+ }
+ hasDefaultStmt = true;
+ }
+ }
+ }
+
void SemanticsStmtVisitor::visitSwitchStmt(SwitchStmt* stmt)
{
WithOuterStmt subContext(this, stmt);
@@ -247,16 +288,18 @@ namespace Slang
stmt->condition = CheckExpr(stmt->condition);
subContext.checkStmt(stmt->body);
- // TODO(tfoley): need to check that all case tags are unique
-
- // TODO(tfoley): check that there is at most one `default` clause
+ // check the case value exits within the switch
+ validateCaseStmts(stmt, getSink());
}
void SemanticsStmtVisitor::visitCaseStmt(CaseStmt* stmt)
{
- // TODO(tfoley): Need to coerce to type being switch on,
- // and ensure that value is a compile-time constant
auto expr = CheckExpr(stmt->expr);
+
+ // coerce to type being switch on, and ensure that value is a compile-time constant
+ // The Vals in the AST are pointer-unique, making them easy to check for duplicates
+ // by addeing them to a HashSet.
+ auto exprVal = tryConstantFoldExpr(expr, ConstantFoldingKind::CompileTime, nullptr);
auto switchStmt = FindOuterStmt<SwitchStmt>();
if (!switchStmt)
@@ -270,6 +313,7 @@ namespace Slang
}
stmt->expr = expr;
+ stmt->exprVal = exprVal;
stmt->parentStmt = switchStmt;
}
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 752bcf03b..882a4c314 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -476,6 +476,10 @@ DIAGNOSTIC(30504, Warning, forLoopTerminatesInFewerIterationsThanMaxIters, "the
DIAGNOSTIC(30505, Warning, loopRunsForZeroIterations, "the loop runs for 0 iterations and will be removed.")
DIAGNOSTIC(30510, Error, loopInDiffFuncRequireUnrollOrMaxIters, "loops inside a differentiable function need to provide either '[MaxIters(n)]' or '[ForceUnroll]' attribute.")
+// Switch
+DIAGNOSTIC(30600, Error, switchMultipleDefault, "multiple 'default' cases not allowed within a 'switch' statement")
+DIAGNOSTIC(30601, Error, switchDuplicateCases, "duplicate cases not allowed within a 'switch' statement")
+
// TODO: need to assign numbers to all these extra diagnostics...
DIAGNOSTIC(39999, Fatal, cyclicReference, "cyclic reference '$0'.")
DIAGNOSTIC(39999, Error, cyclicReferenceInInheritance, "cyclic reference in inheritance graph '$0'.")
diff --git a/tests/diagnostics/switch-duplicate-case.slang b/tests/diagnostics/switch-duplicate-case.slang
new file mode 100644
index 000000000..2d55dc4c8
--- /dev/null
+++ b/tests/diagnostics/switch-duplicate-case.slang
@@ -0,0 +1,33 @@
+//TEST:SIMPLE(filecheck=CHECK):
+
+// Tests to evaluate the behavior of code blocks within a switch statement. A switch statement with duplicate cases with same values is not allowed and should throw an error
+
+enum class Cases
+{
+ A,
+ B
+};
+
+void test1(Cases c)
+{
+ switch (c)
+ {
+ case Cases::A: break;
+ case Cases::B: break;
+ // CHECK: ([[# @LINE+1]]): error 30601: {{.*}}
+ case Cases::A: break;
+ }
+ return;
+}
+
+void test2()
+{
+ switch (0)
+ {
+ case 1: break;
+ case 2: break;
+ // CHECK: ([[# @LINE+1]]): error 30601: {{.*}}
+ case 1: break;
+ }
+ return;
+}
diff --git a/tests/diagnostics/switch-multiple-defaults.slang b/tests/diagnostics/switch-multiple-defaults.slang
new file mode 100644
index 000000000..bf164dd57
--- /dev/null
+++ b/tests/diagnostics/switch-multiple-defaults.slang
@@ -0,0 +1,14 @@
+//TEST:SIMPLE(filecheck=CHECK):
+
+// Test to evaluate the behavior of unreachable code blocks within a switch statement. A switch statement with multiple default cases is not allowed and should throw an error
+
+void test()
+{
+ switch (0)
+ {
+ default: break;
+ // CHECK: ([[# @LINE+1]]): error 30600: {{.*}}
+ default: break;
+ }
+ return;
+}