diff options
| author | Yong He <yonghe@outlook.com> | 2024-04-02 11:51:36 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-04-02 11:51:36 -0700 |
| commit | b5f4cf63a8b952731053a0d04af0fc8c946d86f3 (patch) | |
| tree | 5ff2c4fc31a9c6728d7e0af6b60d9b7c074c7a81 | |
| parent | 251f55c5ec4cb2b7432e71d6ba8adc96700d35c2 (diff) | |
Allow enum values to be used as generic arguments. (#3874)
* Allow enum values to be used as generic arguments.
* Fix constant folding.
| -rw-r--r-- | source/slang/core.meta.slang | 21 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 73 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 3 | ||||
| -rw-r--r-- | tests/language-feature/enums/enum-generic-arg.slang | 33 |
6 files changed, 110 insertions, 34 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 8d91f27ab..ef3614e2e 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2221,6 +2221,27 @@ __prefix T operator !(T v0) return v0.not(); } +// The operator overloads defined above already allows Enum types to be used +// in logical operators, but we still provide overloads for __EnumTypes and map +// them directly to intrinsic op to allow constant propagation at AST level to +// work on enum types. + +__generic<T : __EnumType> +[__unsafeForceInlineEarly] +__intrinsic_op($(kIROp_BitAnd)) +T operator &(T v0, T v1); + +__generic<T : __EnumType> +[__unsafeForceInlineEarly] +__intrinsic_op($(kIROp_BitOr)) +T operator |(T v0, T v1); + +__generic<T : __EnumType> +[__unsafeForceInlineEarly] +__intrinsic_op($(kIROp_BitNot)) +__prefix T operator ~(T v0); + + // IR level type traits. __generic<T> diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 0dbe65ee0..1d5a875dd 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -8,6 +8,7 @@ #include "slang-syntax.h" #include "slang-ast-val.h" #include "slang-mangle.h" +#include "slang-check-impl.h" namespace Slang { @@ -1054,44 +1055,54 @@ void TypeCastIntVal::_toTextOverride(StringBuilder& out) Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink) { SLANG_UNUSED(sink); - + auto convertValue = [&](BasicExpressionType* baseType, IntegerLiteralValue& resultValue) -> bool + { + switch (baseType->getBaseType()) + { + case BaseType::Int: + resultValue = (int)resultValue; + return true; + case BaseType::UInt: + resultValue = (unsigned int)resultValue; + return true; + case BaseType::Int64: + case BaseType::IntPtr: + resultValue = (Int64)resultValue; + return true; + case BaseType::UInt64: + case BaseType::UIntPtr: + resultValue = (UInt64)resultValue; + return true; + case BaseType::Int16: + resultValue = (int16_t)resultValue; + return true; + case BaseType::UInt16: + resultValue = (uint16_t)resultValue; + return true; + case BaseType::Int8: + resultValue = (int8_t)resultValue; + return true; + case BaseType::UInt8: + resultValue = (uint8_t)resultValue; + return true; + default: + return false; + } + }; if (auto c = as<ConstantIntVal>(base)) { IntegerLiteralValue resultValue = c->getValue(); auto baseType = as<BasicExpressionType>(resultType); if (baseType) { - switch (baseType->getBaseType()) - { - case BaseType::Int: - resultValue = (int)resultValue; - break; - case BaseType::UInt: - resultValue = (unsigned int)resultValue; - break; - case BaseType::Int64: - case BaseType::IntPtr: - resultValue = (Int64)resultValue; - break; - case BaseType::UInt64: - case BaseType::UIntPtr: - resultValue = (UInt64)resultValue; - break; - case BaseType::Int16: - resultValue = (int16_t)resultValue; - break; - case BaseType::UInt16: - resultValue = (uint16_t)resultValue; - break; - case BaseType::Int8: - resultValue = (int8_t)resultValue; - break; - case BaseType::UInt8: - resultValue = (uint8_t)resultValue; - break; - default: + if (!convertValue(baseType, resultValue)) + return nullptr; + } + else if (auto enumDecl = isEnumType(resultType)) + { + baseType = as<BasicExpressionType>(enumDecl->tagType); + if (!baseType || !convertValue(baseType, resultValue)) return nullptr; - } } return astBuilder->getIntVal(resultType, resultValue); } diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 07c8b0cba..aec3b463b 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -5946,6 +5946,11 @@ namespace Slang return isIntegerBaseType(baseType) || baseType == BaseType::Bool; } + bool SemanticsVisitor::isValidCompileTimeConstantType(Type* type) + { + return isScalarIntegerType(type) || isEnumType(type); + } + bool SemanticsVisitor::isIntValueInRangeOfType(IntegerLiteralValue value, Type* type) { auto basicType = as<BasicExpressionType>(type); diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index d3341e87b..ff5dd4af5 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1768,7 +1768,10 @@ namespace Slang return nullptr; ConstantFoldingCircularityInfo newCircularityInfo(enumCaseDecl, circularityInfo); - return tryConstantFoldExpr(tagExpr, kind, &newCircularityInfo); + auto intVal = as<IntVal>(tryConstantFoldExpr(tagExpr, kind, &newCircularityInfo)); + if (!intVal) + return nullptr; + return as<IntVal>(m_astBuilder->getTypeCastIntVal(enumCaseDecl->getType(), intVal)->resolve()); } } } @@ -1778,7 +1781,7 @@ namespace Slang auto substType = getType(m_astBuilder, expr); if (!substType) return nullptr; - if (!isScalarIntegerType(substType)) + if (!isValidCompileTimeConstantType(substType)) return nullptr; auto val = tryConstantFoldExpr(getArg(castExpr, 0), kind, circularityInfo); if (val) @@ -1826,7 +1829,7 @@ namespace Slang { // Check if type is acceptable for an integer constant expression // - if(!isScalarIntegerType(getType(m_astBuilder, expr))) + if(!isValidCompileTimeConstantType(getType(m_astBuilder, expr))) return nullptr; // Consider operations that we might be able to constant-fold... diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 002ef1f71..48bd2093b 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1816,6 +1816,9 @@ namespace Slang /// Is `type` a scalar integer type. bool isScalarIntegerType(Type* type); + /// Is `type` something we allow as compile time constants, i.e. scalar integer and enum types. + bool isValidCompileTimeConstantType(Type* type); + bool isIntValueInRangeOfType(IntegerLiteralValue value, Type* type); // Validate that `type` is a suitable type to use diff --git a/tests/language-feature/enums/enum-generic-arg.slang b/tests/language-feature/enums/enum-generic-arg.slang new file mode 100644 index 000000000..e851727e1 --- /dev/null +++ b/tests/language-feature/enums/enum-generic-arg.slang @@ -0,0 +1,33 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj + +// Test that enum values can be used as compile time constants +// to specialize generics. + +[Flags] +enum BitFlags +{ + One, Two, Three +} + +int test<let F : BitFlags>() +{ + return F; +} + +int testInt<let f : int>() +{ + return f; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID : SV_DispatchThreadID) +{ + // CHECK: 3 + outputBuffer[0] = test<BitFlags.One | BitFlags.Two>(); + + // CHECK: 3 + outputBuffer[1] = testInt<BitFlags.One | BitFlags.Two>(); +} |
