diff options
| author | Julius Ikkala <julius.ikkala@gmail.com> | 2025-01-14 20:32:29 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-14 10:32:29 -0800 |
| commit | cbdc7e1219e472fd74f7f559d7e417f233e7df39 (patch) | |
| tree | e051b90e317a875e264c2c8d951668bf0b7d3ad0 /source/slang/slang-check-modifier.cpp | |
| parent | 971996b397711016d47fe961890d7001338c6f23 (diff) | |
Implement specialization constant support in numthreads / local_size (#5963)
* Allow using specialization constants in numthreads attribute
* Add support for GLSL local_size_x_id syntax
* Fix overeager specialization constant parsing
* Add diagnostics for specialization constant numthreads
* Remove unused variable
* Fix local_size_x_id not finding existing specialization constant
* Allow materializeGetWorkGroupSize to reference specialization constants
* Use SpvOpExecutionModeId for modes that require it
* Cleanup specialization constant numthreads code
* Add tests for specialization constant work group sizes
* Fix implicit Slang::Int -> int32_t cast
* Fix querying thread group size in reflection API
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-check-modifier.cpp')
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 108 |
1 files changed, 91 insertions, 17 deletions
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 3723c98f8..6e451b5cf 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -114,6 +114,36 @@ void SemanticsVisitor::visitModifier(Modifier*) // Do nothing with modifiers for now } +DeclRef<VarDeclBase> SemanticsVisitor::tryGetIntSpecializationConstant(Expr* expr) +{ + // First type-check the expression as normal + expr = CheckExpr(expr); + + if (IsErrorExpr(expr)) + return DeclRef<VarDeclBase>(); + + if (!isScalarIntegerType(expr->type)) + return DeclRef<VarDeclBase>(); + + auto specConstVar = as<VarExpr>(expr); + if (!specConstVar || !specConstVar->declRef) + return DeclRef<VarDeclBase>(); + + auto decl = specConstVar->declRef.getDecl(); + if (!decl) + return DeclRef<VarDeclBase>(); + + for (auto modifier : decl->modifiers) + { + if (as<SpecializationConstantAttribute>(modifier) || as<VkConstantIdAttribute>(modifier)) + { + return specConstVar->declRef.as<VarDeclBase>(); + } + } + + return DeclRef<VarDeclBase>(); +} + static bool _isDeclAllowedAsAttribute(DeclRef<Decl> declRef) { if (as<AttributeDecl>(declRef.getDecl())) @@ -350,8 +380,6 @@ Modifier* SemanticsVisitor::validateAttribute( { SLANG_ASSERT(attr->args.getCount() == 3); - IntVal* values[3]; - for (int i = 0; i < 3; ++i) { IntVal* value = nullptr; @@ -359,6 +387,14 @@ Modifier* SemanticsVisitor::validateAttribute( auto arg = attr->args[i]; if (arg) { + auto specConstDecl = tryGetIntSpecializationConstant(arg); + if (specConstDecl) + { + numThreadsAttr->extents[i] = nullptr; + numThreadsAttr->specConstExtents[i] = specConstDecl; + continue; + } + auto intValue = checkLinkTimeConstantIntVal(arg); if (!intValue) { @@ -390,12 +426,8 @@ Modifier* SemanticsVisitor::validateAttribute( { value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } - values[i] = value; + numThreadsAttr->extents[i] = value; } - - numThreadsAttr->x = values[0]; - numThreadsAttr->y = values[1]; - numThreadsAttr->z = values[2]; } else if (auto waveSizeAttr = as<WaveSizeAttribute>(attr)) { @@ -1831,15 +1863,24 @@ Modifier* SemanticsVisitor::checkModifier( { SLANG_ASSERT(attr->args.getCount() == 3); - IntVal* values[3]; + // GLSLLayoutLocalSizeAttribute is always attached to an EmptyDecl. + auto decl = as<EmptyDecl>(syntaxNode); + SLANG_ASSERT(decl); for (int i = 0; i < 3; ++i) { - IntVal* value = nullptr; + attr->extents[i] = nullptr; auto arg = attr->args[i]; if (arg) { + auto specConstDecl = tryGetIntSpecializationConstant(arg); + if (specConstDecl) + { + attr->specConstExtents[i] = specConstDecl; + continue; + } + auto intValue = checkConstantIntVal(arg); if (!intValue) { @@ -1847,7 +1888,45 @@ Modifier* SemanticsVisitor::checkModifier( } if (auto cintVal = as<ConstantIntVal>(intValue)) { - if (cintVal->getValue() < 1) + if (attr->axisIsSpecConstId[i]) + { + // This integer should actually be a reference to a + // specialization constant with this ID. + Int specConstId = cintVal->getValue(); + + for (auto member : decl->parentDecl->members) + { + auto constantId = member->findModifier<VkConstantIdAttribute>(); + if (constantId) + { + SLANG_ASSERT(constantId->args.getCount() == 1); + auto id = checkConstantIntVal(constantId->args[0]); + if (id->getValue() == specConstId) + { + attr->specConstExtents[i] = + DeclRef<VarDeclBase>(member->getDefaultDeclRef()); + break; + } + } + } + + // If not found, we need to create a new specialization + // constant with this ID. + if (!attr->specConstExtents[i]) + { + auto specConstVarDecl = getASTBuilder()->create<VarDecl>(); + auto constantIdModifier = + getASTBuilder()->create<VkConstantIdAttribute>(); + constantIdModifier->location = (int32_t)specConstId; + specConstVarDecl->type.type = getASTBuilder()->getIntType(); + addModifier(specConstVarDecl, constantIdModifier); + decl->parentDecl->addMember(specConstVarDecl); + attr->specConstExtents[i] = + DeclRef<VarDeclBase>(specConstVarDecl->getDefaultDeclRef()); + } + continue; + } + else if (cintVal->getValue() < 1) { getSink()->diagnose( attr, @@ -1856,18 +1935,13 @@ Modifier* SemanticsVisitor::checkModifier( return nullptr; } } - value = intValue; + attr->extents[i] = intValue; } else { - value = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); + attr->extents[i] = m_astBuilder->getIntVal(m_astBuilder->getIntType(), 1); } - values[i] = value; } - - attr->x = values[0]; - attr->y = values[1]; - attr->z = values[2]; } // Default behavior is to leave things as they are, |
