From cbdc7e1219e472fd74f7f559d7e417f233e7df39 Mon Sep 17 00:00:00 2001 From: Julius Ikkala Date: Tue, 14 Jan 2025 20:32:29 +0200 Subject: Implement specialization constant support in numthreads / local_size (#5963) * Allow using specialization constants in numthreads attribute * Add support for GLSL local_size_x_id syntax * Fix overeager specialization constant parsing * Add diagnostics for specialization constant numthreads * Remove unused variable * Fix local_size_x_id not finding existing specialization constant * Allow materializeGetWorkGroupSize to reference specialization constants * Use SpvOpExecutionModeId for modes that require it * Cleanup specialization constant numthreads code * Add tests for specialization constant work group sizes * Fix implicit Slang::Int -> int32_t cast * Fix querying thread group size in reflection API --------- Co-authored-by: Yong He --- source/slang/slang-parser.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) (limited to 'source/slang/slang-parser.cpp') diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index c275a868b..6ae41a2eb 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -8437,7 +8437,9 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) int localSizeIndex = -1; if (nameText.startsWith(localSizePrefix) && - nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1) + (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 1 || + (nameText.endsWith("_id") && + (nameText.getLength() == SLANG_COUNT_OF(localSizePrefix) - 1 + 4)))) { char lastChar = nameText[SLANG_COUNT_OF(localSizePrefix) - 1]; localSizeIndex = (lastChar >= 'x' && lastChar <= 'z') ? (lastChar - 'x') : -1; @@ -8451,6 +8453,8 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) numThreadsAttrib->args.setCount(3); for (auto& i : numThreadsAttrib->args) i = nullptr; + for (auto& b : numThreadsAttrib->axisIsSpecConstId) + b = false; // Just mark the loc and name from the first in the list numThreadsAttrib->keywordName = getName(parser, "numthreads"); @@ -8467,6 +8471,11 @@ static NodeBase* parseLayoutModifier(Parser* parser, void* /*userData*/) } numThreadsAttrib->args[localSizeIndex] = expr; + + // We can't resolve the specialization constant declaration + // here, because it may not even exist. IDs pointing to unnamed + // specialization constants are allowed in GLSL. + numThreadsAttrib->axisIsSpecConstId[localSizeIndex] = nameText.endsWith("_id"); } } else if (nameText == "derivative_group_quadsNV") -- cgit v1.2.3