summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorDarren Wihandi <65404740+fairywreath@users.noreply.github.com>2025-01-28 23:12:51 -0500
committerGitHub <noreply@github.com>2025-01-29 04:12:51 +0000
commit1c282b80b9fbcfea9dc3dab7f5f546b069143e01 (patch)
tree626a858fff466a0f0c54d4afbe4148a1a58caed4 /tests
parentcf66563cfdcff9b7d76017e5b73319705ccdb735 (diff)
Implement WaveMultiPrefix* for SPIRV and GLSL (#6182)
Diffstat (limited to 'tests')
-rw-r--r--tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang74
-rw-r--r--tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang.expected.txt40
-rw-r--r--tests/hlsl-intrinsic/wave-multi-prefix.slang171
-rw-r--r--tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt8
-rw-r--r--tests/language-feature/capability/capabilitySimplification1.slang8
-rw-r--r--tests/language-feature/capability/capabilitySimplification3.slang8
6 files changed, 267 insertions, 42 deletions
diff --git a/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang b/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang
new file mode 100644
index 000000000..69240198e
--- /dev/null
+++ b/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang
@@ -0,0 +1,74 @@
+//TEST_CATEGORY(wave, compute)
+//DISABLE_TEST:COMPARE_COMPUTE_EX:-cpu -compute -shaderobj
+//DISABLE_TEST:COMPARE_COMPUTE_EX:-slang -compute -shaderobj
+
+//TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile sm_6_5 -shaderobj
+//TEST:COMPARE_COMPUTE_EX:-vk -compute -shaderobj
+//TEST:COMPARE_COMPUTE_EX:-cuda -compute -render-features cuda_sm_7_0 -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
+RWStructuredBuffer<uint> outputBuffer;
+
+[numthreads(8, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint index = int(dispatchThreadID.x);
+
+ // Split into two groups.
+ uint4 mask = 0b00001111;
+ if (index >= 4)
+ {
+ mask = 0b11110000;
+ }
+
+ //
+ // WaveMultiPrefixSum.
+ // Results in hex: [0 1 3 7], [0 10 30 70]
+ //
+ uint sumValue = WaveMultiPrefixSum(1 << index, mask);
+ const uint sumBaseIndex = 0;
+ outputBuffer[sumBaseIndex + index] = sumValue;
+
+ //
+ // WaveMultiPrefixProduct.
+ // Results in hex: [1 1 2 8], [1 10 200 8000]
+ //
+ uint productValue = WaveMultiPrefixProduct(1 << index, mask);
+ const uint productBaseIndex = 8;
+ outputBuffer[productBaseIndex + index] = productValue;
+
+ //
+ // WaveMultiPrefixBitAnd.
+ // This prefix operation starts with all bits set.
+ // Results in hex: [FFFFFFFF 1 1 1], [FFFFFFFF F F F]
+ //
+ uint andBits = 0b1;
+ if (index >= 4)
+ {
+ andBits = 0b1111;
+ }
+ uint andValue = WaveMultiPrefixBitAnd(andBits, mask);
+ const uint andBaseIndex = 16;
+ outputBuffer[andBaseIndex + index] = andValue;
+
+ //
+ // WaveMultiPrefixBitOr.
+ // Results in hex: [0 1 3 7], [0 10 30 70]
+ //
+ uint orValue = WaveMultiPrefixBitOr(1 << index, mask);
+ const uint orBaseIndex = 24;
+ outputBuffer[orBaseIndex + index] = orValue;
+
+ //
+ // WaveMultiPrefixBitXor.
+ // Results in hex: [0 1 3 7], [0 F 0 F]
+ //
+ uint xorBits = (1 << index);
+ if (index >= 4)
+ {
+ xorBits = 0b1111;
+ }
+ uint xorValue = WaveMultiPrefixBitXor(xorBits, mask);
+ const uint xorBaseIndex = 32;
+ outputBuffer[xorBaseIndex + index] = xorValue;
+}
diff --git a/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang.expected.txt b/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang.expected.txt
new file mode 100644
index 000000000..c80baa5b1
--- /dev/null
+++ b/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang.expected.txt
@@ -0,0 +1,40 @@
+0
+1
+3
+7
+0
+10
+30
+70
+1
+1
+2
+8
+1
+10
+200
+8000
+FFFFFFFF
+1
+1
+1
+FFFFFFFF
+F
+F
+F
+0
+1
+3
+7
+0
+10
+30
+70
+0
+1
+3
+7
+0
+F
+0
+F
diff --git a/tests/hlsl-intrinsic/wave-multi-prefix.slang b/tests/hlsl-intrinsic/wave-multi-prefix.slang
index 31dde2af4..99698e497 100644
--- a/tests/hlsl-intrinsic/wave-multi-prefix.slang
+++ b/tests/hlsl-intrinsic/wave-multi-prefix.slang
@@ -1,27 +1,146 @@
-//TEST_CATEGORY(wave, compute)
-//DISABLE_TEST:COMPARE_COMPUTE_EX:-cpu -compute -shaderobj
-//DISABLE_TEST:COMPARE_COMPUTE_EX:-slang -compute -shaderobj
-// We need SM6.5 for these tests
-// Disable because version of dxc we are currently using doesn't support SM6.5
-//DISABLE_TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile sm_6_5 -shaderobj
-// Disabled because we don't have GLSL intrinsics for these it seems
-//DISABLE_TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj
-//TEST:COMPARE_COMPUTE_EX:-cuda -compute -render-features cuda_sm_7_0 -shaderobj
-
-//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
-RWStructuredBuffer<int> outputBuffer;
-
-[numthreads(8, 1, 1)]
-void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+//TEST:SIMPLE(filecheck=CHECK_SPIRV): -stage compute -entry computeMain -target spirv -DNO_INTEGER_MATRIX
+//TEST:SIMPLE(filecheck=CHECK_GLSL): -stage compute -entry computeMain -target glsl -DNO_INTEGER_MATRIX
+//TEST:SIMPLE(filecheck=CHECK_CUDA): -stage compute -entry computeMain -target cuda
+//TEST:SIMPLE(filecheck=CHECK_HLSL): -stage compute -entry computeMain -target hlsl
+
+//
+// Tests all variants and overloads of WaveMultiPrefix* arithmetic intrinsics.
+//
+
+struct OutputData
+{
+ int scalarSum;
+ int scalarProduct;
+ int scalarBitAnd;
+ int scalarBitOr;
+ int scalarBitXor;
+ int vectorSum;
+ int vectorProduct;
+ int vectorBitAnd;
+ int vectorBitOr;
+ int vectorBitXor;
+ int matrixSum;
+ int matrixProduct;
+ int matrixBitAnd;
+ int matrixBitOr;
+ int matrixBitXor;
+ float floatScalarSum;
+ float floatScalarProduct;
+ float floatVectorSum;
+ float floatVectorProduct;
+ float floatMatrixSum;
+ float floatMatrixProduct;
+};
+
+RWStructuredBuffer<OutputData> outputBuffer;
+
+// CHECK_SPIRV: OpCapability GroupNonUniformPartitionedNV
+// CHECK_SPIRV: OpExtension "SPV_NV_shader_subgroup_partitioned"
+// CHECK_SPIRV: OpGroupNonUniformIAdd{{.*}}PartitionedExclusiveScanNV
+// CHECK_SPIRV: OpGroupNonUniformIMul{{.*}}PartitionedExclusiveScanNV
+// CHECK_SPIRV: OpGroupNonUniformBitwiseAnd{{.*}}PartitionedExclusiveScanNV
+// CHECK_SPIRV: OpGroupNonUniformBitwiseOr{{.*}}PartitionedExclusiveScanNV
+// CHECK_SPIRV: OpGroupNonUniformBitwiseXor{{.*}}PartitionedExclusiveScanNV
+// CHECK_SPIRV: OpGroupNonUniformFAdd{{.*}}PartitionedExclusiveScanNV
+
+// CHECK_GLSL: GL_NV_shader_subgroup_partitioned
+// CHECK_GLSL: subgroupPartitionedExclusiveAddNV
+// CHECK_GLSL: subgroupPartitionedExclusiveMulNV
+// CHECK_GLSL: subgroupPartitionedExclusiveAndNV
+// CHECK_GLSL: subgroupPartitionedExclusiveOrNV
+// CHECK_GLSL: subgroupPartitionedExclusiveXorNV
+
+// CHECK_CUDA: _wavePrefixSum
+// CHECK_CUDA: _wavePrefixProduct
+// CHECK_CUDA: _wavePrefixAnd
+// CHECK_CUDA: _wavePrefixOr
+// CHECK_CUDA: _wavePrefixXor
+// CHECK_CUDA: _wavePrefixSumMultiple
+// CHECK_CUDA: _wavePrefixProductMultiple
+// CHECK_CUDA: _wavePrefixAndMultiple
+// CHECK_CUDA: _wavePrefixOrMultiple
+// CHECK_CUDA: _wavePrefixXorMultiple
+
+// CHECK_HLSL: WaveMultiPrefixSum
+// CHECK_HLSL: WaveMultiPrefixProduct
+// CHECK_HLSL: WaveMultiPrefixBitAnd
+// CHECK_HLSL: WaveMultiPrefixBitOr
+// CHECK_HLSL: WaveMultiPrefixBitXor
+
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dTid : SV_DispatchThreadID)
{
- int idx = int(dispatchThreadID.x);
-
- int value = 0;
-
- uint4 mask = WaveMatch(true);
-
- // Scalar
- value += WaveMultiPrefixSum(1 << idx, mask);
-
- outputBuffer[idx] = value;
-} \ No newline at end of file
+ int scalarVal = dTid.x;
+ uint4 mask = WaveMatch(scalarVal);
+
+ int scalarSum = WaveMultiPrefixSum(scalarVal, mask);
+ int scalarProduct = WaveMultiPrefixProduct(scalarVal, mask);
+ int scalarBitAnd = WaveMultiPrefixBitAnd(scalarVal, mask);
+ int scalarBitOr = WaveMultiPrefixBitOr(scalarVal, mask);
+ int scalarBitXor = WaveMultiPrefixBitXor(scalarVal, mask);
+
+ int3 vectorVal = int3(dTid.x, dTid.y, dTid.z);
+ int3 vectorSum = WaveMultiPrefixSum(vectorVal, mask);
+ int3 vectorProduct = WaveMultiPrefixProduct(vectorVal, mask);
+ int3 vectorBitAnd = WaveMultiPrefixBitAnd(vectorVal, mask);
+ int3 vectorBitOr = WaveMultiPrefixBitOr(vectorVal, mask);
+ int3 vectorBitXor = WaveMultiPrefixBitXor(vectorVal, mask);
+
+ float floatScalarVal = float(dTid.x) + 0.5f; // Example floating-point scalar value
+ uint4 floatMask = WaveMatch(floatScalarVal); // Create a mask for matching lanes
+
+ float floatScalarSum = WaveMultiPrefixSum(floatScalarVal, floatMask);
+ float floatScalarProduct = WaveMultiPrefixProduct(floatScalarVal, floatMask);
+
+ float3 floatVectorVal = float3(dTid.x, dTid.y, dTid.z) + 0.5f; // Example floating-point vector value
+ float3 floatVectorSum = WaveMultiPrefixSum(floatVectorVal, floatMask);
+ float3 floatVectorProduct = WaveMultiPrefixProduct(floatVectorVal, floatMask);
+
+ OutputData output;
+ output.scalarSum = scalarSum;
+ output.scalarProduct = scalarProduct;
+ output.scalarBitAnd = scalarBitAnd;
+ output.scalarBitOr = scalarBitOr;
+ output.scalarBitXor = scalarBitXor;
+ output.vectorSum = vectorSum.x;
+ output.vectorProduct = vectorProduct.x;
+ output.vectorBitAnd = vectorBitAnd.x;
+ output.vectorBitOr = vectorBitOr.x;
+ output.vectorBitXor = vectorBitXor.x;
+ output.floatScalarSum = floatScalarSum;
+ output.floatScalarProduct = floatScalarProduct;
+ output.floatVectorSum = floatVectorSum.x;
+ output.floatVectorProduct = floatVectorProduct.x;
+
+ float3x3 floatMatrixVal = float3x3(
+ float(dTid.x) + 0.5f, float(dTid.y) + 0.5f, float(dTid.z) + 0.5f,
+ float(dTid.z) + 0.5f, float(dTid.x) + 0.5f, float(dTid.y) + 0.5f,
+ float(dTid.y) + 0.5f, float(dTid.z) + 0.5f, float(dTid.x) + 0.5f
+ );
+ float3x3 floatMatrixSum = WaveMultiPrefixSum(floatMatrixVal, floatMask);
+ float3x3 floatMatrixProduct = WaveMultiPrefixProduct(floatMatrixVal, floatMask);
+ output.floatMatrixSum = floatMatrixSum[0][0];
+ output.floatMatrixProduct = floatMatrixProduct[0][0];
+
+#if !defined(NO_INTEGER_MATRIX)
+ int3x3 matrixVal = int3x3(
+ dTid.x, dTid.y, dTid.z,
+ dTid.z, dTid.x, dTid.y,
+ dTid.y, dTid.z, dTid.x
+ );
+ int3x3 matrixSum = WaveMultiPrefixSum(matrixVal, mask);
+ int3x3 matrixProduct = WaveMultiPrefixProduct(matrixVal, mask);
+ int3x3 matrixBitAnd = WaveMultiPrefixBitAnd(matrixVal, mask);
+ int3x3 matrixBitOr = WaveMultiPrefixBitOr(matrixVal, mask);
+ int3x3 matrixBitXor = WaveMultiPrefixBitXor(matrixVal, mask);
+ output.matrixSum = matrixSum[0][0];
+ output.matrixProduct = matrixProduct[0][0];
+ output.matrixBitAnd = matrixBitAnd[0][0];
+ output.matrixBitOr = matrixBitOr[0][0];
+ output.matrixBitXor = matrixBitXor[0][0];
+#endif
+
+ outputBuffer[dTid.x] = output;
+}
+
diff --git a/tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt b/tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt
deleted file mode 100644
index 6ec6deeea..000000000
--- a/tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt
+++ /dev/null
@@ -1,8 +0,0 @@
-0
-1
-3
-7
-F
-1F
-3F
-7F
diff --git a/tests/language-feature/capability/capabilitySimplification1.slang b/tests/language-feature/capability/capabilitySimplification1.slang
index b694673e9..1d781a45e 100644
--- a/tests/language-feature/capability/capabilitySimplification1.slang
+++ b/tests/language-feature/capability/capabilitySimplification1.slang
@@ -6,9 +6,9 @@
// CHECK: error 36107
// CHECK-SAME: entrypoint 'computeMain' does not support compilation target 'glsl' with stage 'compute'
-// CHECK: capabilitySimplification1.slang(21): note: see using of 'WaveMultiPrefixProduct'
-// CHECK-NOT: see using of 'WaveMultiPrefixProduct'
-// CHECK: {{.*}}.meta.slang({{.*}}): note: see definition of 'WaveMultiPrefixProduct'
+// CHECK: capabilitySimplification1.slang(21): note: see using of 'WaveMultiPrefixCountBits'
+// CHECK-NOT: see using of 'WaveMultiPrefixCountBits'
+// CHECK: {{.*}}.meta.slang({{.*}}): note: see definition of 'WaveMultiPrefixCountBits'
// CHECK: {{.*}}.meta.slang({{.*}}): note: see declaration of 'require'
void nestedSafeCall()
@@ -18,7 +18,7 @@ void nestedSafeCall()
void nestedBadCall()
{
- WaveMultiPrefixProduct(1, 0);
+ WaveMultiPrefixCountBits(true, 0);
}
void nestedCall()
diff --git a/tests/language-feature/capability/capabilitySimplification3.slang b/tests/language-feature/capability/capabilitySimplification3.slang
index faf161d15..808c19bf6 100644
--- a/tests/language-feature/capability/capabilitySimplification3.slang
+++ b/tests/language-feature/capability/capabilitySimplification3.slang
@@ -5,13 +5,13 @@
// CHECK_IGNORE_CAPS-NOT: error 36107
// CHECK: error 36107: entrypoint 'computeMain' does not support compilation target 'glsl' with stage 'compute'
-// CHECK: capabilitySimplification3.slang(16): note: see using of 'WaveMultiPrefixProduct'
-// CHECK-NOT: see using of 'WaveMultiPrefixProduct'
-// CHECK: {{.*}}.meta.slang({{.*}}): note: see definition of 'WaveMultiPrefixProduct'
+// CHECK: capabilitySimplification3.slang(16): note: see using of 'WaveMultiPrefixCountBits'
+// CHECK-NOT: see using of 'WaveMultiPrefixCountBits'
+// CHECK: {{.*}}.meta.slang({{.*}}): note: see definition of 'WaveMultiPrefixCountBits'
// CHECK: {{.*}}.meta.slang({{.*}}): note: see declaration of 'require'
[numthreads(1,1,1)]
void computeMain()
{
- WaveMultiPrefixProduct(1, 0);
+ WaveMultiPrefixCountBits(true, 0);
}