summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-check-decl.cpp10
-rw-r--r--source/slang/slang-check-expr.cpp9
-rw-r--r--source/slang/slang-check-impl.h6
3 files changed, 19 insertions, 6 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 1dc230dae..806912abe 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -8485,6 +8485,16 @@ bool SemanticsVisitor::isScalarIntegerType(Type* type)
return isIntegerBaseType(baseType) || baseType == BaseType::Bool;
}
+Type* SemanticsVisitor::getMatchingIntType(Type* type)
+{
+ if (isScalarIntegerType(type))
+ return type;
+ if (auto enumTypeDecl = isDeclRefTypeOf<EnumDecl>(type))
+ if (enumTypeDecl.getDecl()->tagType)
+ return getMatchingIntType(enumTypeDecl.getDecl()->tagType);
+ return m_astBuilder->getIntType();
+}
+
bool SemanticsVisitor::isHalfType(Type* type)
{
auto basicType = as<BasicExpressionType>(type);
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 41c3bd510..2891c316f 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -2490,13 +2490,10 @@ Expr* SemanticsVisitor::CheckSimpleSubscriptExpr(IndexExpr* subscriptExpr, Type*
{
expr = CheckExpr(expr);
}
- auto indexExpr = subscriptExpr->indexExprs[0];
+ auto& indexExpr = subscriptExpr->indexExprs[0];
- if (!isScalarIntegerType(indexExpr->type.type))
- {
- getSink()->diagnose(indexExpr, Diagnostics::subscriptIndexNonInteger);
- return CreateErrorExpr(subscriptExpr);
- }
+ auto intTargetType = getMatchingIntType(indexExpr->type.type);
+ indexExpr = coerce(CoercionSite::Argument, intTargetType, indexExpr, getSink());
subscriptExpr->type = QualType(elementType);
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 86d0b42fe..5c6be2665 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2139,6 +2139,12 @@ public:
/// Is `type` a scalar integer type.
bool isScalarIntegerType(Type* type);
+ // This function is used to get the best integer type that matches the given type.
+ // If `type` is already an integer type, return it as is.
+ // If `type` is a enum type, return the tag type if it exists.
+ // Otherwise, return the 32-bit signed integer type.
+ Type* getMatchingIntType(Type* type);
+
/// Is `type` a scalar half type.
bool isHalfType(Type* type);