diff options
| -rw-r--r-- | source/slang/slang-ast-stmt.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-check-stmt.cpp | 54 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 4 | ||||
| -rw-r--r-- | tests/diagnostics/switch-duplicate-case.slang | 33 | ||||
| -rw-r--r-- | tests/diagnostics/switch-multiple-defaults.slang | 14 |
6 files changed, 105 insertions, 5 deletions
diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h index 055785333..afa606456 100644 --- a/source/slang/slang-ast-stmt.h +++ b/source/slang/slang-ast-stmt.h @@ -143,6 +143,8 @@ class CaseStmt : public CaseStmtBase SLANG_AST_CLASS(CaseStmt) Expr* expr = nullptr; + + Val* exprVal = nullptr; }; // a `default` statement inside a `switch` diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 48bd2093b..55edba6b9 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2741,6 +2741,9 @@ namespace Slang void tryInferLoopMaxIterations(ForStmt* stmt); void checkLoopInDifferentiableFunc(Stmt* stmt); + + private: + void validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* sink); }; struct SemanticsDeclVisitorBase diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index b68465087..729b24d35 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -239,6 +239,47 @@ namespace Slang subContext.checkStmt(stmt->body); } + void SemanticsStmtVisitor::validateCaseStmts(SwitchStmt* stmt, DiagnosticSink* sink) + { + auto blockStmt = as<BlockStmt>(stmt->body); + if (!blockStmt) + return; + + auto seqStmt = as<SeqStmt>(blockStmt->body); + if (!seqStmt) + return; + + bool hasDefaultStmt = false; + HashSet<Val*> caseStmtVals; + for (auto& sStmt : seqStmt->stmts) + { + if (auto caseStmt = as<CaseStmt>(sStmt)) + { + // check that all case tags are unique + if (caseStmt->exprVal) + { + // exprVal contains the constant folded expr, that is checked for + // uniqueness within the scope of the switch statement. + if (!caseStmtVals.add(caseStmt->exprVal)) + { + sink->diagnose(sStmt, Diagnostics::switchDuplicateCases); + return; + } + } + } + else if (auto defaultStmt = as<DefaultStmt>(sStmt)) + { + // check that there is at most one `default` clause + if (hasDefaultStmt) + { + sink->diagnose(sStmt, Diagnostics::switchMultipleDefault); + return; + } + hasDefaultStmt = true; + } + } + } + void SemanticsStmtVisitor::visitSwitchStmt(SwitchStmt* stmt) { WithOuterStmt subContext(this, stmt); @@ -247,16 +288,18 @@ namespace Slang stmt->condition = CheckExpr(stmt->condition); subContext.checkStmt(stmt->body); - // TODO(tfoley): need to check that all case tags are unique - - // TODO(tfoley): check that there is at most one `default` clause + // check the case value exits within the switch + validateCaseStmts(stmt, getSink()); } void SemanticsStmtVisitor::visitCaseStmt(CaseStmt* stmt) { - // TODO(tfoley): Need to coerce to type being switch on, - // and ensure that value is a compile-time constant auto expr = CheckExpr(stmt->expr); + + // coerce to type being switch on, and ensure that value is a compile-time constant + // The Vals in the AST are pointer-unique, making them easy to check for duplicates + // by addeing them to a HashSet. + auto exprVal = tryConstantFoldExpr(expr, ConstantFoldingKind::CompileTime, nullptr); auto switchStmt = FindOuterStmt<SwitchStmt>(); if (!switchStmt) @@ -270,6 +313,7 @@ namespace Slang } stmt->expr = expr; + stmt->exprVal = exprVal; stmt->parentStmt = switchStmt; } diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 752bcf03b..882a4c314 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -476,6 +476,10 @@ DIAGNOSTIC(30504, Warning, forLoopTerminatesInFewerIterationsThanMaxIters, "the DIAGNOSTIC(30505, Warning, loopRunsForZeroIterations, "the loop runs for 0 iterations and will be removed.") DIAGNOSTIC(30510, Error, loopInDiffFuncRequireUnrollOrMaxIters, "loops inside a differentiable function need to provide either '[MaxIters(n)]' or '[ForceUnroll]' attribute.") +// Switch +DIAGNOSTIC(30600, Error, switchMultipleDefault, "multiple 'default' cases not allowed within a 'switch' statement") +DIAGNOSTIC(30601, Error, switchDuplicateCases, "duplicate cases not allowed within a 'switch' statement") + // TODO: need to assign numbers to all these extra diagnostics... DIAGNOSTIC(39999, Fatal, cyclicReference, "cyclic reference '$0'.") DIAGNOSTIC(39999, Error, cyclicReferenceInInheritance, "cyclic reference in inheritance graph '$0'.") diff --git a/tests/diagnostics/switch-duplicate-case.slang b/tests/diagnostics/switch-duplicate-case.slang new file mode 100644 index 000000000..2d55dc4c8 --- /dev/null +++ b/tests/diagnostics/switch-duplicate-case.slang @@ -0,0 +1,33 @@ +//TEST:SIMPLE(filecheck=CHECK): + +// Tests to evaluate the behavior of code blocks within a switch statement. A switch statement with duplicate cases with same values is not allowed and should throw an error + +enum class Cases +{ + A, + B +}; + +void test1(Cases c) +{ + switch (c) + { + case Cases::A: break; + case Cases::B: break; + // CHECK: ([[# @LINE+1]]): error 30601: {{.*}} + case Cases::A: break; + } + return; +} + +void test2() +{ + switch (0) + { + case 1: break; + case 2: break; + // CHECK: ([[# @LINE+1]]): error 30601: {{.*}} + case 1: break; + } + return; +} diff --git a/tests/diagnostics/switch-multiple-defaults.slang b/tests/diagnostics/switch-multiple-defaults.slang new file mode 100644 index 000000000..bf164dd57 --- /dev/null +++ b/tests/diagnostics/switch-multiple-defaults.slang @@ -0,0 +1,14 @@ +//TEST:SIMPLE(filecheck=CHECK): + +// Test to evaluate the behavior of unreachable code blocks within a switch statement. A switch statement with multiple default cases is not allowed and should throw an error + +void test() +{ + switch (0) + { + default: break; + // CHECK: ([[# @LINE+1]]): error 30600: {{.*}} + default: break; + } + return; +} |
