summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/pipeline/rasterization/mesh/task-simple.slang12
1 files changed, 8 insertions, 4 deletions
diff --git a/tests/pipeline/rasterization/mesh/task-simple.slang b/tests/pipeline/rasterization/mesh/task-simple.slang
index 61cc6da3d..a85fce8c0 100644
--- a/tests/pipeline/rasterization/mesh/task-simple.slang
+++ b/tests/pipeline/rasterization/mesh/task-simple.slang
@@ -34,12 +34,14 @@ struct MeshPayload
int exponent;
};
-[numthreads(1, 1, 1)]
+const static uint AMPLIFICATION_NUM_THREADS_X = 1;
+
+[numthreads(AMPLIFICATION_NUM_THREADS_X, 1, 1)]
[shader("amplification")]
void taskMain(in uint tig : SV_GroupIndex)
{
MeshPayload p;
- p.exponent = 3;
+ p.exponent = select(AMPLIFICATION_NUM_THREADS_X == WorkgroupSize().x, 3, 0);
DispatchMesh(1,1,1,p);
}
@@ -71,8 +73,10 @@ struct Vertex
const static uint MAX_VERTS = 12;
const static uint MAX_PRIMS = 4;
+const static uint MESH_NUM_THREADS_X = 12;
+
[outputtopology("triangle")]
-[numthreads(12, 1, 1)]
+[numthreads(MESH_NUM_THREADS_X, 1, 1)]
void meshMain(
in uint tig : SV_GroupIndex,
in payload MeshPayload meshPayload,
@@ -88,7 +92,7 @@ void meshMain(
if(tig < numVertices)
{
- const int tri = tig / 3;
+ const int tri = select(WorkgroupSize().x == MESH_NUM_THREADS_X, tig / 3, -1);
verts[tig] = {float4(positions[tig % 3], 0, 1), colors[tig % 3], tri, int(pow(tri, meshPayload.exponent))};
}