summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-modifier.cpp
diff options
context:
space:
mode:
authorJulius Ikkala <julius.ikkala@gmail.com>2025-01-14 20:32:29 +0200
committerGitHub <noreply@github.com>2025-01-14 10:32:29 -0800
commitcbdc7e1219e472fd74f7f559d7e417f233e7df39 (patch)
treee051b90e317a875e264c2c8d951668bf0b7d3ad0 /source/slang/slang-check-modifier.cpp
parent971996b397711016d47fe961890d7001338c6f23 (diff)
Implement specialization constant support in numthreads / local_size (#5963)
* Allow using specialization constants in numthreads attribute * Add support for GLSL local_size_x_id syntax * Fix overeager specialization constant parsing * Add diagnostics for specialization constant numthreads * Remove unused variable * Fix local_size_x_id not finding existing specialization constant * Allow materializeGetWorkGroupSize to reference specialization constants * Use SpvOpExecutionModeId for modes that require it * Cleanup specialization constant numthreads code * Add tests for specialization constant work group sizes * Fix implicit Slang::Int -> int32_t cast * Fix querying thread group size in reflection API --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-check-modifier.cpp')
-rw-r--r--source/slang/slang-check-modifier.cpp108
1 files changed, 91 insertions, 17 deletions
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index 3723c98f8..6e451b5cf 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -114,6 +114,36 @@ void SemanticsVisitor::visitModifier(Modifier*)
// Do nothing with modifiers for now
}
+DeclRef<VarDeclBase> SemanticsVisitor::tryGetIntSpecializationConstant(Expr* expr)
+{
+ // First type-check the expression as normal
+ expr = CheckExpr(expr);
+
+ if (IsErrorExpr(expr))
+ return DeclRef<VarDeclBase>();
+
+ if (!isScalarIntegerType(expr->type))
+ return DeclRef<VarDeclBase>();
+
+ auto specConstVar = as<VarExpr>(expr);
+ if (!specConstVar || !specConstVar->declRef)
+ return DeclRef<VarDeclBase>();
+
+ auto decl = specConstVar->declRef.getDecl();
+ if (!decl)
+ return DeclRef<VarDeclBase>();
+
+ for (auto modifier : decl->modifiers)
+ {
+ if (as<SpecializationConstantAttribute>(modifier) || as<VkConstantIdAttribute>(modifier))
+ {
+ return specConstVar->declRef.as<VarDeclBase>();
+ }
+ }
+
+ return DeclRef<VarDeclBase>();
+}
+
static bool _isDeclAllowedAsAttribute(DeclRef<Decl> declRef)
{
if (as<AttributeDecl>(declRef.getDecl()))
@@ -350,8 +380,6 @@ Modifier* SemanticsVisitor::validateAttribute(
{
SLANG_ASSERT(attr->args.getCount() == 3);
- IntVal* values[3];
-
for (int i = 0; i < 3; ++i)
{
IntVal* value = nullptr;
@@ -359,6 +387,14 @@ Modifier* SemanticsVisitor::validateAttribute(
auto arg = attr->args[i];
if (arg)
{
+ auto specConstDecl = tryGetIntSpecializationConstant(arg);
+ if (specConstDecl)
+ {
+ numThreadsAttr->extents[i] = nullptr;
+ numThreadsAttr->specConstExtents[i] = specConstDecl;
+ continue;
+ }
+
auto intValue = checkLinkTimeConstantIntVal(arg);
if (!intValue)
{
@@ -390,12 +426,8 @@ Modifier* SemanticsVisitor::validateAttribute(
{
value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1);
}
- values[i] = value;
+ numThreadsAttr->extents[i] = value;
}
-
- numThreadsAttr->x = values[0];
- numThreadsAttr->y = values[1];
- numThreadsAttr->z = values[2];
}
else if (auto waveSizeAttr = as<WaveSizeAttribute>(attr))
{
@@ -1831,15 +1863,24 @@ Modifier* SemanticsVisitor::checkModifier(
{
SLANG_ASSERT(attr->args.getCount() == 3);
- IntVal* values[3];
+ // GLSLLayoutLocalSizeAttribute is always attached to an EmptyDecl.
+ auto decl = as<EmptyDecl>(syntaxNode);
+ SLANG_ASSERT(decl);
for (int i = 0; i < 3; ++i)
{
- IntVal* value = nullptr;
+ attr->extents[i] = nullptr;
auto arg = attr->args[i];
if (arg)
{
+ auto specConstDecl = tryGetIntSpecializationConstant(arg);
+ if (specConstDecl)
+ {
+ attr->specConstExtents[i] = specConstDecl;
+ continue;
+ }
+
auto intValue = checkConstantIntVal(arg);
if (!intValue)
{
@@ -1847,7 +1888,45 @@ Modifier* SemanticsVisitor::checkModifier(
}
if (auto cintVal = as<ConstantIntVal>(intValue))
{
- if (cintVal->getValue() < 1)
+ if (attr->axisIsSpecConstId[i])
+ {
+ // This integer should actually be a reference to a
+ // specialization constant with this ID.
+ Int specConstId = cintVal->getValue();
+
+ for (auto member : decl->parentDecl->members)
+ {
+ auto constantId = member->findModifier<VkConstantIdAttribute>();
+ if (constantId)
+ {
+ SLANG_ASSERT(constantId->args.getCount() == 1);
+ auto id = checkConstantIntVal(constantId->args[0]);
+ if (id->getValue() == specConstId)
+ {
+ attr->specConstExtents[i] =
+ DeclRef<VarDeclBase>(member->getDefaultDeclRef());
+ break;
+ }
+ }
+ }
+
+ // If not found, we need to create a new specialization
+ // constant with this ID.
+ if (!attr->specConstExtents[i])
+ {
+ auto specConstVarDecl = getASTBuilder()->create<VarDecl>();
+ auto constantIdModifier =
+ getASTBuilder()->create<VkConstantIdAttribute>();
+ constantIdModifier->location = (int32_t)specConstId;
+ specConstVarDecl->type.type = getASTBuilder()->getIntType();
+ addModifier(specConstVarDecl, constantIdModifier);
+ decl->parentDecl->addMember(specConstVarDecl);
+ attr->specConstExtents[i] =
+ DeclRef<VarDeclBase>(specConstVarDecl->getDefaultDeclRef());
+ }
+ continue;
+ }
+ else if (cintVal->getValue() < 1)
{
getSink()->diagnose(
attr,
@@ -1856,18 +1935,13 @@ Modifier* SemanticsVisitor::checkModifier(
return nullptr;
}
}
- value = intValue;
+ attr->extents[i] = intValue;
}
else
{
- value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1);
+ attr->extents[i] = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1);
}
- values[i] = value;
}
-
- attr->x = values[0];
- attr->y = values[1];
- attr->z = values[2];
}
// Default behavior is to leave things as they are,