From 1c282b80b9fbcfea9dc3dab7f5f546b069143e01 Mon Sep 17 00:00:00 2001 From: Darren Wihandi <65404740+fairywreath@users.noreply.github.com> Date: Tue, 28 Jan 2025 23:12:51 -0500 Subject: Implement WaveMultiPrefix* for SPIRV and GLSL (#6182) --- source/slang/hlsl.meta.slang | 250 ++++++++++++++++++++++++++++----- source/slang/slang-capabilities.capdef | 4 +- 2 files changed, 216 insertions(+), 38 deletions(-) (limited to 'source') diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index ba5c95a0c..1853a82b6 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -15517,7 +15517,7 @@ uint4 WaveMatch(matrix value) } /// @category wave -[require(cuda_hlsl, waveprefix)] +[require(cuda_hlsl, wave_multi_prefix)] uint WaveMultiPrefixCountBits(bool value, uint4 mask) { __target_switch @@ -15528,190 +15528,366 @@ uint WaveMultiPrefixCountBits(bool value, uint4 mask) } /// @category wave -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__generic +__glsl_extension(GL_NV_shader_subgroup_partitioned) __spirv_version(1.3) -[require(cuda_glsl_hlsl, waveprefix)] +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] T WaveMultiPrefixBitAnd(T expr, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixAnd(_getMultiPrefixMask(($1).x), $0)"; - case glsl: __intrinsic_asm "subgroupExclusiveAnd($0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveAndNV"; + case spirv: + return spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + result:$$T = OpGroupNonUniformBitwiseAnd Subgroup PartitionedExclusiveScanNV $expr $mask + }; } } -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__generic +__glsl_extension(GL_NV_shader_subgroup_partitioned) __spirv_version(1.3) -__generic -[require(cuda_glsl_hlsl, waveprefix)] +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] vector WaveMultiPrefixBitAnd(vector expr, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixAndMultiple(_getMultiPrefixMask(($1).x), $0)"; - case glsl: __intrinsic_asm "subgroupExclusiveAnd($0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveAndNV"; + case spirv: + return spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + result:$$vector = OpGroupNonUniformBitwiseAnd Subgroup PartitionedExclusiveScanNV $expr $mask + }; } } -__generic -[require(cuda_hlsl, waveprefix)] +__generic +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] matrix WaveMultiPrefixBitAnd(matrix expr, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixAndMultiple(_getMultiPrefixMask(($1).x), $0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd"; + case glsl: + case spirv: + matrix result; + for (int i = 0; i < N; ++i) + result[i] = WaveMultiPrefixBitAnd(expr[i], mask); + return result; } } /// @category wave -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__generic +__glsl_extension(GL_NV_shader_subgroup_partitioned) __spirv_version(1.3) -[require(cuda_glsl_hlsl, waveprefix)] +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] T WaveMultiPrefixBitOr(T expr, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixOr(, _getMultiPrefixMask(($1).x), $0)"; - case glsl: __intrinsic_asm "subgroupExclusiveOr($0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveOrNV"; + case spirv: + return spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + result:$$T = OpGroupNonUniformBitwiseOr Subgroup PartitionedExclusiveScanNV $expr $mask + }; } } -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__generic +__glsl_extension(GL_NV_shader_subgroup_partitioned) __spirv_version(1.3) -[require(cuda_glsl_hlsl, waveprefix)] +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] vector WaveMultiPrefixBitOr(vector expr, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixOrMultiple(_getMultiPrefixMask(($1).x), $0)"; - case glsl: __intrinsic_asm "subgroupExclusiveOr($0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveOrNV"; + case spirv: + return spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + result:$$vector = OpGroupNonUniformBitwiseOr Subgroup PartitionedExclusiveScanNV $expr $mask + }; } } -__generic -[require(cuda_hlsl, waveprefix)] +__generic +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] matrix WaveMultiPrefixBitOr(matrix expr, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixOrMultiple(_getMultiPrefixMask(($1).x), $0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr"; + case glsl: + case spirv: + matrix result; + for (int i = 0; i < N; ++i) + result[i] = WaveMultiPrefixBitOr(expr[i], mask); + return result; } } /// @category wave -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__generic +__glsl_extension(GL_NV_shader_subgroup_partitioned) __spirv_version(1.3) -[require(cuda_glsl_hlsl, waveprefix)] +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] T WaveMultiPrefixBitXor(T expr, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixXor(_getMultiPrefixMask(($1).x), $0)"; - case glsl: __intrinsic_asm "subgroupExclusiveXor($0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveXorNV"; + case spirv: + return spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + result:$$T = OpGroupNonUniformBitwiseXor Subgroup PartitionedExclusiveScanNV $expr $mask + }; } } -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__generic +__glsl_extension(GL_NV_shader_subgroup_partitioned) __spirv_version(1.3) -[require(cuda_glsl_hlsl, waveprefix)] +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] vector WaveMultiPrefixBitXor(vector expr, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixXorMultiple(_getMultiPrefixMask(($1).x), $0)"; - case glsl: __intrinsic_asm "subgroupExclusiveXor($0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveXorNV"; + case spirv: + return spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + result:$$vector = OpGroupNonUniformBitwiseXor Subgroup PartitionedExclusiveScanNV $expr $mask + }; } } -__generic -[require(cuda_hlsl, waveprefix)] +__generic +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] matrix WaveMultiPrefixBitXor(matrix expr, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixXorMultiple(_getMultiPrefixMask(($1).x), $0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor"; + case glsl: + case spirv: + matrix result; + for (int i = 0; i < N; ++i) + result[i] = WaveMultiPrefixBitXor(expr[i], mask); + return result; } } /// @category wave __generic -[require(cuda_hlsl, waveprefix)] +__glsl_extension(GL_NV_shader_subgroup_partitioned) +__spirv_version(1.3) +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] T WaveMultiPrefixProduct(T value, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixProduct(_getMultiPrefixMask(($1).x), $0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixProduct"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveMulNV"; + case spirv: + { + spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + }; + + if (__isFloat()) + { + return spirv_asm + { + result:$$T = OpGroupNonUniformFMul Subgroup PartitionedExclusiveScanNV $value $mask + }; + } + else + { + return spirv_asm + { + result:$$T = OpGroupNonUniformIMul Subgroup PartitionedExclusiveScanNV $value $mask + }; + } + } } } __generic -[require(cuda_hlsl, waveprefix)] +__glsl_extension(GL_NV_shader_subgroup_partitioned) +__spirv_version(1.3) +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] vector WaveMultiPrefixProduct(vector value, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixProductMultiple(_getMultiPrefixMask(($1).x), $0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixProduct"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveMulNV"; + case spirv: + { + spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + }; + + if (__isFloat()) + { + return spirv_asm + { + result:$$vector = OpGroupNonUniformFMul Subgroup PartitionedExclusiveScanNV $value $mask + }; + } + else + { + return spirv_asm + { + result:$$vector = OpGroupNonUniformIMul Subgroup PartitionedExclusiveScanNV $value $mask + }; + } + } } } __generic -[require(cuda_hlsl, waveprefix)] +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] matrix WaveMultiPrefixProduct(matrix value, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixProductMultiple(_getMultiPrefixMask(($1).x), $0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixProduct"; + case glsl: + case spirv: + matrix result; + for (int i = 0; i < N; ++i) + result[i] = WaveMultiPrefixProduct(value[i], mask); + return result; } } /// @category wave __generic -[require(cuda_hlsl, waveprefix)] +__glsl_extension(GL_NV_shader_subgroup_partitioned) +__spirv_version(1.3) +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] T WaveMultiPrefixSum(T value, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixSum(_getMultiPrefixMask(($1).x), $0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixSum"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveAddNV"; + case spirv: + { + spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + }; + + if (__isFloat()) + { + return spirv_asm + { + result:$$T = OpGroupNonUniformFAdd Subgroup PartitionedExclusiveScanNV $value $mask + }; + } + else + { + return spirv_asm + { + result:$$T = OpGroupNonUniformIAdd Subgroup PartitionedExclusiveScanNV $value $mask + }; + } + } } } __generic -[require(cuda_hlsl, waveprefix)] +__glsl_extension(GL_NV_shader_subgroup_partitioned) +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] +__spirv_version(1.3) vector WaveMultiPrefixSum(vector value, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixSumMultiple(_getMultiPrefixMask(($1).x), $0 )"; case hlsl: __intrinsic_asm "WaveMultiPrefixSum"; + case glsl: __intrinsic_asm "subgroupPartitionedExclusiveAddNV"; + case spirv: + { + spirv_asm + { + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpCapability GroupNonUniformPartitionedNV; + }; + + if (__isFloat()) + { + return spirv_asm + { + result:$$vector = OpGroupNonUniformFAdd Subgroup PartitionedExclusiveScanNV $value $mask + }; + } + else + { + return spirv_asm + { + result:$$vector = OpGroupNonUniformIAdd Subgroup PartitionedExclusiveScanNV $value $mask + }; + } + } } } __generic -[require(cuda_hlsl, waveprefix)] +[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] matrix WaveMultiPrefixSum(matrix value, uint4 mask) { __target_switch { case cuda: __intrinsic_asm "_wavePrefixSumMultiple(_getMultiPrefixMask(($1).x), $0)"; case hlsl: __intrinsic_asm "WaveMultiPrefixSum"; + case glsl: + case spirv: + matrix result; + for (int i = 0; i < N; ++i) + result[i] = WaveMultiPrefixSum(value[i], mask); + return result; } } diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index 4f6357779..3bc54c080 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -1024,7 +1024,9 @@ alias fragmentshaderbarycentric = GL_EXT_fragment_shader_barycentric | _sm_6_1; alias shadermemorycontrol = glsl | _spirv_1_0 | _sm_5_0; /// Capabilities needed to use HLSL tier wave operations /// [Compound] -alias waveprefix = _sm_6_5 | _cuda_sm_7_0 | GL_KHR_shader_subgroup_arithmetic; +alias wave_multi_prefix = _sm_6_5 + | _cuda_sm_7_0 + | GL_KHR_shader_subgroup_ballot + GL_KHR_shader_subgroup_arithmetic + GL_NV_shader_subgroup_partitioned; /// Capabilities needed to use GLSL buffer-reference's /// [Compound] alias bufferreference = GL_EXT_buffer_reference; -- cgit v1.2.3