summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-emit-spirv.cpp
diff options
context:
space:
mode:
authorJulius Ikkala <julius.ikkala@gmail.com>2025-01-14 20:32:29 +0200
committerGitHub <noreply@github.com>2025-01-14 10:32:29 -0800
commitcbdc7e1219e472fd74f7f559d7e417f233e7df39 (patch)
treee051b90e317a875e264c2c8d951668bf0b7d3ad0 /source/slang/slang-emit-spirv.cpp
parent971996b397711016d47fe961890d7001338c6f23 (diff)
Implement specialization constant support in numthreads / local_size (#5963)
* Allow using specialization constants in numthreads attribute * Add support for GLSL local_size_x_id syntax * Fix overeager specialization constant parsing * Add diagnostics for specialization constant numthreads * Remove unused variable * Fix local_size_x_id not finding existing specialization constant * Allow materializeGetWorkGroupSize to reference specialization constants * Use SpvOpExecutionModeId for modes that require it * Cleanup specialization constant numthreads code * Add tests for specialization constant work group sizes * Fix implicit Slang::Int -> int32_t cast * Fix querying thread group size in reflection API --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-emit-spirv.cpp')
-rw-r--r--source/slang/slang-emit-spirv.cpp55
1 files changed, 38 insertions, 17 deletions
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 068e1563c..2cf84a854 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -4353,23 +4353,36 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
// [3.6. Execution Mode]: LocalSize
case kIROp_NumThreadsDecoration:
{
- // TODO: The `LocalSize` execution mode option requires
- // literal values for the X,Y,Z thread-group sizes.
- // There is a `LocalSizeId` variant that takes `<id>`s
- // for those sizes, and we should consider using that
- // and requiring the appropriate capabilities
- // if any of the operands to the decoration are not
- // literals (in a future where we support non-literals
- // in those positions in the Slang IR).
- //
auto numThreads = cast<IRNumThreadsDecoration>(decoration);
- requireSPIRVExecutionMode(
- decoration,
- dstID,
- SpvExecutionModeLocalSize,
- SpvLiteralInteger::from32(int32_t(numThreads->getX()->getValue())),
- SpvLiteralInteger::from32(int32_t(numThreads->getY()->getValue())),
- SpvLiteralInteger::from32(int32_t(numThreads->getZ()->getValue())));
+ if (numThreads->getXSpecConst() || numThreads->getYSpecConst() ||
+ numThreads->getZSpecConst())
+ {
+ // If any of the dimensions needs an ID, we need to emit
+ // all dimensions as an ID due to how LocalSizeId works.
+ int32_t ids[3];
+ for (int i = 0; i < 3; ++i)
+ ids[i] = ensureInst(numThreads->getOperand(i))->id;
+
+ // LocalSizeId is supported from SPIR-V 1.2 onwards without
+ // any extra capabilities.
+ requireSPIRVExecutionMode(
+ decoration,
+ dstID,
+ SpvExecutionModeLocalSizeId,
+ SpvLiteralInteger::from32(int32_t(ids[0])),
+ SpvLiteralInteger::from32(int32_t(ids[1])),
+ SpvLiteralInteger::from32(int32_t(ids[2])));
+ }
+ else
+ {
+ requireSPIRVExecutionMode(
+ decoration,
+ dstID,
+ SpvExecutionModeLocalSize,
+ SpvLiteralInteger::from32(int32_t(numThreads->getX()->getValue())),
+ SpvLiteralInteger::from32(int32_t(numThreads->getY()->getValue())),
+ SpvLiteralInteger::from32(int32_t(numThreads->getZ()->getValue())));
+ }
}
break;
case kIROp_MaxVertexCountDecoration:
@@ -7977,10 +7990,18 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
{
if (m_executionModes[entryPoint].add(executionMode))
{
+ SpvOp execModeOp = SpvOpExecutionMode;
+ if (executionMode == SpvExecutionModeLocalSizeId ||
+ executionMode == SpvExecutionModeLocalSizeHintId ||
+ executionMode == SpvExecutionModeSubgroupsPerWorkgroupId)
+ {
+ execModeOp = SpvOpExecutionModeId;
+ }
+
emitInst(
getSection(SpvLogicalSectionID::ExecutionModes),
parentInst,
- SpvOpExecutionMode,
+ execModeOp,
entryPoint,
executionMode,
ops...);