summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-check-decl.cpp10
-rw-r--r--source/slang/slang-check-expr.cpp9
-rw-r--r--source/slang/slang-check-impl.h6
-rw-r--r--tests/language-feature/enums/enum-array-indexing.slang36
4 files changed, 55 insertions, 6 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 1dc230dae..806912abe 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -8485,6 +8485,16 @@ bool SemanticsVisitor::isScalarIntegerType(Type* type)
return isIntegerBaseType(baseType) || baseType == BaseType::Bool;
}
+Type* SemanticsVisitor::getMatchingIntType(Type* type)
+{
+ if (isScalarIntegerType(type))
+ return type;
+ if (auto enumTypeDecl = isDeclRefTypeOf<EnumDecl>(type))
+ if (enumTypeDecl.getDecl()->tagType)
+ return getMatchingIntType(enumTypeDecl.getDecl()->tagType);
+ return m_astBuilder->getIntType();
+}
+
bool SemanticsVisitor::isHalfType(Type* type)
{
auto basicType = as<BasicExpressionType>(type);
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 41c3bd510..2891c316f 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2490,13 +2490,10 @@ Expr* SemanticsVisitor::CheckSimpleSubscriptExpr(IndexExpr* subscriptExpr, Type*
{
expr = CheckExpr(expr);
}
- auto indexExpr = subscriptExpr->indexExprs[0];
+ auto& indexExpr = subscriptExpr->indexExprs[0];
- if (!isScalarIntegerType(indexExpr->type.type))
- {
- getSink()->diagnose(indexExpr, Diagnostics::subscriptIndexNonInteger);
- return CreateErrorExpr(subscriptExpr);
- }
+ auto intTargetType = getMatchingIntType(indexExpr->type.type);
+ indexExpr = coerce(CoercionSite::Argument, intTargetType, indexExpr, getSink());
subscriptExpr->type = QualType(elementType);
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 86d0b42fe..5c6be2665 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2139,6 +2139,12 @@ public:
/// Is `type` a scalar integer type.
bool isScalarIntegerType(Type* type);
+ // This function is used to get the best integer type that matches the given type.
+ // If `type` is already an integer type, return it as is.
+ // If `type` is a enum type, return the tag type if it exists.
+ // Otherwise, return the 32-bit signed integer type.
+ Type* getMatchingIntType(Type* type);
+
/// Is `type` a scalar half type.
bool isHalfType(Type* type);
diff --git a/tests/language-feature/enums/enum-array-indexing.slang b/tests/language-feature/enums/enum-array-indexing.slang
new file mode 100644
index 000000000..c9294ad0c
--- /dev/null
+++ b/tests/language-feature/enums/enum-array-indexing.slang
@@ -0,0 +1,36 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-cpu -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -shaderobj -output-using-type
+
+// Test that enums can be used as array indices without explicit casting
+
+enum Fruit { Apple, Orange, Banana };
+
+//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4)
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(1,1,1)]
+void computeMain()
+{
+ int fruits[10];
+
+ // Initialize arrays with some values
+ for (int i = 0; i < 10; i++)
+ {
+ fruits[i] = i * 10;
+ }
+
+ // Test basic enum indexing - this should work with our fix
+ int appleCost = fruits[Fruit::Apple]; // Should access fruits[0] = 0
+ int orangeCost = fruits[Fruit::Orange]; // Should access fruits[1] = 10
+ int bananaCost = fruits[Fruit::Banana]; // Should access fruits[2] = 20
+
+ // CHECK: 0
+ outputBuffer[0] = appleCost;
+ // CHECK: 10
+ outputBuffer[1] = orangeCost;
+ // CHECK: 20
+ outputBuffer[2] = bananaCost;
+ // CHECK: 42
+ outputBuffer[3] = 42; // Just a test value
+} \ No newline at end of file