summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-04-02 11:51:36 -0700
committerGitHub <noreply@github.com>2024-04-02 11:51:36 -0700
commitb5f4cf63a8b952731053a0d04af0fc8c946d86f3 (patch)
tree5ff2c4fc31a9c6728d7e0af6b60d9b7c074c7a81
parent251f55c5ec4cb2b7432e71d6ba8adc96700d35c2 (diff)
Allow enum values to be used as generic arguments. (#3874)
* Allow enum values to be used as generic arguments. * Fix constant folding.
-rw-r--r--source/slang/core.meta.slang21
-rw-r--r--source/slang/slang-ast-val.cpp73
-rw-r--r--source/slang/slang-check-decl.cpp5
-rw-r--r--source/slang/slang-check-expr.cpp9
-rw-r--r--source/slang/slang-check-impl.h3
-rw-r--r--tests/language-feature/enums/enum-generic-arg.slang33
6 files changed, 110 insertions, 34 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 8d91f27ab..ef3614e2e 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2221,6 +2221,27 @@ __prefix T operator !(T v0)
return v0.not();
}
+// The operator overloads defined above already allows Enum types to be used
+// in logical operators, but we still provide overloads for __EnumTypes and map
+// them directly to intrinsic op to allow constant propagation at AST level to
+// work on enum types.
+
+__generic<T : __EnumType>
+[__unsafeForceInlineEarly]
+__intrinsic_op($(kIROp_BitAnd))
+T operator &(T v0, T v1);
+
+__generic<T : __EnumType>
+[__unsafeForceInlineEarly]
+__intrinsic_op($(kIROp_BitOr))
+T operator |(T v0, T v1);
+
+__generic<T : __EnumType>
+[__unsafeForceInlineEarly]
+__intrinsic_op($(kIROp_BitNot))
+__prefix T operator ~(T v0);
+
+
// IR level type traits.
__generic<T>
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index 0dbe65ee0..1d5a875dd 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -8,6 +8,7 @@
#include "slang-syntax.h"
#include "slang-ast-val.h"
#include "slang-mangle.h"
+#include "slang-check-impl.h"
namespace Slang {
@@ -1054,44 +1055,54 @@ void TypeCastIntVal::_toTextOverride(StringBuilder& out)
Val* TypeCastIntVal::tryFoldImpl(ASTBuilder* astBuilder, Type* resultType, Val* base, DiagnosticSink* sink)
{
SLANG_UNUSED(sink);
-
+ auto convertValue = [&](BasicExpressionType* baseType, IntegerLiteralValue& resultValue) -> bool
+ {
+ switch (baseType->getBaseType())
+ {
+ case BaseType::Int:
+ resultValue = (int)resultValue;
+ return true;
+ case BaseType::UInt:
+ resultValue = (unsigned int)resultValue;
+ return true;
+ case BaseType::Int64:
+ case BaseType::IntPtr:
+ resultValue = (Int64)resultValue;
+ return true;
+ case BaseType::UInt64:
+ case BaseType::UIntPtr:
+ resultValue = (UInt64)resultValue;
+ return true;
+ case BaseType::Int16:
+ resultValue = (int16_t)resultValue;
+ return true;
+ case BaseType::UInt16:
+ resultValue = (uint16_t)resultValue;
+ return true;
+ case BaseType::Int8:
+ resultValue = (int8_t)resultValue;
+ return true;
+ case BaseType::UInt8:
+ resultValue = (uint8_t)resultValue;
+ return true;
+ default:
+ return false;
+ }
+ };
if (auto c = as<ConstantIntVal>(base))
{
IntegerLiteralValue resultValue = c->getValue();
auto baseType = as<BasicExpressionType>(resultType);
if (baseType)
{
- switch (baseType->getBaseType())
- {
- case BaseType::Int:
- resultValue = (int)resultValue;
- break;
- case BaseType::UInt:
- resultValue = (unsigned int)resultValue;
- break;
- case BaseType::Int64:
- case BaseType::IntPtr:
- resultValue = (Int64)resultValue;
- break;
- case BaseType::UInt64:
- case BaseType::UIntPtr:
- resultValue = (UInt64)resultValue;
- break;
- case BaseType::Int16:
- resultValue = (int16_t)resultValue;
- break;
- case BaseType::UInt16:
- resultValue = (uint16_t)resultValue;
- break;
- case BaseType::Int8:
- resultValue = (int8_t)resultValue;
- break;
- case BaseType::UInt8:
- resultValue = (uint8_t)resultValue;
- break;
- default:
+ if (!convertValue(baseType, resultValue))
+ return nullptr;
+ }
+ else if (auto enumDecl = isEnumType(resultType))
+ {
+ baseType = as<BasicExpressionType>(enumDecl->tagType);
+ if (!baseType || !convertValue(baseType, resultValue))
return nullptr;
- }
}
return astBuilder->getIntVal(resultType, resultValue);
}
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 07c8b0cba..aec3b463b 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -5946,6 +5946,11 @@ namespace Slang
return isIntegerBaseType(baseType) || baseType == BaseType::Bool;
}
+ bool SemanticsVisitor::isValidCompileTimeConstantType(Type* type)
+ {
+ return isScalarIntegerType(type) || isEnumType(type);
+ }
+
bool SemanticsVisitor::isIntValueInRangeOfType(IntegerLiteralValue value, Type* type)
{
auto basicType = as<BasicExpressionType>(type);
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index d3341e87b..ff5dd4af5 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1768,7 +1768,10 @@ namespace Slang
return nullptr;
ConstantFoldingCircularityInfo newCircularityInfo(enumCaseDecl, circularityInfo);
- return tryConstantFoldExpr(tagExpr, kind, &newCircularityInfo);
+ auto intVal = as<IntVal>(tryConstantFoldExpr(tagExpr, kind, &newCircularityInfo));
+ if (!intVal)
+ return nullptr;
+ return as<IntVal>(m_astBuilder->getTypeCastIntVal(enumCaseDecl->getType(), intVal)->resolve());
}
}
}
@@ -1778,7 +1781,7 @@ namespace Slang
auto substType = getType(m_astBuilder, expr);
if (!substType)
return nullptr;
- if (!isScalarIntegerType(substType))
+ if (!isValidCompileTimeConstantType(substType))
return nullptr;
auto val = tryConstantFoldExpr(getArg(castExpr, 0), kind, circularityInfo);
if (val)
@@ -1826,7 +1829,7 @@ namespace Slang
{
// Check if type is acceptable for an integer constant expression
//
- if(!isScalarIntegerType(getType(m_astBuilder, expr)))
+ if(!isValidCompileTimeConstantType(getType(m_astBuilder, expr)))
return nullptr;
// Consider operations that we might be able to constant-fold...
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 002ef1f71..48bd2093b 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1816,6 +1816,9 @@ namespace Slang
/// Is `type` a scalar integer type.
bool isScalarIntegerType(Type* type);
+ /// Is `type` something we allow as compile time constants, i.e. scalar integer and enum types.
+ bool isValidCompileTimeConstantType(Type* type);
+
bool isIntValueInRangeOfType(IntegerLiteralValue value, Type* type);
// Validate that `type` is a suitable type to use
diff --git a/tests/language-feature/enums/enum-generic-arg.slang b/tests/language-feature/enums/enum-generic-arg.slang
new file mode 100644
index 000000000..e851727e1
--- /dev/null
+++ b/tests/language-feature/enums/enum-generic-arg.slang
@@ -0,0 +1,33 @@
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj
+
+// Test that enum values can be used as compile time constants
+// to specialize generics.
+
+[Flags]
+enum BitFlags
+{
+ One, Two, Three
+}
+
+int test<let F : BitFlags>()
+{
+ return F;
+}
+
+int testInt<let f : int>()
+{
+ return f;
+}
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
+{
+ // CHECK: 3
+ outputBuffer[0] = test<BitFlags.One | BitFlags.Two>();
+
+ // CHECK: 3
+ outputBuffer[1] = testInt<BitFlags.One | BitFlags.Two>();
+}