From 7cd502256dde2fc32a1dd77462a69b6f8e84c288 Mon Sep 17 00:00:00 2001 From: Julius Ikkala Date: Sun, 11 May 2025 08:14:44 +0300 Subject: Fix local constants in switch cases (#7053) * Fix using local constants in switch cases * Add test * format code * Always lower switch cases with exprVal * Fix formatting --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com> Co-authored-by: Yong He --- source/slang/slang-lower-to-ir.cpp | 8 +++++++- tests/bugs/switch-local-const.slang | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 tests/bugs/switch-local-const.slang diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index c58eed1c1..90780882d 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -6813,7 +6813,13 @@ struct StmtLoweringVisitor : StmtVisitor IRBuilder subBuilder = *getBuilder(); subBuilder.setInsertInto(info->initialBlock); subContext.irBuilder = &subBuilder; - auto caseVal = getSimpleVal(context, lowerRValueExpr(&subContext, caseStmt->expr)); + + auto constVal = as(caseStmt->exprVal); + SLANG_ASSERT(constVal); + auto caseType = lowerType(context, constVal->getType()); + auto caseValInfo = + LoweredValInfo::simple(getBuilder()->getIntValue(caseType, constVal->getValue())); + auto caseVal = getSimpleVal(context, caseValInfo); // Figure out where we are branching to. auto label = getLabelForCase(info); diff --git a/tests/bugs/switch-local-const.slang b/tests/bugs/switch-local-const.slang new file mode 100644 index 000000000..30cd8a814 --- /dev/null +++ b/tests/bugs/switch-local-const.slang @@ -0,0 +1,34 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -dx12 +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -cpu + +// CHECK: 0 +// CHECK-NEXT: 1 +// CHECK-NEXT: 2 + +//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer outputBuffer; + +uint testFunc(uint n) +{ + const uint CONST_A = 0; + const uint CONST_B = 1; + + switch (n) + { + case CONST_A: + return 0; + case CONST_B: + return 1; + } + + return 2; +} + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + outputBuffer[0] = testFunc(0); + outputBuffer[1] = testFunc(1); + outputBuffer[2] = testFunc(2); +} -- cgit v1.2.3