summaryrefslogtreecommitdiff
path: root/source/slang/slang-ast-val.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-04-02 11:51:36 -0700
committerGitHub <noreply@github.com>2024-04-02 11:51:36 -0700
commitb5f4cf63a8b952731053a0d04af0fc8c946d86f3 (patch)
tree5ff2c4fc31a9c6728d7e0af6b60d9b7c074c7a81 /source/slang/slang-ast-val.cpp
parent251f55c5ec4cb2b7432e71d6ba8adc96700d35c2 (diff)
Allow enum values to be used as generic arguments. (#3874)
* Allow enum values to be used as generic arguments. * Fix constant folding.
Diffstat (limited to 'source/slang/slang-ast-val.cpp')
-rw-r--r--source/slang/slang-ast-val.cpp73
1 files changed, 42 insertions, 31 deletions
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index 0dbe65ee0..1d5a875dd 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -8,6 +8,7 @@
#include "slang-syntax.h"
#include "slang-ast-val.h"
#include "slang-mangle.h"
+#include "slang-check-impl.h"
namespace Slang {
@@ -1054,44 +1055,54 @@ void TypeCastIntVal::_toTextOverride(StringBuilder& out)
Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink)
{
SLANG_UNUSED(sink);
-
+ auto convertValue = [&](BasicExpressionType* baseType, IntegerLiteralValue& resultValue) -> bool
+ {
+ switch (baseType->getBaseType())
+ {
+ case BaseType::Int:
+ resultValue = (int)resultValue;
+ return true;
+ case BaseType::UInt:
+ resultValue = (unsigned int)resultValue;
+ return true;
+ case BaseType::Int64:
+ case BaseType::IntPtr:
+ resultValue = (Int64)resultValue;
+ return true;
+ case BaseType::UInt64:
+ case BaseType::UIntPtr:
+ resultValue = (UInt64)resultValue;
+ return true;
+ case BaseType::Int16:
+ resultValue = (int16_t)resultValue;
+ return true;
+ case BaseType::UInt16:
+ resultValue = (uint16_t)resultValue;
+ return true;
+ case BaseType::Int8:
+ resultValue = (int8_t)resultValue;
+ return true;
+ case BaseType::UInt8:
+ resultValue = (uint8_t)resultValue;
+ return true;
+ default:
+ return false;
+ }
+ };
if (auto c = as<ConstantIntVal>(base))
{
IntegerLiteralValue resultValue = c->getValue();
auto baseType = as<BasicExpressionType>(resultType);
if (baseType)
{
- switch (baseType->getBaseType())
- {
- case BaseType::Int:
- resultValue = (int)resultValue;
- break;
- case BaseType::UInt:
- resultValue = (unsigned int)resultValue;
- break;
- case BaseType::Int64:
- case BaseType::IntPtr:
- resultValue = (Int64)resultValue;
- break;
- case BaseType::UInt64:
- case BaseType::UIntPtr:
- resultValue = (UInt64)resultValue;
- break;
- case BaseType::Int16:
- resultValue = (int16_t)resultValue;
- break;
- case BaseType::UInt16:
- resultValue = (uint16_t)resultValue;
- break;
- case BaseType::Int8:
- resultValue = (int8_t)resultValue;
- break;
- case BaseType::UInt8:
- resultValue = (uint8_t)resultValue;
- break;
- default:
+ if (!convertValue(baseType, resultValue))
+ return nullptr;
+ }
+ else if (auto enumDecl = isEnumType(resultType))
+ {
+ baseType = as<BasicExpressionType>(enumDecl->tagType);
+ if (!baseType || !convertValue(baseType, resultValue))
return nullptr;
- }
}
return astBuilder->getIntVal(resultType, resultValue);
}