diff options
| author | Julius Ikkala <julius.ikkala@gmail.com> | 2025-01-14 20:32:29 +0200 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-14 10:32:29 -0800 |
| commit | cbdc7e1219e472fd74f7f559d7e417f233e7df39 (patch) | |
| tree | e051b90e317a875e264c2c8d951668bf0b7d3ad0 /source/slang/slang-reflection-api.cpp | |
| parent | 971996b397711016d47fe961890d7001338c6f23 (diff) | |
Implement specialization constant support in numthreads / local_size (#5963)
* Allow using specialization constants in numthreads attribute
* Add support for GLSL local_size_x_id syntax
* Fix overeager specialization constant parsing
* Add diagnostics for specialization constant numthreads
* Remove unused variable
* Fix local_size_x_id not finding existing specialization constant
* Allow materializeGetWorkGroupSize to reference specialization constants
* Use SpvOpExecutionModeId for modes that require it
* Cleanup specialization constant numthreads code
* Add tests for specialization constant work group sizes
* Fix implicit Slang::Int -> int32_t cast
* Fix querying thread group size in reflection API
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-reflection-api.cpp')
| -rw-r--r-- | source/slang/slang-reflection-api.cpp | 20 |
1 files changed, 8 insertions, 12 deletions
diff --git a/source/slang/slang-reflection-api.cpp b/source/slang/slang-reflection-api.cpp index d235c8270..d1adfedc0 100644 --- a/source/slang/slang-reflection-api.cpp +++ b/source/slang/slang-reflection-api.cpp @@ -4033,18 +4033,14 @@ SLANG_API void spReflectionEntryPoint_getComputeThreadGroupSize( auto numThreadsAttribute = entryPointFunc.getDecl()->findModifier<NumThreadsAttribute>(); if (numThreadsAttribute) { - if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->x)) - sizeAlongAxis[0] = (SlangUInt)cint->getValue(); - else if (numThreadsAttribute->x) - sizeAlongAxis[0] = 0; - if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->y)) - sizeAlongAxis[1] = (SlangUInt)cint->getValue(); - else if (numThreadsAttribute->y) - sizeAlongAxis[1] = 0; - if (auto cint = entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->z)) - sizeAlongAxis[2] = (SlangUInt)cint->getValue(); - else if (numThreadsAttribute->z) - sizeAlongAxis[2] = 0; + for (int i = 0; i < 3; ++i) + { + if (auto cint = + entryPointLayout->program->tryFoldIntVal(numThreadsAttribute->extents[i])) + sizeAlongAxis[i] = (SlangUInt)cint->getValue(); + else if (numThreadsAttribute->extents[i]) + sizeAlongAxis[i] = 0; + } } // |
