summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDarren Wihandi <65404740+fairywreath@users.noreply.github.com>2025-01-28 23:12:51 -0500
committerGitHub <noreply@github.com>2025-01-29 04:12:51 +0000
commit1c282b80b9fbcfea9dc3dab7f5f546b069143e01 (patch)
tree626a858fff466a0f0c54d4afbe4148a1a58caed4
parentcf66563cfdcff9b7d76017e5b73319705ccdb735 (diff)
Implement WaveMultiPrefix* for SPIRV and GLSL (#6182)
-rw-r--r--docs/user-guide/a3-02-reference-capability-atoms.md2
-rw-r--r--source/slang/hlsl.meta.slang250
-rw-r--r--source/slang/slang-capabilities.capdef4
-rw-r--r--tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang74
-rw-r--r--tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang.expected.txt40
-rw-r--r--tests/hlsl-intrinsic/wave-multi-prefix.slang171
-rw-r--r--tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt8
-rw-r--r--tests/language-feature/capability/capabilitySimplification1.slang8
-rw-r--r--tests/language-feature/capability/capabilitySimplification3.slang8
9 files changed, 484 insertions, 81 deletions
diff --git a/docs/user-guide/a3-02-reference-capability-atoms.md b/docs/user-guide/a3-02-reference-capability-atoms.md
index a70a9f88c..092dab9f0 100644
--- a/docs/user-guide/a3-02-reference-capability-atoms.md
+++ b/docs/user-guide/a3-02-reference-capability-atoms.md
@@ -852,7 +852,7 @@ Compound Capabilities
`shadermemorycontrol`
> (gfx targets) Capabilities needed to use memory barriers
-`waveprefix`
+`wave_multi_prefix`
> Capabilities needed to use HLSL tier wave operations
`bufferreference`
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index ba5c95a0c..1853a82b6 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -15517,7 +15517,7 @@ uint4 WaveMatch(matrix<T,N,M> value)
}
/// @category wave
-[require(cuda_hlsl, waveprefix)]
+[require(cuda_hlsl, wave_multi_prefix)]
uint WaveMultiPrefixCountBits(bool value, uint4 mask)
{
__target_switch
@@ -15528,190 +15528,366 @@ uint WaveMultiPrefixCountBits(bool value, uint4 mask)
}
/// @category wave
-__generic<T : __BuiltinArithmeticType>
-__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__generic<T : __BuiltinIntegerType>
+__glsl_extension(GL_NV_shader_subgroup_partitioned)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl, waveprefix)]
+[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)]
T WaveMultiPrefixBitAnd(T expr, uint4 mask)
{
__target_switch
{
case cuda: __intrinsic_asm "_wavePrefixAnd(_getMultiPrefixMask(($1).x), $0)";
- case glsl: __intrinsic_asm "subgroupExclusiveAnd($0)";
case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd";
+ case glsl: __intrinsic_asm "subgroupPartitionedExclusiveAndNV";
+ case spirv:
+ return spirv_asm
+ {
+ OpExtension "SPV_NV_shader_subgroup_partitioned";
+ OpCapability GroupNonUniformPartitionedNV;
+ result:$$T = OpGroupNonUniformBitwiseAnd Subgroup PartitionedExclusiveScanNV $expr $mask
+ };
}
}
-__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__generic<T : __BuiltinIntegerType, let N : int>
+__glsl_extension(GL_NV_shader_subgroup_partitioned)
__spirv_version(1.3)
-__generic<T : __BuiltinArithmeticType, let N : int>
-[require(cuda_glsl_hlsl, waveprefix)]
+[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)]
vector<T,N> WaveMultiPrefixBitAnd(vector<T,N> expr, uint4 mask)
{
__target_switch
{
case cuda: __intrinsic_asm "_wavePrefixAndMultiple(_getMultiPrefixMask(($1).x), $0)";
- case glsl: __intrinsic_asm "subgroupExclusiveAnd($0)";
case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd";
+ case glsl: __intrinsic_asm "subgroupPartitionedExclusiveAndNV";
+ case spirv:
+ return spirv_asm
+ {
+ OpExtension "SPV_NV_shader_subgroup_partitioned";
+ OpCapability GroupNonUniformPartitionedNV;
+ result:$$vector<T,N> = OpGroupNonUniformBitwiseAnd Subgroup PartitionedExclusiveScanNV $expr $mask
+ };
}
}
-__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-[require(cuda_hlsl, waveprefix)]
+__generic<T : __BuiltinIntegerType, let N : int, let M : int>
+[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)]
matrix<T,N,M> WaveMultiPrefixBitAnd(matrix<T,N,M> expr, uint4 mask)
{
__target_switch
{
case cuda: __intrinsic_asm "_wavePrefixAndMultiple(_getMultiPrefixMask(($1).x), $0)";
case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd";
+ case glsl:
+ case spirv:
+ matrix<T, N, M> result;
+ for (int i = 0; i < N; ++i)
+ result[i] = WaveMultiPrefixBitAnd(expr[i], mask);
+ return result;
}
}
/// @category wave
-__generic<T : __BuiltinArithmeticType>
-__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__generic<T : __BuiltinIntegerType>
+__glsl_extension(GL_NV_shader_subgroup_partitioned)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl, waveprefix)]
+[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)]
T WaveMultiPrefixBitOr(T expr, uint4 mask)
{
__target_switch
{
case cuda: __intrinsic_asm "_wavePrefixOr(, _getMultiPrefixMask(($1).x), $0)";
- case glsl: __intrinsic_asm "subgroupExclusiveOr($0)";
case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr";
+ case glsl: __intrinsic_asm "subgroupPartitionedExclusiveOrNV";
+ case spirv:
+ return spirv_asm
+ {
+ OpExtension "SPV_NV_shader_subgroup_partitioned";
+ OpCapability GroupNonUniformPartitionedNV;
+ result:$$T = OpGroupNonUniformBitwiseOr Subgroup PartitionedExclusiveScanNV $expr $mask
+ };
}
}
-__generic<T : __BuiltinArithmeticType, let N : int>
-__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__generic<T : __BuiltinIntegerType, let N : int>
+__glsl_extension(GL_NV_shader_subgroup_partitioned)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl, waveprefix)]
+[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)]
vector<T,N> WaveMultiPrefixBitOr(vector<T,N> expr, uint4 mask)
{
__target_switch
{
case cuda: __intrinsic_asm "_wavePrefixOrMultiple(_getMultiPrefixMask(($1).x), $0)";
- case glsl: __intrinsic_asm "subgroupExclusiveOr($0)";
case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr";
+ case glsl: __intrinsic_asm "subgroupPartitionedExclusiveOrNV";
+ case spirv:
+ return spirv_asm
+ {
+ OpExtension "SPV_NV_shader_subgroup_partitioned";
+ OpCapability GroupNonUniformPartitionedNV;
+ result:$$vector<T,N> = OpGroupNonUniformBitwiseOr Subgroup PartitionedExclusiveScanNV $expr $mask
+ };
}
}
-__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-[require(cuda_hlsl, waveprefix)]
+__generic<T : __BuiltinIntegerType, let N : int, let M : int>
+[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)]
matrix<T,N,M> WaveMultiPrefixBitOr(matrix<T,N,M> expr, uint4 mask)
{
__target_switch
{
case cuda: __intrinsic_asm "_wavePrefixOrMultiple(_getMultiPrefixMask(($1).x), $0)";
case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr";
+ case glsl:
+ case spirv:
+ matrix<T, N, M> result;
+ for (int i = 0; i < N; ++i)
+ result[i] = WaveMultiPrefixBitOr(expr[i], mask);
+ return result;
}
}
/// @category wave
-__generic<T : __BuiltinArithmeticType>
-__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__generic<T : __BuiltinIntegerType>
+__glsl_extension(GL_NV_shader_subgroup_partitioned)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl, waveprefix)]
+[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)]
T WaveMultiPrefixBitXor(T expr, uint4 mask)
{
__target_switch
{
case cuda: __intrinsic_asm "_wavePrefixXor(_getMultiPrefixMask(($1).x), $0)";
- case glsl: __intrinsic_asm "subgroupExclusiveXor($0)";
case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor";
+ case glsl: __intrinsic_asm "subgroupPartitionedExclusiveXorNV";
+ case spirv:
+ return spirv_asm
+ {
+ OpExtension "SPV_NV_shader_subgroup_partitioned";
+ OpCapability GroupNonUniformPartitionedNV;
+ result:$$T = OpGroupNonUniformBitwiseXor Subgroup PartitionedExclusiveScanNV $expr $mask
+ };
}
}
-__generic<T : __BuiltinArithmeticType, let N : int>
-__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__generic<T : __BuiltinIntegerType, let N : int>
+__glsl_extension(GL_NV_shader_subgroup_partitioned)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl, waveprefix)]
+[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)]
vector<T,N> WaveMultiPrefixBitXor(vector<T,N> expr, uint4 mask)
{
__target_switch
{
case cuda: __intrinsic_asm "_wavePrefixXorMultiple(_getMultiPrefixMask(($1).x), $0)";
- case glsl: __intrinsic_asm "subgroupExclusiveXor($0)";
case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor";
+ case glsl: __intrinsic_asm "subgroupPartitionedExclusiveXorNV";
+ case spirv:
+ return spirv_asm
+ {
+ OpExtension "SPV_NV_shader_subgroup_partitioned";
+ OpCapability GroupNonUniformPartitionedNV;
+ result:$$vector<T,N> = OpGroupNonUniformBitwiseXor Subgroup PartitionedExclusiveScanNV $expr $mask
+ };
}
}
-__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-[require(cuda_hlsl, waveprefix)]
+__generic<T : __BuiltinIntegerType, let N : int, let M : int>
+[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)]
matrix<T,N,M> WaveMultiPrefixBitXor(matrix<T,N,M> expr, uint4 mask)
{
__target_switch
{
case cuda: __intrinsic_asm "_wavePrefixXorMultiple(_getMultiPrefixMask(($1).x), $0)";
case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor";
+ case glsl:
+ case spirv:
+ matrix<T, N, M> result;
+ for (int i = 0; i < N; ++i)
+ result[i] = WaveMultiPrefixBitXor(expr[i], mask);
+ return result;
}
}
/// @category wave
__generic<T : __BuiltinArithmeticType>
-[require(cuda_hlsl, waveprefix)]
+__glsl_extension(GL_NV_shader_subgroup_partitioned)
+__spirv_version(1.3)
+[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)]
T WaveMultiPrefixProduct(T value, uint4 mask)
{
__target_switch
{
case cuda: __intrinsic_asm "_wavePrefixProduct(_getMultiPrefixMask(($1).x), $0)";
case hlsl: __intrinsic_asm "WaveMultiPrefixProduct";
+ case glsl: __intrinsic_asm "subgroupPartitionedExclusiveMulNV";
+ case spirv:
+ {
+ spirv_asm
+ {
+ OpExtension "SPV_NV_shader_subgroup_partitioned";
+ OpCapability GroupNonUniformPartitionedNV;
+ };
+
+ if (__isFloat<T>())
+ {
+ return spirv_asm
+ {
+ result:$$T = OpGroupNonUniformFMul Subgroup PartitionedExclusiveScanNV $value $mask
+ };
+ }
+ else
+ {
+ return spirv_asm
+ {
+ result:$$T = OpGroupNonUniformIMul Subgroup PartitionedExclusiveScanNV $value $mask
+ };
+ }
+ }
}
}
__generic<T : __BuiltinArithmeticType, let N : int>
-[require(cuda_hlsl, waveprefix)]
+__glsl_extension(GL_NV_shader_subgroup_partitioned)
+__spirv_version(1.3)
+[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)]
vector<T,N> WaveMultiPrefixProduct(vector<T,N> value, uint4 mask)
{
__target_switch
{
case cuda: __intrinsic_asm "_wavePrefixProductMultiple(_getMultiPrefixMask(($1).x), $0)";
case hlsl: __intrinsic_asm "WaveMultiPrefixProduct";
+ case glsl: __intrinsic_asm "subgroupPartitionedExclusiveMulNV";
+ case spirv:
+ {
+ spirv_asm
+ {
+ OpExtension "SPV_NV_shader_subgroup_partitioned";
+ OpCapability GroupNonUniformPartitionedNV;
+ };
+
+ if (__isFloat<T>())
+ {
+ return spirv_asm
+ {
+ result:$$vector<T,N> = OpGroupNonUniformFMul Subgroup PartitionedExclusiveScanNV $value $mask
+ };
+ }
+ else
+ {
+ return spirv_asm
+ {
+ result:$$vector<T,N> = OpGroupNonUniformIMul Subgroup PartitionedExclusiveScanNV $value $mask
+ };
+ }
+ }
}
}
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-[require(cuda_hlsl, waveprefix)]
+[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)]
matrix<T,N,M> WaveMultiPrefixProduct(matrix<T,N,M> value, uint4 mask)
{
__target_switch
{
case cuda: __intrinsic_asm "_wavePrefixProductMultiple(_getMultiPrefixMask(($1).x), $0)";
case hlsl: __intrinsic_asm "WaveMultiPrefixProduct";
+ case glsl:
+ case spirv:
+ matrix<T, N, M> result;
+ for (int i = 0; i < N; ++i)
+ result[i] = WaveMultiPrefixProduct(value[i], mask);
+ return result;
}
}
/// @category wave
__generic<T : __BuiltinArithmeticType>
-[require(cuda_hlsl, waveprefix)]
+__glsl_extension(GL_NV_shader_subgroup_partitioned)
+__spirv_version(1.3)
+[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)]
T WaveMultiPrefixSum(T value, uint4 mask)
{
__target_switch
{
case cuda: __intrinsic_asm "_wavePrefixSum(_getMultiPrefixMask(($1).x), $0)";
case hlsl: __intrinsic_asm "WaveMultiPrefixSum";
+ case glsl: __intrinsic_asm "subgroupPartitionedExclusiveAddNV";
+ case spirv:
+ {
+ spirv_asm
+ {
+ OpExtension "SPV_NV_shader_subgroup_partitioned";
+ OpCapability GroupNonUniformPartitionedNV;
+ };
+
+ if (__isFloat<T>())
+ {
+ return spirv_asm
+ {
+ result:$$T = OpGroupNonUniformFAdd Subgroup PartitionedExclusiveScanNV $value $mask
+ };
+ }
+ else
+ {
+ return spirv_asm
+ {
+ result:$$T = OpGroupNonUniformIAdd Subgroup PartitionedExclusiveScanNV $value $mask
+ };
+ }
+ }
}
}
__generic<T : __BuiltinArithmeticType, let N : int>
-[require(cuda_hlsl, waveprefix)]
+__glsl_extension(GL_NV_shader_subgroup_partitioned)
+[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)]
+__spirv_version(1.3)
vector<T,N> WaveMultiPrefixSum(vector<T,N> value, uint4 mask)
{
__target_switch
{
case cuda: __intrinsic_asm "_wavePrefixSumMultiple(_getMultiPrefixMask(($1).x), $0 )";
case hlsl: __intrinsic_asm "WaveMultiPrefixSum";
+ case glsl: __intrinsic_asm "subgroupPartitionedExclusiveAddNV";
+ case spirv:
+ {
+ spirv_asm
+ {
+ OpExtension "SPV_NV_shader_subgroup_partitioned";
+ OpCapability GroupNonUniformPartitionedNV;
+ };
+
+ if (__isFloat<T>())
+ {
+ return spirv_asm
+ {
+ result:$$vector<T,N> = OpGroupNonUniformFAdd Subgroup PartitionedExclusiveScanNV $value $mask
+ };
+ }
+ else
+ {
+ return spirv_asm
+ {
+ result:$$vector<T,N> = OpGroupNonUniformIAdd Subgroup PartitionedExclusiveScanNV $value $mask
+ };
+ }
+ }
}
}
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-[require(cuda_hlsl, waveprefix)]
+[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)]
matrix<T,N,M> WaveMultiPrefixSum(matrix<T,N,M> value, uint4 mask)
{
__target_switch
{
case cuda: __intrinsic_asm "_wavePrefixSumMultiple(_getMultiPrefixMask(($1).x), $0)";
case hlsl: __intrinsic_asm "WaveMultiPrefixSum";
+ case glsl:
+ case spirv:
+ matrix<T, N, M> result;
+ for (int i = 0; i < N; ++i)
+ result[i] = WaveMultiPrefixSum(value[i], mask);
+ return result;
}
}
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index 4f6357779..3bc54c080 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -1024,7 +1024,9 @@ alias fragmentshaderbarycentric = GL_EXT_fragment_shader_barycentric | _sm_6_1;
alias shadermemorycontrol = glsl | _spirv_1_0 | _sm_5_0;
/// Capabilities needed to use HLSL tier wave operations
/// [Compound]
-alias waveprefix = _sm_6_5 | _cuda_sm_7_0 | GL_KHR_shader_subgroup_arithmetic;
+alias wave_multi_prefix = _sm_6_5
+ | _cuda_sm_7_0
+ | GL_KHR_shader_subgroup_ballot + GL_KHR_shader_subgroup_arithmetic + GL_NV_shader_subgroup_partitioned;
/// Capabilities needed to use GLSL buffer-reference's
/// [Compound]
alias bufferreference = GL_EXT_buffer_reference;
diff --git a/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang b/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang
new file mode 100644
index 000000000..69240198e
--- /dev/null
+++ b/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang
@@ -0,0 +1,74 @@
+//TEST_CATEGORY(wave, compute)
+//DISABLE_TEST:COMPARE_COMPUTE_EX:-cpu -compute -shaderobj
+//DISABLE_TEST:COMPARE_COMPUTE_EX:-slang -compute -shaderobj
+
+//TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile sm_6_5 -shaderobj
+//TEST:COMPARE_COMPUTE_EX:-vk -compute -shaderobj
+//TEST:COMPARE_COMPUTE_EX:-cuda -compute -render-features cuda_sm_7_0 -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
+RWStructuredBuffer<uint> outputBuffer;
+
+[numthreads(8, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint index = int(dispatchThreadID.x);
+
+ // Split into two groups.
+ uint4 mask = 0b00001111;
+ if (index >= 4)
+ {
+ mask = 0b11110000;
+ }
+
+ //
+ // WaveMultiPrefixSum.
+ // Results in hex: [0 1 3 7], [0 10 30 70]
+ //
+ uint sumValue = WaveMultiPrefixSum(1 << index, mask);
+ const uint sumBaseIndex = 0;
+ outputBuffer[sumBaseIndex + index] = sumValue;
+
+ //
+ // WaveMultiPrefixProduct.
+ // Results in hex: [1 1 2 8], [1 10 200 8000]
+ //
+ uint productValue = WaveMultiPrefixProduct(1 << index, mask);
+ const uint productBaseIndex = 8;
+ outputBuffer[productBaseIndex + index] = productValue;
+
+ //
+ // WaveMultiPrefixBitAnd.
+ // This prefix operation starts with all bits set.
+ // Results in hex: [FFFFFFFF 1 1 1], [FFFFFFFF F F F]
+ //
+ uint andBits = 0b1;
+ if (index >= 4)
+ {
+ andBits = 0b1111;
+ }
+ uint andValue = WaveMultiPrefixBitAnd(andBits, mask);
+ const uint andBaseIndex = 16;
+ outputBuffer[andBaseIndex + index] = andValue;
+
+ //
+ // WaveMultiPrefixBitOr.
+ // Results in hex: [0 1 3 7], [0 10 30 70]
+ //
+ uint orValue = WaveMultiPrefixBitOr(1 << index, mask);
+ const uint orBaseIndex = 24;
+ outputBuffer[orBaseIndex + index] = orValue;
+
+ //
+ // WaveMultiPrefixBitXor.
+ // Results in hex: [0 1 3 7], [0 F 0 F]
+ //
+ uint xorBits = (1 << index);
+ if (index >= 4)
+ {
+ xorBits = 0b1111;
+ }
+ uint xorValue = WaveMultiPrefixBitXor(xorBits, mask);
+ const uint xorBaseIndex = 32;
+ outputBuffer[xorBaseIndex + index] = xorValue;
+}
diff --git a/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang.expected.txt b/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang.expected.txt
new file mode 100644
index 000000000..c80baa5b1
--- /dev/null
+++ b/tests/hlsl-intrinsic/wave-multi-prefix-scalar-functional.slang.expected.txt
@@ -0,0 +1,40 @@
+0
+1
+3
+7
+0
+10
+30
+70
+1
+1
+2
+8
+1
+10
+200
+8000
+FFFFFFFF
+1
+1
+1
+FFFFFFFF
+F
+F
+F
+0
+1
+3
+7
+0
+10
+30
+70
+0
+1
+3
+7
+0
+F
+0
+F
diff --git a/tests/hlsl-intrinsic/wave-multi-prefix.slang b/tests/hlsl-intrinsic/wave-multi-prefix.slang
index 31dde2af4..99698e497 100644
--- a/tests/hlsl-intrinsic/wave-multi-prefix.slang
+++ b/tests/hlsl-intrinsic/wave-multi-prefix.slang
@@ -1,27 +1,146 @@
-//TEST_CATEGORY(wave, compute)
-//DISABLE_TEST:COMPARE_COMPUTE_EX:-cpu -compute -shaderobj
-//DISABLE_TEST:COMPARE_COMPUTE_EX:-slang -compute -shaderobj
-// We need SM6.5 for these tests
-// Disable because version of dxc we are currently using doesn't support SM6.5
-//DISABLE_TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile sm_6_5 -shaderobj
-// Disabled because we don't have GLSL intrinsics for these it seems
-//DISABLE_TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj
-//TEST:COMPARE_COMPUTE_EX:-cuda -compute -render-features cuda_sm_7_0 -shaderobj
-
-//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
-RWStructuredBuffer<int> outputBuffer;
-
-[numthreads(8, 1, 1)]
-void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+//TEST:SIMPLE(filecheck=CHECK_SPIRV): -stage compute -entry computeMain -target spirv -DNO_INTEGER_MATRIX
+//TEST:SIMPLE(filecheck=CHECK_GLSL): -stage compute -entry computeMain -target glsl -DNO_INTEGER_MATRIX
+//TEST:SIMPLE(filecheck=CHECK_CUDA): -stage compute -entry computeMain -target cuda
+//TEST:SIMPLE(filecheck=CHECK_HLSL): -stage compute -entry computeMain -target hlsl
+
+//
+// Tests all variants and overloads of WaveMultiPrefix* arithmetic intrinsics.
+//
+
+struct OutputData
+{
+ int scalarSum;
+ int scalarProduct;
+ int scalarBitAnd;
+ int scalarBitOr;
+ int scalarBitXor;
+ int vectorSum;
+ int vectorProduct;
+ int vectorBitAnd;
+ int vectorBitOr;
+ int vectorBitXor;
+ int matrixSum;
+ int matrixProduct;
+ int matrixBitAnd;
+ int matrixBitOr;
+ int matrixBitXor;
+ float floatScalarSum;
+ float floatScalarProduct;
+ float floatVectorSum;
+ float floatVectorProduct;
+ float floatMatrixSum;
+ float floatMatrixProduct;
+};
+
+RWStructuredBuffer<OutputData> outputBuffer;
+
+// CHECK_SPIRV: OpCapability GroupNonUniformPartitionedNV
+// CHECK_SPIRV: OpExtension "SPV_NV_shader_subgroup_partitioned"
+// CHECK_SPIRV: OpGroupNonUniformIAdd{{.*}}PartitionedExclusiveScanNV
+// CHECK_SPIRV: OpGroupNonUniformIMul{{.*}}PartitionedExclusiveScanNV
+// CHECK_SPIRV: OpGroupNonUniformBitwiseAnd{{.*}}PartitionedExclusiveScanNV
+// CHECK_SPIRV: OpGroupNonUniformBitwiseOr{{.*}}PartitionedExclusiveScanNV
+// CHECK_SPIRV: OpGroupNonUniformBitwiseXor{{.*}}PartitionedExclusiveScanNV
+// CHECK_SPIRV: OpGroupNonUniformFAdd{{.*}}PartitionedExclusiveScanNV
+
+// CHECK_GLSL: GL_NV_shader_subgroup_partitioned
+// CHECK_GLSL: subgroupPartitionedExclusiveAddNV
+// CHECK_GLSL: subgroupPartitionedExclusiveMulNV
+// CHECK_GLSL: subgroupPartitionedExclusiveAndNV
+// CHECK_GLSL: subgroupPartitionedExclusiveOrNV
+// CHECK_GLSL: subgroupPartitionedExclusiveXorNV
+
+// CHECK_CUDA: _wavePrefixSum
+// CHECK_CUDA: _wavePrefixProduct
+// CHECK_CUDA: _wavePrefixAnd
+// CHECK_CUDA: _wavePrefixOr
+// CHECK_CUDA: _wavePrefixXor
+// CHECK_CUDA: _wavePrefixSumMultiple
+// CHECK_CUDA: _wavePrefixProductMultiple
+// CHECK_CUDA: _wavePrefixAndMultiple
+// CHECK_CUDA: _wavePrefixOrMultiple
+// CHECK_CUDA: _wavePrefixXorMultiple
+
+// CHECK_HLSL: WaveMultiPrefixSum
+// CHECK_HLSL: WaveMultiPrefixProduct
+// CHECK_HLSL: WaveMultiPrefixBitAnd
+// CHECK_HLSL: WaveMultiPrefixBitOr
+// CHECK_HLSL: WaveMultiPrefixBitXor
+
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dTid : SV_DispatchThreadID)
{
- int idx = int(dispatchThreadID.x);
-
- int value = 0;
-
- uint4 mask = WaveMatch(true);
-
- // Scalar
- value += WaveMultiPrefixSum(1 << idx, mask);
-
- outputBuffer[idx] = value;
-} \ No newline at end of file
+ int scalarVal = dTid.x;
+ uint4 mask = WaveMatch(scalarVal);
+
+ int scalarSum = WaveMultiPrefixSum(scalarVal, mask);
+ int scalarProduct = WaveMultiPrefixProduct(scalarVal, mask);
+ int scalarBitAnd = WaveMultiPrefixBitAnd(scalarVal, mask);
+ int scalarBitOr = WaveMultiPrefixBitOr(scalarVal, mask);
+ int scalarBitXor = WaveMultiPrefixBitXor(scalarVal, mask);
+
+ int3 vectorVal = int3(dTid.x, dTid.y, dTid.z);
+ int3 vectorSum = WaveMultiPrefixSum(vectorVal, mask);
+ int3 vectorProduct = WaveMultiPrefixProduct(vectorVal, mask);
+ int3 vectorBitAnd = WaveMultiPrefixBitAnd(vectorVal, mask);
+ int3 vectorBitOr = WaveMultiPrefixBitOr(vectorVal, mask);
+ int3 vectorBitXor = WaveMultiPrefixBitXor(vectorVal, mask);
+
+ float floatScalarVal = float(dTid.x) + 0.5f; // Example floating-point scalar value
+ uint4 floatMask = WaveMatch(floatScalarVal); // Create a mask for matching lanes
+
+ float floatScalarSum = WaveMultiPrefixSum(floatScalarVal, floatMask);
+ float floatScalarProduct = WaveMultiPrefixProduct(floatScalarVal, floatMask);
+
+ float3 floatVectorVal = float3(dTid.x, dTid.y, dTid.z) + 0.5f; // Example floating-point vector value
+ float3 floatVectorSum = WaveMultiPrefixSum(floatVectorVal, floatMask);
+ float3 floatVectorProduct = WaveMultiPrefixProduct(floatVectorVal, floatMask);
+
+ OutputData output;
+ output.scalarSum = scalarSum;
+ output.scalarProduct = scalarProduct;
+ output.scalarBitAnd = scalarBitAnd;
+ output.scalarBitOr = scalarBitOr;
+ output.scalarBitXor = scalarBitXor;
+ output.vectorSum = vectorSum.x;
+ output.vectorProduct = vectorProduct.x;
+ output.vectorBitAnd = vectorBitAnd.x;
+ output.vectorBitOr = vectorBitOr.x;
+ output.vectorBitXor = vectorBitXor.x;
+ output.floatScalarSum = floatScalarSum;
+ output.floatScalarProduct = floatScalarProduct;
+ output.floatVectorSum = floatVectorSum.x;
+ output.floatVectorProduct = floatVectorProduct.x;
+
+ float3x3 floatMatrixVal = float3x3(
+ float(dTid.x) + 0.5f, float(dTid.y) + 0.5f, float(dTid.z) + 0.5f,
+ float(dTid.z) + 0.5f, float(dTid.x) + 0.5f, float(dTid.y) + 0.5f,
+ float(dTid.y) + 0.5f, float(dTid.z) + 0.5f, float(dTid.x) + 0.5f
+ );
+ float3x3 floatMatrixSum = WaveMultiPrefixSum(floatMatrixVal, floatMask);
+ float3x3 floatMatrixProduct = WaveMultiPrefixProduct(floatMatrixVal, floatMask);
+ output.floatMatrixSum = floatMatrixSum[0][0];
+ output.floatMatrixProduct = floatMatrixProduct[0][0];
+
+#if !defined(NO_INTEGER_MATRIX)
+ int3x3 matrixVal = int3x3(
+ dTid.x, dTid.y, dTid.z,
+ dTid.z, dTid.x, dTid.y,
+ dTid.y, dTid.z, dTid.x
+ );
+ int3x3 matrixSum = WaveMultiPrefixSum(matrixVal, mask);
+ int3x3 matrixProduct = WaveMultiPrefixProduct(matrixVal, mask);
+ int3x3 matrixBitAnd = WaveMultiPrefixBitAnd(matrixVal, mask);
+ int3x3 matrixBitOr = WaveMultiPrefixBitOr(matrixVal, mask);
+ int3x3 matrixBitXor = WaveMultiPrefixBitXor(matrixVal, mask);
+ output.matrixSum = matrixSum[0][0];
+ output.matrixProduct = matrixProduct[0][0];
+ output.matrixBitAnd = matrixBitAnd[0][0];
+ output.matrixBitOr = matrixBitOr[0][0];
+ output.matrixBitXor = matrixBitXor[0][0];
+#endif
+
+ outputBuffer[dTid.x] = output;
+}
+
diff --git a/tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt b/tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt
deleted file mode 100644
index 6ec6deeea..000000000
--- a/tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt
+++ /dev/null
@@ -1,8 +0,0 @@
-0
-1
-3
-7
-F
-1F
-3F
-7F
diff --git a/tests/language-feature/capability/capabilitySimplification1.slang b/tests/language-feature/capability/capabilitySimplification1.slang
index b694673e9..1d781a45e 100644
--- a/tests/language-feature/capability/capabilitySimplification1.slang
+++ b/tests/language-feature/capability/capabilitySimplification1.slang
@@ -6,9 +6,9 @@
// CHECK: error 36107
// CHECK-SAME: entrypoint 'computeMain' does not support compilation target 'glsl' with stage 'compute'
-// CHECK: capabilitySimplification1.slang(21): note: see using of 'WaveMultiPrefixProduct'
-// CHECK-NOT: see using of 'WaveMultiPrefixProduct'
-// CHECK: {{.*}}.meta.slang({{.*}}): note: see definition of 'WaveMultiPrefixProduct'
+// CHECK: capabilitySimplification1.slang(21): note: see using of 'WaveMultiPrefixCountBits'
+// CHECK-NOT: see using of 'WaveMultiPrefixCountBits'
+// CHECK: {{.*}}.meta.slang({{.*}}): note: see definition of 'WaveMultiPrefixCountBits'
// CHECK: {{.*}}.meta.slang({{.*}}): note: see declaration of 'require'
void nestedSafeCall()
@@ -18,7 +18,7 @@ void nestedSafeCall()
void nestedBadCall()
{
- WaveMultiPrefixProduct(1, 0);
+ WaveMultiPrefixCountBits(true, 0);
}
void nestedCall()
diff --git a/tests/language-feature/capability/capabilitySimplification3.slang b/tests/language-feature/capability/capabilitySimplification3.slang
index faf161d15..808c19bf6 100644
--- a/tests/language-feature/capability/capabilitySimplification3.slang
+++ b/tests/language-feature/capability/capabilitySimplification3.slang
@@ -5,13 +5,13 @@
// CHECK_IGNORE_CAPS-NOT: error 36107
// CHECK: error 36107: entrypoint 'computeMain' does not support compilation target 'glsl' with stage 'compute'
-// CHECK: capabilitySimplification3.slang(16): note: see using of 'WaveMultiPrefixProduct'
-// CHECK-NOT: see using of 'WaveMultiPrefixProduct'
-// CHECK: {{.*}}.meta.slang({{.*}}): note: see definition of 'WaveMultiPrefixProduct'
+// CHECK: capabilitySimplification3.slang(16): note: see using of 'WaveMultiPrefixCountBits'
+// CHECK-NOT: see using of 'WaveMultiPrefixCountBits'
+// CHECK: {{.*}}.meta.slang({{.*}}): note: see definition of 'WaveMultiPrefixCountBits'
// CHECK: {{.*}}.meta.slang({{.*}}): note: see declaration of 'require'
[numthreads(1,1,1)]
void computeMain()
{
- WaveMultiPrefixProduct(1, 0);
+ WaveMultiPrefixCountBits(true, 0);
}