diff options
| -rw-r--r-- | source/slang/glsl.meta.slang | 13 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 17 | ||||
| -rw-r--r-- | tests/pipeline/rasterization/mesh/task-simple.slang | 12 |
3 files changed, 28 insertions, 14 deletions
diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang index b3323a327..dfa858eca 100644 --- a/source/slang/glsl.meta.slang +++ b/source/slang/glsl.meta.slang @@ -141,25 +141,18 @@ public property int gl_DeviceIndex public property uint3 gl_NumWorkGroups { [require(glsl_spirv, GLSL_430_SPIRV_1_0_compute)] + [require(glsl_spirv, meshshading)] get { - __target_switch - { - case glsl: - __intrinsic_asm "(gl_NumWorkGroups)"; - case spirv: - return spirv_asm { - result:$$uint3 = OpLoad builtin(NumWorkgroups:uint3); - }; - } + return WorkgroupCount(); } } -[require(compute)] public property uint3 gl_WorkGroupSize { [__unsafeForceInlineEarly] [require(compute)] + [require(meshshading)] get { return WorkgroupSize(); diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index c1449f3cb..f7efc3a51 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -6619,9 +6619,26 @@ void AllMemoryBarrierWithGroupSync() // Returns the workgroup size of the calling entry point. [require(compute)] +[require(meshshading)] __intrinsic_op($(kIROp_GetWorkGroupSize)) int3 WorkgroupSize(); +// Returns number of workgroups that have been dispatched to a GLSL or SPIR-V compute shader +[require(glsl_spirv, GLSL_430_SPIRV_1_0_compute)] +[require(glsl_spirv, meshshading)] +uint3 WorkgroupCount() +{ + __target_switch + { + case glsl: + __intrinsic_asm "(gl_NumWorkGroups)"; + case spirv: + return spirv_asm { + result:$$uint3 = OpLoad builtin(NumWorkgroups:uint3); + }; + } +} + // Test if any components is non-zero. __generic<T : __BuiltinType> diff --git a/tests/pipeline/rasterization/mesh/task-simple.slang b/tests/pipeline/rasterization/mesh/task-simple.slang index 61cc6da3d..a85fce8c0 100644 --- a/tests/pipeline/rasterization/mesh/task-simple.slang +++ b/tests/pipeline/rasterization/mesh/task-simple.slang @@ -34,12 +34,14 @@ struct MeshPayload int exponent; }; -[numthreads(1, 1, 1)] +const static uint AMPLIFICATION_NUM_THREADS_X = 1; + +[numthreads(AMPLIFICATION_NUM_THREADS_X, 1, 1)] [shader("amplification")] void taskMain(in uint tig : SV_GroupIndex) { MeshPayload p; - p.exponent = 3; + p.exponent = select(AMPLIFICATION_NUM_THREADS_X == WorkgroupSize().x, 3, 0); DispatchMesh(1,1,1,p); } @@ -71,8 +73,10 @@ struct Vertex const static uint MAX_VERTS = 12; const static uint MAX_PRIMS = 4; +const static uint MESH_NUM_THREADS_X = 12; + [outputtopology("triangle")] -[numthreads(12, 1, 1)] +[numthreads(MESH_NUM_THREADS_X, 1, 1)] void meshMain( in uint tig : SV_GroupIndex, in payload MeshPayload meshPayload, @@ -88,7 +92,7 @@ void meshMain( if(tig < numVertices) { - const int tri = tig / 3; + const int tri = select(WorkgroupSize().x == MESH_NUM_THREADS_X, tig / 3, -1); verts[tig] = {float4(positions[tig % 3], 0, 1), colors[tig % 3], tri, int(pow(tri, meshPayload.exponent))}; } |
