diff options
| author | Darren Wihandi <65404740+fairywreath@users.noreply.github.com> | 2025-01-28 23:12:51 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-01-29 04:12:51 +0000 |
| commit | 1c282b80b9fbcfea9dc3dab7f5f546b069143e01 (patch) | |
| tree | 626a858fff466a0f0c54d4afbe4148a1a58caed4 /tests | |
| parent | cf66563cfdcff9b7d76017e5b73319705ccdb735 (diff) | |
Implement WaveMultiPrefix* for SPIRV and GLSL (#6182)
Diffstat (limited to 'tests')
6 files changed, 267 insertions, 42 deletions
diff --git a/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang b/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang new file mode 100644 index 000000000..69240198e --- /dev/null +++ b/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang @@ -0,0 +1,74 @@ +//TEST_CATEGORY(wave, compute) +//DISABLE_TEST:COMPARE_COMPUTE_EX:-cpu -compute -shaderobj +//DISABLE_TEST:COMPARE_COMPUTE_EX:-slang -compute -shaderobj + +//TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile sm_6_5 -shaderobj +//TEST:COMPARE_COMPUTE_EX:-vk -compute -shaderobj +//TEST:COMPARE_COMPUTE_EX:-cuda -compute -render-features cuda_sm_7_0 -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer<uint> outputBuffer; + +[numthreads(8, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint index = int(dispatchThreadID.x); + + // Split into two groups. + uint4 mask = 0b00001111; + if (index >= 4) + { + mask = 0b11110000; + } + + // + // WaveMultiPrefixSum. + // Results in hex: [0 1 3 7], [0 10 30 70] + // + uint sumValue = WaveMultiPrefixSum(1 << index, mask); + const uint sumBaseIndex = 0; + outputBuffer[sumBaseIndex + index] = sumValue; + + // + // WaveMultiPrefixProduct. + // Results in hex: [1 1 2 8], [1 10 200 8000] + // + uint productValue = WaveMultiPrefixProduct(1 << index, mask); + const uint productBaseIndex = 8; + outputBuffer[productBaseIndex + index] = productValue; + + // + // WaveMultiPrefixBitAnd. + // This prefix operation starts with all bits set. + // Results in hex: [FFFFFFFF 1 1 1], [FFFFFFFF F F F] + // + uint andBits = 0b1; + if (index >= 4) + { + andBits = 0b1111; + } + uint andValue = WaveMultiPrefixBitAnd(andBits, mask); + const uint andBaseIndex = 16; + outputBuffer[andBaseIndex + index] = andValue; + + // + // WaveMultiPrefixBitOr. + // Results in hex: [0 1 3 7], [0 10 30 70] + // + uint orValue = WaveMultiPrefixBitOr(1 << index, mask); + const uint orBaseIndex = 24; + outputBuffer[orBaseIndex + index] = orValue; + + // + // WaveMultiPrefixBitXor. + // Results in hex: [0 1 3 7], [0 F 0 F] + // + uint xorBits = (1 << index); + if (index >= 4) + { + xorBits = 0b1111; + } + uint xorValue = WaveMultiPrefixBitXor(xorBits, mask); + const uint xorBaseIndex = 32; + outputBuffer[xorBaseIndex + index] = xorValue; +} diff --git a/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang.expected.txt b/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang.expected.txt new file mode 100644 index 000000000..c80baa5b1 --- /dev/null +++ b/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang.expected.txt @@ -0,0 +1,40 @@ +0 +1 +3 +7 +0 +10 +30 +70 +1 +1 +2 +8 +1 +10 +200 +8000 +FFFFFFFF +1 +1 +1 +FFFFFFFF +F +F +F +0 +1 +3 +7 +0 +10 +30 +70 +0 +1 +3 +7 +0 +F +0 +F diff --git a/tests/hlsl-intrinsic/wave-multi-prefix.slang b/tests/hlsl-intrinsic/wave-multi-prefix.slang index 31dde2af4..99698e497 100644 --- a/tests/hlsl-intrinsic/wave-multi-prefix.slang +++ b/tests/hlsl-intrinsic/wave-multi-prefix.slang @@ -1,27 +1,146 @@ -//TEST_CATEGORY(wave, compute) -//DISABLE_TEST:COMPARE_COMPUTE_EX:-cpu -compute -shaderobj -//DISABLE_TEST:COMPARE_COMPUTE_EX:-slang -compute -shaderobj -// We need SM6.5 for these tests -// Disable because version of dxc we are currently using doesn't support SM6.5 -//DISABLE_TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile sm_6_5 -shaderobj -// Disabled because we don't have GLSL intrinsics for these it seems -//DISABLE_TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -//TEST:COMPARE_COMPUTE_EX:-cuda -compute -render-features cuda_sm_7_0 -shaderobj - -//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer -RWStructuredBuffer<int> outputBuffer; - -[numthreads(8, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +//TEST:SIMPLE(filecheck=CHECK_SPIRV): -stage compute -entry computeMain -target spirv -DNO_INTEGER_MATRIX +//TEST:SIMPLE(filecheck=CHECK_GLSL): -stage compute -entry computeMain -target glsl -DNO_INTEGER_MATRIX +//TEST:SIMPLE(filecheck=CHECK_CUDA): -stage compute -entry computeMain -target cuda +//TEST:SIMPLE(filecheck=CHECK_HLSL): -stage compute -entry computeMain -target hlsl + +// +// Tests all variants and overloads of WaveMultiPrefix* arithmetic intrinsics. +// + +struct OutputData +{ + int scalarSum; + int scalarProduct; + int scalarBitAnd; + int scalarBitOr; + int scalarBitXor; + int vectorSum; + int vectorProduct; + int vectorBitAnd; + int vectorBitOr; + int vectorBitXor; + int matrixSum; + int matrixProduct; + int matrixBitAnd; + int matrixBitOr; + int matrixBitXor; + float floatScalarSum; + float floatScalarProduct; + float floatVectorSum; + float floatVectorProduct; + float floatMatrixSum; + float floatMatrixProduct; +}; + +RWStructuredBuffer<OutputData> outputBuffer; + +// CHECK_SPIRV: OpCapability GroupNonUniformPartitionedNV +// CHECK_SPIRV: OpExtension "SPV_NV_shader_subgroup_partitioned" +// CHECK_SPIRV: OpGroupNonUniformIAdd{{.*}}PartitionedExclusiveScanNV +// CHECK_SPIRV: OpGroupNonUniformIMul{{.*}}PartitionedExclusiveScanNV +// CHECK_SPIRV: OpGroupNonUniformBitwiseAnd{{.*}}PartitionedExclusiveScanNV +// CHECK_SPIRV: OpGroupNonUniformBitwiseOr{{.*}}PartitionedExclusiveScanNV +// CHECK_SPIRV: OpGroupNonUniformBitwiseXor{{.*}}PartitionedExclusiveScanNV +// CHECK_SPIRV: OpGroupNonUniformFAdd{{.*}}PartitionedExclusiveScanNV + +// CHECK_GLSL: GL_NV_shader_subgroup_partitioned +// CHECK_GLSL: subgroupPartitionedExclusiveAddNV +// CHECK_GLSL: subgroupPartitionedExclusiveMulNV +// CHECK_GLSL: subgroupPartitionedExclusiveAndNV +// CHECK_GLSL: subgroupPartitionedExclusiveOrNV +// CHECK_GLSL: subgroupPartitionedExclusiveXorNV + +// CHECK_CUDA: _wavePrefixSum +// CHECK_CUDA: _wavePrefixProduct +// CHECK_CUDA: _wavePrefixAnd +// CHECK_CUDA: _wavePrefixOr +// CHECK_CUDA: _wavePrefixXor +// CHECK_CUDA: _wavePrefixSumMultiple +// CHECK_CUDA: _wavePrefixProductMultiple +// CHECK_CUDA: _wavePrefixAndMultiple +// CHECK_CUDA: _wavePrefixOrMultiple +// CHECK_CUDA: _wavePrefixXorMultiple + +// CHECK_HLSL: WaveMultiPrefixSum +// CHECK_HLSL: WaveMultiPrefixProduct +// CHECK_HLSL: WaveMultiPrefixBitAnd +// CHECK_HLSL: WaveMultiPrefixBitOr +// CHECK_HLSL: WaveMultiPrefixBitXor + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dTid : SV_DispatchThreadID) { - int idx = int(dispatchThreadID.x); - - int value = 0; - - uint4 mask = WaveMatch(true); - - // Scalar - value += WaveMultiPrefixSum(1 << idx, mask); - - outputBuffer[idx] = value; -}
\ No newline at end of file + int scalarVal = dTid.x; + uint4 mask = WaveMatch(scalarVal); + + int scalarSum = WaveMultiPrefixSum(scalarVal, mask); + int scalarProduct = WaveMultiPrefixProduct(scalarVal, mask); + int scalarBitAnd = WaveMultiPrefixBitAnd(scalarVal, mask); + int scalarBitOr = WaveMultiPrefixBitOr(scalarVal, mask); + int scalarBitXor = WaveMultiPrefixBitXor(scalarVal, mask); + + int3 vectorVal = int3(dTid.x, dTid.y, dTid.z); + int3 vectorSum = WaveMultiPrefixSum(vectorVal, mask); + int3 vectorProduct = WaveMultiPrefixProduct(vectorVal, mask); + int3 vectorBitAnd = WaveMultiPrefixBitAnd(vectorVal, mask); + int3 vectorBitOr = WaveMultiPrefixBitOr(vectorVal, mask); + int3 vectorBitXor = WaveMultiPrefixBitXor(vectorVal, mask); + + float floatScalarVal = float(dTid.x) + 0.5f; // Example floating-point scalar value + uint4 floatMask = WaveMatch(floatScalarVal); // Create a mask for matching lanes + + float floatScalarSum = WaveMultiPrefixSum(floatScalarVal, floatMask); + float floatScalarProduct = WaveMultiPrefixProduct(floatScalarVal, floatMask); + + float3 floatVectorVal = float3(dTid.x, dTid.y, dTid.z) + 0.5f; // Example floating-point vector value + float3 floatVectorSum = WaveMultiPrefixSum(floatVectorVal, floatMask); + float3 floatVectorProduct = WaveMultiPrefixProduct(floatVectorVal, floatMask); + + OutputData output; + output.scalarSum = scalarSum; + output.scalarProduct = scalarProduct; + output.scalarBitAnd = scalarBitAnd; + output.scalarBitOr = scalarBitOr; + output.scalarBitXor = scalarBitXor; + output.vectorSum = vectorSum.x; + output.vectorProduct = vectorProduct.x; + output.vectorBitAnd = vectorBitAnd.x; + output.vectorBitOr = vectorBitOr.x; + output.vectorBitXor = vectorBitXor.x; + output.floatScalarSum = floatScalarSum; + output.floatScalarProduct = floatScalarProduct; + output.floatVectorSum = floatVectorSum.x; + output.floatVectorProduct = floatVectorProduct.x; + + float3x3 floatMatrixVal = float3x3( + float(dTid.x) + 0.5f, float(dTid.y) + 0.5f, float(dTid.z) + 0.5f, + float(dTid.z) + 0.5f, float(dTid.x) + 0.5f, float(dTid.y) + 0.5f, + float(dTid.y) + 0.5f, float(dTid.z) + 0.5f, float(dTid.x) + 0.5f + ); + float3x3 floatMatrixSum = WaveMultiPrefixSum(floatMatrixVal, floatMask); + float3x3 floatMatrixProduct = WaveMultiPrefixProduct(floatMatrixVal, floatMask); + output.floatMatrixSum = floatMatrixSum[0][0]; + output.floatMatrixProduct = floatMatrixProduct[0][0]; + +#if !defined(NO_INTEGER_MATRIX) + int3x3 matrixVal = int3x3( + dTid.x, dTid.y, dTid.z, + dTid.z, dTid.x, dTid.y, + dTid.y, dTid.z, dTid.x + ); + int3x3 matrixSum = WaveMultiPrefixSum(matrixVal, mask); + int3x3 matrixProduct = WaveMultiPrefixProduct(matrixVal, mask); + int3x3 matrixBitAnd = WaveMultiPrefixBitAnd(matrixVal, mask); + int3x3 matrixBitOr = WaveMultiPrefixBitOr(matrixVal, mask); + int3x3 matrixBitXor = WaveMultiPrefixBitXor(matrixVal, mask); + output.matrixSum = matrixSum[0][0]; + output.matrixProduct = matrixProduct[0][0]; + output.matrixBitAnd = matrixBitAnd[0][0]; + output.matrixBitOr = matrixBitOr[0][0]; + output.matrixBitXor = matrixBitXor[0][0]; +#endif + + outputBuffer[dTid.x] = output; +} + diff --git a/tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt b/tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt deleted file mode 100644 index 6ec6deeea..000000000 --- a/tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt +++ /dev/null @@ -1,8 +0,0 @@ -0 -1 -3 -7 -F -1F -3F -7F diff --git a/tests/language-feature/capability/capabilitySimplification1.slang b/tests/language-feature/capability/capabilitySimplification1.slang index b694673e9..1d781a45e 100644 --- a/tests/language-feature/capability/capabilitySimplification1.slang +++ b/tests/language-feature/capability/capabilitySimplification1.slang @@ -6,9 +6,9 @@ // CHECK: error 36107 // CHECK-SAME: entrypoint 'computeMain' does not support compilation target 'glsl' with stage 'compute' -// CHECK: capabilitySimplification1.slang(21): note: see using of 'WaveMultiPrefixProduct' -// CHECK-NOT: see using of 'WaveMultiPrefixProduct' -// CHECK: {{.*}}.meta.slang({{.*}}): note: see definition of 'WaveMultiPrefixProduct' +// CHECK: capabilitySimplification1.slang(21): note: see using of 'WaveMultiPrefixCountBits' +// CHECK-NOT: see using of 'WaveMultiPrefixCountBits' +// CHECK: {{.*}}.meta.slang({{.*}}): note: see definition of 'WaveMultiPrefixCountBits' // CHECK: {{.*}}.meta.slang({{.*}}): note: see declaration of 'require' void nestedSafeCall() @@ -18,7 +18,7 @@ void nestedSafeCall() void nestedBadCall() { - WaveMultiPrefixProduct(1, 0); + WaveMultiPrefixCountBits(true, 0); } void nestedCall() diff --git a/tests/language-feature/capability/capabilitySimplification3.slang b/tests/language-feature/capability/capabilitySimplification3.slang index faf161d15..808c19bf6 100644 --- a/tests/language-feature/capability/capabilitySimplification3.slang +++ b/tests/language-feature/capability/capabilitySimplification3.slang @@ -5,13 +5,13 @@ // CHECK_IGNORE_CAPS-NOT: error 36107 // CHECK: error 36107: entrypoint 'computeMain' does not support compilation target 'glsl' with stage 'compute' -// CHECK: capabilitySimplification3.slang(16): note: see using of 'WaveMultiPrefixProduct' -// CHECK-NOT: see using of 'WaveMultiPrefixProduct' -// CHECK: {{.*}}.meta.slang({{.*}}): note: see definition of 'WaveMultiPrefixProduct' +// CHECK: capabilitySimplification3.slang(16): note: see using of 'WaveMultiPrefixCountBits' +// CHECK-NOT: see using of 'WaveMultiPrefixCountBits' +// CHECK: {{.*}}.meta.slang({{.*}}): note: see definition of 'WaveMultiPrefixCountBits' // CHECK: {{.*}}.meta.slang({{.*}}): note: see declaration of 'require' [numthreads(1,1,1)] void computeMain() { - WaveMultiPrefixProduct(1, 0); + WaveMultiPrefixCountBits(true, 0); } |
