From 0476b57faad96bee61f59f27ddd48c6cb067cfa2 Mon Sep 17 00:00:00 2001 From: Darren Wihandi <65404740+fairywreath@users.noreply.github.com> Date: Sun, 25 May 2025 12:58:08 -0400 Subject: Add full support for SPV_NV_shader_subgroup_partitioned (#7103) * Properly implement WaveMask* variants of WaveMultiPrefix* intrinsics * More partitioned intrinsics * More partitioned intrinsics and cleaned up non-prefixed WaveMask* implementations * Refactor HLSL WaveMultiPrefix* implementations * fix cap atoms * Clean up implementation * Add GLSL intrinsics and cleanup * Add tests * Fix affected capability test * Update and fix tests * Move expected.txt file * Refactor WaveMask* to call WaveMulti* * Refactor SPIRV/GLSL preamble code * Enable emit-via-glsl tests * remove wave_multi_prefix capability in favor of subgroup_partitioned * Update docs * Update cap atoms doc --- source/slang/glsl.meta.slang | 351 ++++++++ source/slang/hlsl.meta.slang | 1470 ++++++++++++++++---------------- source/slang/slang-capabilities.capdef | 9 +- 3 files changed, 1080 insertions(+), 750 deletions(-) (limited to 'source') diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang index 588396251..88c90a777 100644 --- a/source/slang/glsl.meta.slang +++ b/source/slang/glsl.meta.slang @@ -8280,6 +8280,7 @@ public vector subgroupQuadSwapDiagonal(vector value) // GL_KHR_shader_subgroup_rotate __generic +[ForceInline] [require(glsl_metal_spirv, subgroup_rotate)] public T subgroupRotate(T value, uint delta) { @@ -8287,6 +8288,7 @@ public T subgroupRotate(T value, uint delta) } __generic +[ForceInline] [require(glsl_metal_spirv, subgroup_rotate)] public vector subgroupRotate(vector value, uint delta) { @@ -8294,6 +8296,7 @@ public vector subgroupRotate(vector value, uint delta) } __generic +[ForceInline] [require(glsl_spirv, subgroup_rotate)] public T subgroupClusteredRotate(T value, uint delta, constexpr uint clusterSize) { @@ -8302,12 +8305,360 @@ public T subgroupClusteredRotate(T value, uint delta, constexpr uint clusterSize } __generic +[ForceInline] [require(glsl_spirv, subgroup_rotate)] public vector subgroupClusteredRotate(vector value, uint delta, constexpr uint clusterSize) { return WaveClusteredRotate(value, delta, clusterSize); } + +// GL_NV_shader_subgroup_partitioned + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedAddNV(T value, uvec4 ballot) +{ + return WaveMultiSum(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedAddNV(vector value, uvec4 ballot) +{ + return WaveMultiSum(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedMulNV(T value, uvec4 ballot) +{ + return WaveMultiProduct(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedMulNV(vector value, uvec4 ballot) +{ + return WaveMultiProduct(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedMinNV(T value, uvec4 ballot) +{ + return WaveMultiMin(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedMinNV(vector value, uvec4 ballot) +{ + return WaveMultiMin(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedMaxNV(T value, uvec4 ballot) +{ + return WaveMultiMax(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedMaxNV(vector value, uvec4 ballot) +{ + return WaveMultiMax(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedAndNV(T value, uvec4 ballot) +{ + return WaveMultiBitAnd(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedAndNV(vector value, uvec4 ballot) +{ + return WaveMultiBitAnd(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedOrNV(T value, uvec4 ballot) +{ + return WaveMultiBitOr(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedOrNV(vector value, uvec4 ballot) +{ + return WaveMultiBitOr(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedXorNV(T value, uvec4 ballot) +{ + return WaveMultiBitXor(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedXorNV(vector value, uvec4 ballot) +{ + return WaveMultiBitXor(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedInclusiveAddNV(T value, uvec4 ballot) +{ + return WaveMultiPrefixInclusiveSum(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedInclusiveAddNV(vector value, uvec4 ballot) +{ + return WaveMultiPrefixInclusiveSum(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedInclusiveMulNV(T value, uvec4 ballot) +{ + return WaveMultiPrefixInclusiveProduct(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedInclusiveMulNV(vector value, uvec4 ballot) +{ + return WaveMultiPrefixInclusiveProduct(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedInclusiveMinNV(T value, uvec4 ballot) +{ + return WaveMultiPrefixInclusiveMin(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedInclusiveMinNV(vector value, uvec4 ballot) +{ + return WaveMultiPrefixInclusiveMin(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedInclusiveMaxNV(T value, uvec4 ballot) +{ + return WaveMultiPrefixInclusiveMax(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedInclusiveMaxNV(vector value, uvec4 ballot) +{ + return WaveMultiPrefixInclusiveMax(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedInclusiveAndNV(T value, uvec4 ballot) +{ + return WaveMultiPrefixInclusiveBitAnd(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedInclusiveAndNV(vector value, uvec4 ballot) +{ + return WaveMultiPrefixInclusiveBitAnd(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedInclusiveOrNV(T value, uvec4 ballot) +{ + return WaveMultiPrefixInclusiveBitOr(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedInclusiveOrNV(vector value, uvec4 ballot) +{ + return WaveMultiPrefixInclusiveBitOr(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedInclusiveXorNV(T value, uvec4 ballot) +{ + return WaveMultiPrefixInclusiveBitXor(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedInclusiveXorNV(vector value, uvec4 ballot) +{ + return WaveMultiPrefixInclusiveBitXor(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedExclusiveAddNV(T value, uvec4 ballot) +{ + return WaveMultiPrefixExclusiveSum(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedExclusiveAddNV(vector value, uvec4 ballot) +{ + return WaveMultiPrefixExclusiveSum(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedExclusiveMulNV(T value, uvec4 ballot) +{ + return WaveMultiPrefixExclusiveProduct(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedExclusiveMulNV(vector value, uvec4 ballot) +{ + return WaveMultiPrefixExclusiveProduct(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedExclusiveMinNV(T value, uvec4 ballot) +{ + return WaveMultiPrefixExclusiveMin(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedExclusiveMinNV(vector value, uvec4 ballot) +{ + return WaveMultiPrefixExclusiveMin(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedExclusiveMaxNV(T value, uvec4 ballot) +{ + return WaveMultiPrefixExclusiveMax(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedExclusiveMaxNV(vector value, uvec4 ballot) +{ + return WaveMultiPrefixExclusiveMax(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedExclusiveAndNV(T value, uvec4 ballot) +{ + return WaveMultiPrefixExclusiveBitAnd(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedExclusiveAndNV(vector value, uvec4 ballot) +{ + return WaveMultiPrefixExclusiveBitAnd(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedExclusiveOrNV(T value, uvec4 ballot) +{ + return WaveMultiPrefixExclusiveBitOr(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedExclusiveOrNV(vector value, uvec4 ballot) +{ + return WaveMultiPrefixExclusiveBitOr(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public T subgroupPartitionedExclusiveXorNV(T value, uvec4 ballot) +{ + return WaveMultiPrefixExclusiveBitXor(value, ballot); +} + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public vector subgroupPartitionedExclusiveXorNV(vector value, uvec4 ballot) +{ + return WaveMultiPrefixExclusiveBitXor(value, ballot); +} + +__generic +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +public uvec4 subgroupPartitionNV(T value) +{ + return WaveMatch(value); +} + + //// GLSL atomic // The following type internally is a Shader Storage Buffer diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 87f98adaf..cb050dd51 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -14138,382 +14138,294 @@ uint WaveMaskPrefixCountBits(WaveMask mask, bool value) // Across lane ops -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] T WaveMaskBitAnd(WaveMask mask, T expr) { __target_switch { - case glsl: __intrinsic_asm "subgroupAnd($1)"; - case cuda: __intrinsic_asm "_waveAnd($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveBitAnd($1)"; - case spirv: - return spirv_asm { - OpCapability GroupNonUniformArithmetic; - OpGroupNonUniformBitwiseAnd $$T result Subgroup 0 $expr - }; + case hlsl: + __intrinsic_asm "WaveActiveBitAnd($1)"; + default: + return WaveMultiBitAnd(expr, uint4(mask, 0, 0, 0)); } } -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] vector WaveMaskBitAnd(WaveMask mask, vector expr) { __target_switch { - case glsl: __intrinsic_asm "subgroupAnd($1)"; - case cuda: __intrinsic_asm "_waveAndMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveBitAnd($1)"; - case spirv: - return spirv_asm { - OpCapability GroupNonUniformArithmetic; - OpGroupNonUniformBitwiseAnd $$vector result Subgroup 0 $expr - }; + case hlsl: + __intrinsic_asm "WaveActiveBitAnd($1)"; + default: + return WaveMultiBitAnd(expr, uint4(mask, 0, 0, 0)); } } -__generic -[require(cuda_hlsl, subgroup_arithmetic)] +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] matrix WaveMaskBitAnd(WaveMask mask, matrix expr) { __target_switch { - case cuda: __intrinsic_asm "_waveAndMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveBitAnd($1)"; + case hlsl: + __intrinsic_asm "WaveActiveBitAnd($1)"; + default: + return WaveMultiBitAnd(expr, uint4(mask, 0, 0, 0)); } } -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] T WaveMaskBitOr(WaveMask mask, T expr) { __target_switch { - case glsl: __intrinsic_asm "subgroupOr($1)"; - case cuda: __intrinsic_asm "_waveOr($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveBitOr($1)"; - case spirv: - return spirv_asm { - OpCapability GroupNonUniformArithmetic; - OpGroupNonUniformBitwiseOr $$T result Subgroup 0 $expr - }; + case hlsl: + __intrinsic_asm "WaveActiveBitOr($1)"; + default: + return WaveMultiBitOr(expr, uint4(mask, 0, 0, 0)); } } -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__generic +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +[ForceInline] vector WaveMaskBitOr(WaveMask mask, vector expr) { __target_switch { - case glsl: __intrinsic_asm "subgroupOr($1)"; - case cuda: __intrinsic_asm "_waveOrMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveBitOr($1)"; - case spirv: - return spirv_asm { - OpCapability GroupNonUniformArithmetic; - OpGroupNonUniformBitwiseOr $$vector result Subgroup 0 $expr - }; + case hlsl: + __intrinsic_asm "WaveActiveBitOr($1)"; + default: + return WaveMultiBitOr(expr, uint4(mask, 0, 0, 0)); } } -__generic -[require(cuda_hlsl, subgroup_arithmetic)] + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] matrix WaveMaskBitOr(WaveMask mask, matrix expr) { __target_switch { - case cuda: __intrinsic_asm "_waveOrMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveBitOr($1)"; + case hlsl: + __intrinsic_asm "WaveActiveBitOr($1)"; + default: + return WaveMultiBitOr(expr, uint4(mask, 0, 0, 0)); } } -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] T WaveMaskBitXor(WaveMask mask, T expr) { __target_switch { - case glsl: __intrinsic_asm "subgroupXor($1)"; - case cuda: __intrinsic_asm "_waveXor($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveBitXor($1)"; - case spirv: - return spirv_asm { - OpCapability GroupNonUniformArithmetic; - OpGroupNonUniformBitwiseXor $$T result Subgroup 0 $expr - }; + case hlsl: + __intrinsic_asm "WaveActiveBitXor($1)"; + default: + return WaveMultiBitXor(expr, uint4(mask, 0, 0, 0)); } } -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] vector WaveMaskBitXor(WaveMask mask, vector expr) { __target_switch { - case glsl: __intrinsic_asm "subgroupXor($1)"; - case cuda: __intrinsic_asm "_waveXorMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveBitXor($1)"; - case spirv: - return spirv_asm { - OpCapability GroupNonUniformArithmetic; - OpGroupNonUniformBitwiseXor $$vector result Subgroup 0 $expr - }; + case hlsl: + __intrinsic_asm "WaveActiveBitXor($1)"; + default: + return WaveMultiBitXor(expr, uint4(mask, 0, 0, 0)); } } -__generic -[require(cuda_hlsl, subgroup_arithmetic)] + +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] matrix WaveMaskBitXor(WaveMask mask, matrix expr) { __target_switch { - case cuda: __intrinsic_asm "_waveXorMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveBitXor($1)"; + case hlsl: + __intrinsic_asm "WaveActiveBitXor($1)"; + default: + return WaveMultiBitXor(expr, uint4(mask, 0, 0, 0)); } } __generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] T WaveMaskMax(WaveMask mask, T expr) { __target_switch { - case glsl: __intrinsic_asm "subgroupMax($1)"; - case cuda: __intrinsic_asm "_waveMax($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveMax($1)"; - case spirv: - if (__isFloat()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMax $$T result Subgroup 0 $expr}; - else if (__isSignedInt()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformSMax $$T result Subgroup 0 $expr}; - else if (__isUnsignedInt()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformUMax $$T result Subgroup 0 $expr}; - else return expr; + case hlsl: + __intrinsic_asm "WaveActiveMax($1)"; + default: + return WaveMultiMax(expr, uint4(mask, 0, 0, 0)); } } + __generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] vector WaveMaskMax(WaveMask mask, vector expr) { __target_switch { - case glsl: __intrinsic_asm "subgroupMax($1)"; - case cuda: __intrinsic_asm "_waveMaxMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveMax($1)"; - case spirv: - if (__isFloat()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMax $$vector result Subgroup 0 $expr}; - else if (__isSignedInt()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformSMax $$vector result Subgroup 0 $expr}; - else if (__isUnsignedInt()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformUMax $$vector result Subgroup 0 $expr}; - else return expr; + case hlsl: + __intrinsic_asm "WaveActiveMax($1)"; + default: + return WaveMultiMax(expr, uint4(mask, 0, 0, 0)); } } __generic -[require(cuda_hlsl, subgroup_arithmetic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] matrix WaveMaskMax(WaveMask mask, matrix expr) { __target_switch { - case cuda: __intrinsic_asm "_waveMaxMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveMax($1)"; + case hlsl: + __intrinsic_asm "WaveActiveMax($1)"; + default: + return WaveMultiMax(expr, uint4(mask, 0, 0, 0)); } } __generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] T WaveMaskMin(WaveMask mask, T expr) { __target_switch { - case glsl: __intrinsic_asm "subgroupMin($1)"; - case cuda: __intrinsic_asm "_waveMin($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveMin($1)"; - case spirv: - if (__isFloat()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMin $$T result Subgroup 0 $expr}; - else if (__isSignedInt()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformSMin $$T result Subgroup 0 $expr}; - else if (__isUnsignedInt()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformUMin $$T result Subgroup 0 $expr}; - else return expr; + case hlsl: + __intrinsic_asm "WaveActiveMin($1)"; + default: + return WaveMultiMin(expr, uint4(mask, 0, 0, 0)); } } __generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] vector WaveMaskMin(WaveMask mask, vector expr) { __target_switch { - case glsl: __intrinsic_asm "subgroupMin($1)"; - case cuda: __intrinsic_asm "_waveMinMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveMin($1)"; - case spirv: - if (__isFloat()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMin $$vector result Subgroup 0 $expr}; - else if (__isSignedInt()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformSMin $$vector result Subgroup 0 $expr}; - else if (__isUnsignedInt()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformUMin $$vector result Subgroup 0 $expr}; - else return expr; + case hlsl: + __intrinsic_asm "WaveActiveMin($1)"; + default: + return WaveMultiMin(expr, uint4(mask, 0, 0, 0)); } } __generic -[require(cuda_hlsl, subgroup_arithmetic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] matrix WaveMaskMin(WaveMask mask, matrix expr) { __target_switch { - case cuda: __intrinsic_asm "_waveMinMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveMin($1)"; + case hlsl: + __intrinsic_asm "WaveActiveMin($1)"; + default: + return WaveMultiMin(expr, uint4(mask, 0, 0, 0)); } } __generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] T WaveMaskProduct(WaveMask mask, T expr) { __target_switch { - case glsl: __intrinsic_asm "subgroupMul($1)"; - case cuda: __intrinsic_asm "_waveProduct($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveProduct($1)"; - case spirv: - if (__isFloat()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMul $$T result Subgroup 0 $expr}; - else if (__isInt()) - { - return spirv_asm - { - OpCapability GroupNonUniformArithmetic; - OpGroupNonUniformIMul $$T result Subgroup 0 $expr; - }; - } - else return expr; + case hlsl: + __intrinsic_asm "WaveActiveProduct($1)"; + default: + return WaveMultiProduct(expr, uint4(mask, 0, 0, 0)); } } __generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] vector WaveMaskProduct(WaveMask mask, vector expr) { __target_switch { - case glsl: __intrinsic_asm "subgroupMul($1)"; - case cuda: __intrinsic_asm "_waveProductMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveProduct($1)"; - case spirv: - if (__isFloat()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMul $$vector result Subgroup 0 $expr}; - else if (__isInt()) - { - return spirv_asm - { - OpCapability GroupNonUniformArithmetic; - OpGroupNonUniformIMul $$vector result Subgroup 0 $expr; - }; - } - else return expr; + case hlsl: + __intrinsic_asm "WaveActiveProduct($1)"; + default: + return WaveMultiProduct(expr, uint4(mask, 0, 0, 0)); } } __generic -[require(cuda_hlsl, subgroup_arithmetic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] matrix WaveMaskProduct(WaveMask mask, matrix expr) { __target_switch { - case cuda: __intrinsic_asm "_waveProductMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveProduct($1)"; + case hlsl: + __intrinsic_asm "WaveActiveProduct($1)"; + default: + return WaveMultiProduct(expr, uint4(mask, 0, 0, 0)); } } __generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] T WaveMaskSum(WaveMask mask, T expr) { __target_switch { - case glsl: - if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); - __intrinsic_asm "subgroupAdd($1)"; - case cuda: __intrinsic_asm "_waveSum($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveSum($1)"; - case spirv: - if (__isFloat()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$T result Subgroup 0 $expr}; - else if (__isInt()) - { - return spirv_asm - { - OpCapability GroupNonUniformArithmetic; - OpGroupNonUniformIAdd $$T result Subgroup 0 $expr; - }; - } - else return expr; + case hlsl: + __intrinsic_asm "WaveActiveSum($1)"; + default: + return WaveMultiSum(expr, uint4(mask, 0, 0, 0)); } } __generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] vector WaveMaskSum(WaveMask mask, vector expr) { __target_switch { - case glsl: - if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); - __intrinsic_asm "subgroupAdd($1)"; - case cuda: __intrinsic_asm "_waveSumMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveSum($1)"; - case spirv: - if (__isFloat()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$vector result Subgroup 0 $expr}; - else if (__isInt()) - { - return spirv_asm - { - OpCapability GroupNonUniformArithmetic; - OpGroupNonUniformIAdd $$vector result Subgroup 0 $expr; - }; - } - else return expr; + case hlsl: + __intrinsic_asm "WaveActiveSum($1)"; + default: + return WaveMultiSum(expr, uint4(mask, 0, 0, 0)); } } + __generic -[require(cuda_hlsl, subgroup_arithmetic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] matrix WaveMaskSum(WaveMask mask, matrix expr) { __target_switch { - case cuda: __intrinsic_asm "_waveSumMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveActiveSum($1)"; + case hlsl: + __intrinsic_asm "WaveActiveSum($1)"; + default: + return WaveMultiSum(expr, uint4(mask, 0, 0, 0)); } } @@ -14580,134 +14492,48 @@ bool WaveMaskAllEqual(WaveMask mask, matrix value) // Prefix __generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] T WaveMaskPrefixProduct(WaveMask mask, T expr) { - __target_switch - { - case glsl: - if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); - __intrinsic_asm "subgroupExclusiveMul($1)"; - case cuda: __intrinsic_asm "_wavePrefixProduct($0, $1)"; - case hlsl: __intrinsic_asm "WavePrefixProduct($1)"; - case spirv: - if (__isFloat()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMul $$T result Subgroup ExclusiveScan $expr}; - else if (__isInt()) - { - return spirv_asm - { - OpCapability GroupNonUniformArithmetic; - OpGroupNonUniformIMul $$T result Subgroup ExclusiveScan $expr; - }; - } - else return expr; - } + return WaveMultiPrefixProduct(expr, uint4(mask, 0, 0, 0)); } __generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] vector WaveMaskPrefixProduct(WaveMask mask, vector expr) { - __target_switch - { - case glsl: - if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); - __intrinsic_asm "subgroupExclusiveMul($1)"; - case cuda: __intrinsic_asm "_wavePrefixProductMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WavePrefixProduct($1)"; - case spirv: - if (__isFloat()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMul $$vector result Subgroup ExclusiveScan $expr}; - else if (__isInt()) - { - return spirv_asm - { - OpCapability GroupNonUniformArithmetic; - OpGroupNonUniformIMul $$vector result Subgroup ExclusiveScan $expr; - }; - } - else return expr; - } + return WaveMultiPrefixProduct(expr, uint4(mask, 0, 0, 0)); } __generic -[require(cuda_hlsl, subgroup_arithmetic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] matrix WaveMaskPrefixProduct(WaveMask mask, matrix expr) { - __target_switch - { - case cuda: __intrinsic_asm "_wavePrefixProductMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WavePrefixProduct($1)"; - } + return WaveMultiPrefixProduct(expr, uint4(mask, 0, 0, 0)); } __generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] T WaveMaskPrefixSum(WaveMask mask, T expr) { - __target_switch - { - case glsl: - if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); - __intrinsic_asm "subgroupExclusiveAdd($1)"; - case cuda: __intrinsic_asm "_wavePrefixSum($0, $1)"; - case hlsl: __intrinsic_asm "WavePrefixSum($1)"; - case spirv: - if (__isFloat()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$T result Subgroup ExclusiveScan $expr}; - else if (__isInt()) - { - return spirv_asm - { - OpCapability GroupNonUniformArithmetic; - result:$$T = OpGroupNonUniformIAdd Subgroup ExclusiveScan $expr; - }; - } - else return expr; - } + return WaveMultiPrefixSum(expr, uint4(mask, 0, 0, 0)); } __generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] vector WaveMaskPrefixSum(WaveMask mask, vector expr) { - __target_switch - { - case glsl: - if (__isHalf()) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); - __intrinsic_asm "subgroupExclusiveAdd($1)"; - case cuda: __intrinsic_asm "_wavePrefixSumMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WavePrefixSum($1)"; - case spirv: - if (__isFloat()) - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$vector result Subgroup ExclusiveScan $expr}; - else if (__isInt()) - { - return spirv_asm - { - OpCapability GroupNonUniformArithmetic; - result:$$vector = OpGroupNonUniformIAdd Subgroup ExclusiveScan $expr; - }; - } - else return expr; - } + return WaveMultiPrefixSum(expr, uint4(mask, 0, 0, 0)); } __generic -[require(cuda_hlsl, subgroup_arithmetic)] +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] matrix WaveMaskPrefixSum(WaveMask mask, matrix expr) { - __target_switch - { - case cuda: __intrinsic_asm "_wavePrefixSumMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WavePrefixSum($1)"; - } + return WaveMultiPrefixSum(expr, uint4(mask, 0, 0, 0)); } __generic @@ -14813,133 +14639,76 @@ WaveMask WaveMaskMatch(WaveMask mask, matrix value) } } -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] T WaveMaskPrefixBitAnd(WaveMask mask, T expr) { - __target_switch - { - case glsl: __intrinsic_asm "subgroupExclusiveAnd($1)"; - case cuda: __intrinsic_asm "_wavePrefixAnd($0, $1)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd($1, uint4($0, 0, 0, 0))"; - case spirv: - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformBitwiseAnd $$T result Subgroup ExclusiveScan $expr}; - } + return WaveMultiPrefixBitAnd(expr, uint4(mask, 0, 0, 0)); } -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] vector WaveMaskPrefixBitAnd(WaveMask mask, vector expr) { - __target_switch - { - case glsl: __intrinsic_asm "subgroupExclusiveAnd($1)"; - case cuda: __intrinsic_asm "_wavePrefixAndMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd($1, uint4($0, 0, 0, 0))"; - case spirv: - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformBitwiseAnd $$vector result Subgroup ExclusiveScan $expr}; - } + return WaveMultiPrefixBitAnd(expr, uint4(mask, 0, 0, 0)); } -__generic -[require(cuda_hlsl, subgroup_arithmetic)] +__generic +[ForceInline] +[require(cuda_hlsl, subgroup_partitioned)] matrix WaveMaskPrefixBitAnd(WaveMask mask, matrix expr) { - __target_switch - { - case cuda: __intrinsic_asm "_wavePrefixAndMultiple(_getMultiPrefixMask($0, $1)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd($1, uint4($0, 0, 0, 0))"; - } + return WaveMultiPrefixBitAnd(expr, uint4(mask, 0, 0, 0)); } -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] T WaveMaskPrefixBitOr(WaveMask mask, T expr) { - __target_switch - { - case glsl: __intrinsic_asm "subgroupExclusiveOr($1)"; - case cuda: __intrinsic_asm "_wavePrefixOr($0, $1)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr($1, uint4($0, 0, 0, 0))"; - case spirv: - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformBitwiseAnd $$T result Subgroup ExclusiveScan $expr}; - } + return WaveMultiPrefixBitOr(expr, uint4(mask, 0, 0, 0)); } -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] vector WaveMaskPrefixBitOr(WaveMask mask, vector expr) { - __target_switch - { - case glsl: __intrinsic_asm "subgroupExclusiveOr($1)"; - case cuda: __intrinsic_asm "_wavePrefixOrMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr($1, uint4($0, 0, 0, 0))"; - case spirv: - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformBitwiseOr $$vector result Subgroup ExclusiveScan $expr}; - } + return WaveMultiPrefixBitOr(expr, uint4(mask, 0, 0, 0)); } -__generic -[require(cuda_hlsl, subgroup_arithmetic)] +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] matrix WaveMaskPrefixBitOr(WaveMask mask, matrix expr) { - __target_switch - { - case cuda: __intrinsic_asm "_wavePrefixOrMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr($1, uint4($0, 0, 0, 0))"; - } + return WaveMultiPrefixBitOr(expr, uint4(mask, 0, 0, 0)); } -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] T WaveMaskPrefixBitXor(WaveMask mask, T expr) { - __target_switch - { - case glsl: __intrinsic_asm "subgroupExclusiveXor($1)"; - case cuda: __intrinsic_asm "_wavePrefixXor($0, $1)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor($1, uint4($0, 0, 0, 0))"; - case spirv: - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformBitwiseXor $$T result Subgroup ExclusiveScan $expr}; - } + return WaveMultiPrefixBitXor(expr, uint4(mask, 0, 0, 0)); } -__generic -__glsl_extension(GL_KHR_shader_subgroup_arithmetic) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] vector WaveMaskPrefixBitXor(WaveMask mask, vector expr) { - __target_switch - { - case glsl: __intrinsic_asm "subgroupExclusiveXor($1)"; - case cuda: __intrinsic_asm "_wavePrefixXorMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor($1, uint4($0, 0, 0, 0))"; - case spirv: - return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformBitwiseXor $$vector result Subgroup ExclusiveScan $expr}; - } + return WaveMultiPrefixBitOr(expr, uint4(mask, 0, 0, 0)); } -__generic -[require(cuda_hlsl, subgroup_arithmetic)] +__generic +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] matrix WaveMaskPrefixBitXor(WaveMask mask, matrix expr) { - __target_switch - { - case cuda: __intrinsic_asm "_wavePrefixXorMultiple($0, $1)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor($1, uint4($0, 0, 0, 0))"; - } + return WaveMultiPrefixBitOr(expr, uint4(mask, 0, 0, 0)); } //@public: @@ -15156,7 +14925,7 @@ const WaveActiveBitOpEntry kWaveActiveBitOpEntries[] = {{"BitAnd", "And", "Bitwi for (auto opName : kWaveActiveBitOpEntries) { }}}} /// @category wave Wave and quad functions -__generic +__generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) __wgsl_extension(subgroups) @@ -15179,7 +14948,7 @@ T WaveActive$(opName.hlslName)(T expr) } } -__generic +__generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) __wgsl_extension(subgroups) @@ -15202,7 +14971,7 @@ vector WaveActive$(opName.hlslName)(vector expr) } } -__generic +__generic [require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)] matrix WaveActive$(opName.hlslName)(matrix expr) { @@ -16238,7 +16007,7 @@ uint4 WaveMatch(matrix value) } /// @category wave -[require(cuda_hlsl, wave_multi_prefix)] +[require(cuda_hlsl, subgroup_partitioned)] uint WaveMultiPrefixCountBits(bool value, uint4 mask) { __target_switch @@ -16248,537 +16017,750 @@ uint WaveMultiPrefixCountBits(bool value, uint4 mask) } } -/// @category wave -__generic -__glsl_extension(GL_NV_shader_subgroup_partitioned) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] -T WaveMultiPrefixBitAnd(T expr, uint4 mask) +__glsl_extension(GL_EXT_demote_to_helper_invocation) +[ForceInline] +[require(glsl_hlsl_metal_spirv, helper_lane)] +bool IsHelperLane() { - __target_switch - { - case cuda: __intrinsic_asm "_wavePrefixAnd(_getMultiPrefixMask(($1).x), $0)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd"; - case glsl: __intrinsic_asm "subgroupPartitionedExclusiveAndNV"; + __target_switch { + case hlsl: __intrinsic_asm "IsHelperLane()"; + case glsl: __intrinsic_asm "gl_HelperInvocation"; + case metal: __intrinsic_asm "simd_is_helper_thread()"; case spirv: - return spirv_asm - { - OpExtension "SPV_NV_shader_subgroup_partitioned"; - OpCapability GroupNonUniformPartitionedNV; - result:$$T = OpGroupNonUniformBitwiseAnd Subgroup PartitionedExclusiveScanNV $expr $mask + return spirv_asm { + OpExtension "SPV_EXT_demote_to_helper_invocation"; + OpCapability DemoteToHelperInvocationEXT; + result:$$bool = OpIsHelperInvocationEXT }; } } -__generic -__glsl_extension(GL_NV_shader_subgroup_partitioned) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] -vector WaveMultiPrefixBitAnd(vector expr, uint4 mask) +//@hidden: + +__generic +[ForceInline] +[require(glsl)] +void __requireGLSLShaderSubgroupTypeExtension() { - __target_switch + // the following is a seperate function call, since else the `__requireTargetExtension` and associated __intrinsic_asm is ignored if the calling function also calls an __intrinsic_asm + if (__type_equals() + || __type_equals() + ) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); + else if (__type_equals() + || __type_equals() + ) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_int8"); + else if (__type_equals() + || __type_equals() + ) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_int16"); + else if (__type_equals() + || __type_equals() + ) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_int64"); + + __intrinsic_asm ""; +} + +__generic +[ForceInline] +[require(metal)] +void __checkMetalShaderSubgroupType() +{ + // These builtin types are not supported for Metal's `simd` operations. + if (__type_equals() + || __type_equals() + || __type_equals() + || __type_equals() + || __isBool() + ) { - case cuda: __intrinsic_asm "_wavePrefixAndMultiple(_getMultiPrefixMask(($1).x), $0)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd"; - case glsl: __intrinsic_asm "subgroupPartitionedExclusiveAndNV"; - case spirv: - return spirv_asm - { - OpExtension "SPV_NV_shader_subgroup_partitioned"; - OpCapability GroupNonUniformPartitionedNV; - result:$$vector = OpGroupNonUniformBitwiseAnd Subgroup PartitionedExclusiveScanNV $expr $mask - }; + static_assert(false, "Unsupported type for subgroup operations in Metal. Valid types include scalars and vectors of uint/uint32_t, int/int32_t, uint16_t, int16_t, float, and half."); } } -__generic -[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] -matrix WaveMultiPrefixBitAnd(matrix expr, uint4 mask) +__generic +[ForceInline] +void shader_subgroup_preamble() { + // checks needed for shader_subgroup functions; __requireTargetExtension does not work + // (does not add the ext specified correctly to the compile output; using extended type + // will result in error for using the type) __target_switch { - case cuda: __intrinsic_asm "_wavePrefixAndMultiple(_getMultiPrefixMask(($1).x), $0)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd"; case glsl: - case spirv: - matrix result; - for (int i = 0; i < N; ++i) - result[i] = WaveMultiPrefixBitAnd(expr[i], mask); - return result; + __requireGLSLShaderSubgroupTypeExtension(); + case metal: + __checkMetalShaderSubgroupType(); + default: + return; } } -/// @category wave -__generic -__glsl_extension(GL_NV_shader_subgroup_partitioned) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] -T WaveMultiPrefixBitOr(T expr, uint4 mask) +//@public: + +// +// Wave Rotate intrinsics. +// These are Slang specific intrinsics to rotate values within a subgroup. +// + +__generic +__glsl_extension(GL_KHR_shader_subgroup_rotate) +[require(glsl_metal_spirv, subgroup_rotate)] +T WaveRotate(T value, uint delta) { + shader_subgroup_preamble(); __target_switch { - case cuda: __intrinsic_asm "_wavePrefixOr(, _getMultiPrefixMask(($1).x), $0)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr"; - case glsl: __intrinsic_asm "subgroupPartitionedExclusiveOrNV"; + case glsl: + __intrinsic_asm "subgroupRotate"; + case metal: + __intrinsic_asm "simd_shuffle_rotate_down"; case spirv: return spirv_asm { - OpExtension "SPV_NV_shader_subgroup_partitioned"; - OpCapability GroupNonUniformPartitionedNV; - result:$$T = OpGroupNonUniformBitwiseOr Subgroup PartitionedExclusiveScanNV $expr $mask + OpExtension "SPV_KHR_subgroup_rotate"; + OpCapability GroupNonUniformRotateKHR; + result:$$T = OpGroupNonUniformRotateKHR Subgroup $value $delta; }; } } -__generic -__glsl_extension(GL_NV_shader_subgroup_partitioned) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] -vector WaveMultiPrefixBitOr(vector expr, uint4 mask) +__generic +__glsl_extension(GL_KHR_shader_subgroup_rotate) +[require(glsl_metal_spirv, subgroup_rotate)] +vector WaveRotate(vector value, uint delta) { + shader_subgroup_preamble(); __target_switch { - case cuda: __intrinsic_asm "_wavePrefixOrMultiple(_getMultiPrefixMask(($1).x), $0)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr"; - case glsl: __intrinsic_asm "subgroupPartitionedExclusiveOrNV"; + case glsl: + __intrinsic_asm "subgroupRotate"; + case metal: + __intrinsic_asm "simd_shuffle_rotate_down"; case spirv: return spirv_asm { - OpExtension "SPV_NV_shader_subgroup_partitioned"; - OpCapability GroupNonUniformPartitionedNV; - result:$$vector = OpGroupNonUniformBitwiseOr Subgroup PartitionedExclusiveScanNV $expr $mask + OpExtension "SPV_KHR_subgroup_rotate"; + OpCapability GroupNonUniformRotateKHR; + result:$$vector = OpGroupNonUniformRotateKHR Subgroup $value $delta; }; } } -__generic -[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] -matrix WaveMultiPrefixBitOr(matrix expr, uint4 mask) +__generic +__glsl_extension(GL_KHR_shader_subgroup_rotate) +[require(glsl_spirv, subgroup_rotate)] +T WaveClusteredRotate(T value, uint delta, constexpr uint clusterSize) { + shader_subgroup_preamble(); __target_switch { - case cuda: __intrinsic_asm "_wavePrefixOrMultiple(_getMultiPrefixMask(($1).x), $0)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr"; case glsl: + __intrinsic_asm "subgroupClusteredRotate"; case spirv: - matrix result; - for (int i = 0; i < N; ++i) - result[i] = WaveMultiPrefixBitOr(expr[i], mask); - return result; + return spirv_asm + { + OpExtension "SPV_KHR_subgroup_rotate"; + OpCapability GroupNonUniformRotateKHR; + result:$$T = OpGroupNonUniformRotateKHR Subgroup $value $delta $clusterSize; + }; } } -/// @category wave -__generic -__glsl_extension(GL_NV_shader_subgroup_partitioned) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] -T WaveMultiPrefixBitXor(T expr, uint4 mask) +__generic +__glsl_extension(GL_KHR_shader_subgroup_rotate) +[require(glsl_spirv, subgroup_rotate)] +vector WaveClusteredRotate(vector value, uint delta, constexpr uint clusterSize) { + shader_subgroup_preamble(); __target_switch { - case cuda: __intrinsic_asm "_wavePrefixXor(_getMultiPrefixMask(($1).x), $0)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor"; - case glsl: __intrinsic_asm "subgroupPartitionedExclusiveXorNV"; + case glsl: + __intrinsic_asm "subgroupClusteredRotate"; case spirv: return spirv_asm { - OpExtension "SPV_NV_shader_subgroup_partitioned"; - OpCapability GroupNonUniformPartitionedNV; - result:$$T = OpGroupNonUniformBitwiseXor Subgroup PartitionedExclusiveScanNV $expr $mask + OpExtension "SPV_KHR_subgroup_rotate"; + OpCapability GroupNonUniformRotateKHR; + result:$$vector = OpGroupNonUniformRotateKHR Subgroup $value $delta $clusterSize; }; } } -__generic -__glsl_extension(GL_NV_shader_subgroup_partitioned) -__spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] -vector WaveMultiPrefixBitXor(vector expr, uint4 mask) + +// +// WaveMulti intrinsics are subgroup operations that operate on a 128-bit `uint4` mask. +// They are equivalent to SPIRV/GLSL's subgroup partitioned operation and HLSL's `WaveMultiPrefix*` operations. +// +// SPIRV/GLSL natively supports masked subgroup operations for both reductions and exclusive/inclusive scans. +// HLSL only natively supports exclusive scans(prefix operations) on arithmetic operations. Inclusve scans +// are emulated by performing an additional operation to the inclusive scan result. Reductions are not supported. +// + +__generic +[ForceInline] +void __shaderSubgroupPartitionedPreamble() { + shader_subgroup_preamble(); __target_switch { - case cuda: __intrinsic_asm "_wavePrefixXorMultiple(_getMultiPrefixMask(($1).x), $0)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor"; - case glsl: __intrinsic_asm "subgroupPartitionedExclusiveXorNV"; + case glsl: + __requireTargetExtension("GL_NV_shader_subgroup_partitioned"); case spirv: - return spirv_asm + spirv_asm { OpExtension "SPV_NV_shader_subgroup_partitioned"; OpCapability GroupNonUniformPartitionedNV; - result:$$vector = OpGroupNonUniformBitwiseXor Subgroup PartitionedExclusiveScanNV $expr $mask }; + default: + return; } } -__generic -[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] -matrix WaveMultiPrefixBitXor(matrix expr, uint4 mask) +// +// WaveMultiSum/WaveMultiProduct. +// +${{{{ +struct WaveMultiSumProductEntry { const char* name; const char* spirvName; }; +const WaveMultiSumProductEntry kWaveMultiSumProductNames[] = { {"Sum", "Add"}, {"Product", "Mul"} }; +for (auto opName : kWaveMultiSumProductNames) { +}}}} + +__generic +__spirv_version(1.3) +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +T WaveMulti$(opName.name)(T value, uint4 mask) { + __shaderSubgroupPartitionedPreamble(); __target_switch { - case cuda: __intrinsic_asm "_wavePrefixXorMultiple(_getMultiPrefixMask(($1).x), $0)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor"; + case cuda: + __intrinsic_asm "_wave$(opName.name)($1.x, $0)"; case glsl: + __intrinsic_asm "subgroupPartitioned$(opName.spirvName)NV"; case spirv: + { + if (__isFloat()) + return spirv_asm { result:$$T = OpGroupNonUniformF$(opName.spirvName) Subgroup PartitionedReduceNV $value $mask }; + else + return spirv_asm { result:$$T = OpGroupNonUniformI$(opName.spirvName) Subgroup PartitionedReduceNV $value $mask }; + } + } +} + +__generic +__spirv_version(1.3) +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +vector WaveMulti$(opName.name)(vector value, uint4 mask) +{ + __shaderSubgroupPartitionedPreamble(); + __target_switch + { + case cuda: + __intrinsic_asm "_wave$(opName.name)Multiple($1.x, $0)"; + case glsl: + __intrinsic_asm "subgroupPartitioned$(opName.spirvName)NV"; + case spirv: + { + if (__isFloat()) + return spirv_asm { result:$$vector = OpGroupNonUniformF$(opName.spirvName) Subgroup PartitionedReduceNV $value $mask }; + else + return spirv_asm { result:$$vector = OpGroupNonUniformI$(opName.spirvName) Subgroup PartitionedReduceNV $value $mask }; + } + } +} + +__generic +[require(cuda_glsl_spirv, subgroup_partitioned)] +matrix WaveMulti$(opName.name)(matrix value, uint4 mask) +{ + __target_switch + { + case cuda: + __intrinsic_asm "_wave$(opName.name)Multiple($1.x, $0)"; + default: matrix result; for (int i = 0; i < N; ++i) - result[i] = WaveMultiPrefixBitXor(expr[i], mask); + result[i] = WaveMulti$(opName.name)(value[i], mask); return result; } } -/// @category wave +${{{{ +} // WaveMultiSum/WaveMultiProduct. +}}}} + + +// +// WaveMultiPrefixInclusiveSum/WaveMultiPrefixInclusiveProduct. +// WaveMultiPrefixExclusiveSum/WaveMultiPrefixExclusiveProduct. +// WaveMultiPrefixSum/WaveMultiPrefixProduct. +// +${{{{ +struct WaveMultiPrefixSumProductEntry +{ + const char* name; + const char* spirvName; + const char* spirvGroupOperation; + const char* glslName; + const char* hlslName; + const char* cudaName; + const char* cudaExtraOperation; + + // Inclusive operations are not implemented by the CUDA prelude functions. + // They are implemented here by calling the exclusive implementation and performing an additional operations + // with the current invocation's value. This works for all cases except for element-wise matrix multiplication. + bool cudaMatrixVariantSupport; +}; + +const WaveMultiPrefixSumProductEntry kWaveMultiPrefixSumProductNames[] = +{ + // name spirvName spirvGroupOperation glslName hlslName cudaName cudaExtraOperation cudaMatrixVariantSupport + { "InclusiveSum", "Add", "PartitionedInclusiveScanNV", "InclusiveAdd", "Sum($0, $1) + $0", "Sum", "+ $0", false }, + { "InclusiveProduct", "Mul", "PartitionedInclusiveScanNV", "InclusiveMul", "Product($0, $1) * $0", "Product", "* $0", false }, + { "ExclusiveSum", "Add", "PartitionedExclusiveScanNV", "ExclusiveAdd", "Sum($0, $1)", "Sum", "", true }, + { "ExclusiveProduct", "Mul", "PartitionedExclusiveScanNV", "ExclusiveMul", "Product($0, $1)", "Product", "", true }, + + // These are HLSL SM 6.5 intrinsics and are equal to the exclusive variants. + { "Sum", "Add", "PartitionedExclusiveScanNV", "ExclusiveAdd", "Sum($0, $1)", "Sum", "", true }, + { "Product", "Mul", "PartitionedExclusiveScanNV", "ExclusiveMul", "Product($0, $1)", "Product", "", true }, +}; + +for (auto opName : kWaveMultiPrefixSumProductNames) { +}}}} + __generic -__glsl_extension(GL_NV_shader_subgroup_partitioned) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] -T WaveMultiPrefixProduct(T value, uint4 mask) +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +T WaveMultiPrefix$(opName.name)(T value, uint4 mask) { + __shaderSubgroupPartitionedPreamble(); __target_switch { - case cuda: __intrinsic_asm "_wavePrefixProduct(_getMultiPrefixMask(($1).x), $0)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixProduct"; - case glsl: __intrinsic_asm "subgroupPartitionedExclusiveMulNV"; + case cuda: + __intrinsic_asm "_wavePrefix$(opName.cudaName)($1.x, $0) $(opName.cudaExtraOperation)"; + case glsl: + __intrinsic_asm "subgroupPartitioned$(opName.glslName)NV"; + case hlsl: + __intrinsic_asm "WaveMultiPrefix$(opName.hlslName)"; case spirv: { - spirv_asm - { - OpExtension "SPV_NV_shader_subgroup_partitioned"; - OpCapability GroupNonUniformPartitionedNV; - }; - if (__isFloat()) - { - return spirv_asm - { - result:$$T = OpGroupNonUniformFMul Subgroup PartitionedExclusiveScanNV $value $mask - }; - } + return spirv_asm { result:$$T = OpGroupNonUniformF$(opName.spirvName) Subgroup $(opName.spirvGroupOperation) $value $mask }; else - { - return spirv_asm - { - result:$$T = OpGroupNonUniformIMul Subgroup PartitionedExclusiveScanNV $value $mask - }; - } + return spirv_asm { result:$$T = OpGroupNonUniformI$(opName.spirvName) Subgroup $(opName.spirvGroupOperation) $value $mask }; } } } __generic -__glsl_extension(GL_NV_shader_subgroup_partitioned) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] -vector WaveMultiPrefixProduct(vector value, uint4 mask) +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +vector WaveMultiPrefix$(opName.name)(vector value, uint4 mask) { + __shaderSubgroupPartitionedPreamble(); __target_switch { - case cuda: __intrinsic_asm "_wavePrefixProductMultiple(_getMultiPrefixMask(($1).x), $0)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixProduct"; - case glsl: __intrinsic_asm "subgroupPartitionedExclusiveMulNV"; + case cuda: + __intrinsic_asm "_wavePrefix$(opName.cudaName)Multiple($1.x, $0) $(opName.cudaExtraOperation)"; + case glsl: + __intrinsic_asm "subgroupPartitioned$(opName.glslName)NV"; + case hlsl: + __intrinsic_asm "WaveMultiPrefix$(opName.hlslName)"; case spirv: { - spirv_asm - { - OpExtension "SPV_NV_shader_subgroup_partitioned"; - OpCapability GroupNonUniformPartitionedNV; - }; - if (__isFloat()) - { - return spirv_asm - { - result:$$vector = OpGroupNonUniformFMul Subgroup PartitionedExclusiveScanNV $value $mask - }; - } + return spirv_asm { result:$$vector = OpGroupNonUniformF$(opName.spirvName) Subgroup $(opName.spirvGroupOperation) $value $mask }; else - { - return spirv_asm - { - result:$$vector = OpGroupNonUniformIMul Subgroup PartitionedExclusiveScanNV $value $mask - }; - } + return spirv_asm { result:$$vector = OpGroupNonUniformI$(opName.spirvName) Subgroup $(opName.spirvGroupOperation) $value $mask }; } } } __generic -[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] -matrix WaveMultiPrefixProduct(matrix value, uint4 mask) +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +matrix WaveMultiPrefix$(opName.name)(matrix value, uint4 mask) { __target_switch { - case cuda: __intrinsic_asm "_wavePrefixProductMultiple(_getMultiPrefixMask(($1).x), $0)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixProduct"; - case glsl: - case spirv: + ${{{{ if(opName.cudaMatrixVariantSupport) { }}}} + case cuda: + __intrinsic_asm "_wavePrefix$(opName.cudaName)Multiple($1.x, $0) $(opName.cudaExtraOperation)"; + ${{{{ } }}}} + default: matrix result; for (int i = 0; i < N; ++i) - result[i] = WaveMultiPrefixProduct(value[i], mask); + result[i] = WaveMultiPrefix$(opName.name)(value[i], mask); return result; } } -/// @category wave +${{{{ +} +// WaveMultiPrefixInclusiveSum/WaveMultiPrefixInclusiveProduct. +// WaveMultiPrefixExclusiveSum/WaveMultiPrefixExclusiveProduct. +// WaveMultiPrefixSum/WaveMultiPrefixProduct. +}}}} + + +// +// WaveMultiMin/WaveMultiMax. +// +${{{{ +struct WaveMultiMinMaxEntry { const char* name; }; +const WaveMultiMinMaxEntry kWaveMultiMinMaxNames[] = { {"Min"}, {"Max"} }; +for (auto opName : kWaveMultiMinMaxNames) { +}}}} + __generic -__glsl_extension(GL_NV_shader_subgroup_partitioned) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] -T WaveMultiPrefixSum(T value, uint4 mask) +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +T WaveMulti$(opName.name)(T value, uint4 mask) { + __shaderSubgroupPartitionedPreamble(); __target_switch { - case cuda: __intrinsic_asm "_wavePrefixSum(_getMultiPrefixMask(($1).x), $0)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixSum"; - case glsl: __intrinsic_asm "subgroupPartitionedExclusiveAddNV"; + case cuda: + __intrinsic_asm "_wave$(opName.name)($1.x, $0)"; + case glsl: + __intrinsic_asm "subgroupPartitioned$(opName.name)NV"; case spirv: { - spirv_asm - { - OpExtension "SPV_NV_shader_subgroup_partitioned"; - OpCapability GroupNonUniformPartitionedNV; - }; - if (__isFloat()) - { - return spirv_asm - { - result:$$T = OpGroupNonUniformFAdd Subgroup PartitionedExclusiveScanNV $value $mask - }; - } + return spirv_asm { result:$$T = OpGroupNonUniformF$(opName.name) Subgroup PartitionedReduceNV $value $mask }; + else if (__isUnsignedInt()) + return spirv_asm { result:$$T = OpGroupNonUniformU$(opName.name) Subgroup PartitionedReduceNV $value $mask }; else - { - return spirv_asm - { - result:$$T = OpGroupNonUniformIAdd Subgroup PartitionedExclusiveScanNV $value $mask - }; - } + return spirv_asm { result:$$T = OpGroupNonUniformS$(opName.name) Subgroup PartitionedReduceNV $value $mask }; } } } __generic -__glsl_extension(GL_NV_shader_subgroup_partitioned) -[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] __spirv_version(1.3) -vector WaveMultiPrefixSum(vector value, uint4 mask) +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +vector WaveMulti$(opName.name)(vector value, uint4 mask) { + __shaderSubgroupPartitionedPreamble(); __target_switch { - case cuda: __intrinsic_asm "_wavePrefixSumMultiple(_getMultiPrefixMask(($1).x), $0 )"; - case hlsl: __intrinsic_asm "WaveMultiPrefixSum"; - case glsl: __intrinsic_asm "subgroupPartitionedExclusiveAddNV"; + case cuda: + __intrinsic_asm "_wave$(opName.name)Multiple($1.x, $0)"; + case glsl: + __intrinsic_asm "subgroupPartitioned$(opName.name)NV"; case spirv: { - spirv_asm - { - OpExtension "SPV_NV_shader_subgroup_partitioned"; - OpCapability GroupNonUniformPartitionedNV; - }; - if (__isFloat()) - { - return spirv_asm - { - result:$$vector = OpGroupNonUniformFAdd Subgroup PartitionedExclusiveScanNV $value $mask - }; - } + return spirv_asm { result:$$vector = OpGroupNonUniformF$(opName.name) Subgroup PartitionedReduceNV $value $mask }; + else if (__isUnsignedInt()) + return spirv_asm { result:$$vector = OpGroupNonUniformU$(opName.name) Subgroup PartitionedReduceNV $value $mask }; else - { - return spirv_asm - { - result:$$vector = OpGroupNonUniformIAdd Subgroup PartitionedExclusiveScanNV $value $mask - }; - } + return spirv_asm { result:$$vector = OpGroupNonUniformS$(opName.name) Subgroup PartitionedReduceNV $value $mask }; } } } __generic -[require(cuda_glsl_hlsl_spirv, wave_multi_prefix)] -matrix WaveMultiPrefixSum(matrix value, uint4 mask) +[require(cuda_glsl_spirv, subgroup_partitioned)] +matrix WaveMulti$(opName.name)(matrix value, uint4 mask) { __target_switch { - case cuda: __intrinsic_asm "_wavePrefixSumMultiple(_getMultiPrefixMask(($1).x), $0)"; - case hlsl: __intrinsic_asm "WaveMultiPrefixSum"; - case glsl: - case spirv: + case cuda: + __intrinsic_asm "_wave$(opName.name)Multiple($1.x, $0)"; + default: matrix result; + [ForceUnroll] for (int i = 0; i < N; ++i) - result[i] = WaveMultiPrefixSum(value[i], mask); + result[i] = WaveMulti$(opName.name)(value[i], mask); return result; } } -__glsl_extension(GL_EXT_demote_to_helper_invocation) -[ForceInline] -[require(glsl_hlsl_metal_spirv, helper_lane)] -bool IsHelperLane() -{ - __target_switch { - case hlsl: __intrinsic_asm "IsHelperLane()"; - case glsl: __intrinsic_asm "gl_HelperInvocation"; - case metal: __intrinsic_asm "simd_is_helper_thread()"; - case spirv: - return spirv_asm { - OpExtension "SPV_EXT_demote_to_helper_invocation"; - OpCapability DemoteToHelperInvocationEXT; - result:$$bool = OpIsHelperInvocationEXT - }; - } -} +${{{{ +} // WaveMultiMin/WaveMultiMax. +}}}} -//@hidden: -__generic -[ForceInline] -[require(glsl)] -void __requireGLSLShaderSubgroupTypeExtension() +// +// WaveMultiPrefixInclusiveMin/WaveMultiPrefixInclusiveMax. +// WaveMultiPrefixExclusiveMin/WaveMultiPrefixExclusiveMax. +// +${{{{ +struct WaveMultiPrefixMinMaxEntry { - // the following is a seperate function call, since else the `__requireTargetExtension` and associated __intrinsic_asm is ignored if the calling function also calls an __intrinsic_asm - if (__type_equals() - || __type_equals() - ) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_float16"); - else if (__type_equals() - || __type_equals() - ) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_int8"); - else if (__type_equals() - || __type_equals() - ) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_int16"); - else if (__type_equals() - || __type_equals() - ) __requireTargetExtension("GL_EXT_shader_subgroup_extended_types_int64"); + const char* name; + const char* spirvName; + const char* spirvGroupOperation; + const char* glslName; +}; - __intrinsic_asm ""; -} +const WaveMultiPrefixMinMaxEntry kWaveMultiPrefixMinMaxNames[] = +{ + // name spirvName spirvGroupOperation glslName + { "InclusiveMin", "Min", "PartitionedInclusiveScanNV", "InclusiveMin" }, + { "InclusiveMax", "Max", "PartitionedInclusiveScanNV", "InclusiveMax" }, + { "ExclusiveMin", "Min", "PartitionedExclusiveScanNV", "ExclusiveMin" }, + { "ExclusiveMax", "Max", "PartitionedExclusiveScanNV", "ExclusiveMax" }, +}; -__generic +for (auto opName : kWaveMultiPrefixMinMaxNames) { +}}}} + +__generic +__spirv_version(1.3) [ForceInline] -[require(metal)] -void __checkMetalShaderSubgroupType() +[require(glsl_spirv, subgroup_partitioned)] +T WaveMultiPrefix$(opName.name)(T value, uint4 mask) { - // These builtin types are not supported for Metal's `simd` operations. - if (__type_equals() - || __type_equals() - || __type_equals() - || __type_equals() - || __isBool() - ) + __shaderSubgroupPartitionedPreamble(); + __target_switch { - static_assert(false, "Unsupported type for subgroup operations in Metal. Valid types include scalars and vectors of uint/uint32_t, int/int32_t, uint16_t, int16_t, float, and half."); + case glsl: + __intrinsic_asm "subgroupPartitioned$(opName.glslName)NV"; + case spirv: + { + if (__isFloat()) + return spirv_asm { result:$$T = OpGroupNonUniformF$(opName.spirvName) Subgroup $(opName.spirvGroupOperation) $value $mask }; + else if (__isUnsignedInt()) + return spirv_asm { result:$$T = OpGroupNonUniformU$(opName.spirvName) Subgroup $(opName.spirvGroupOperation) $value $mask }; + else + return spirv_asm { result:$$T = OpGroupNonUniformS$(opName.spirvName) Subgroup $(opName.spirvGroupOperation) $value $mask }; + } } } -__generic -void shader_subgroup_preamble() +__generic +__spirv_version(1.3) +[ForceInline] +[require(glsl_spirv, subgroup_partitioned)] +vector WaveMultiPrefix$(opName.name)(vector value, uint4 mask) { - // checks needed for shader_subgroup functions; __requireTargetExtension does not work - // (does not add the ext specified correctly to the compile output; using extended type - // will result in error for using the type) + __shaderSubgroupPartitionedPreamble(); __target_switch { case glsl: - __requireGLSLShaderSubgroupTypeExtension(); - case metal: - __checkMetalShaderSubgroupType(); - default: - return; + __intrinsic_asm "subgroupPartitioned$(opName.glslName)NV"; + case spirv: + { + if (__isFloat()) + return spirv_asm { result:$$vector = OpGroupNonUniformF$(opName.spirvName) Subgroup $(opName.spirvGroupOperation) $value $mask }; + else if (__isUnsignedInt()) + return spirv_asm { result:$$vector = OpGroupNonUniformU$(opName.spirvName) Subgroup $(opName.spirvGroupOperation) $value $mask }; + else + return spirv_asm { result:$$vector = OpGroupNonUniformS$(opName.spirvName) Subgroup $(opName.spirvGroupOperation) $value $mask }; + } } } -//@public: +__generic +[require(glsl_spirv, subgroup_partitioned)] +matrix WaveMultiPrefix$(opName.name)(matrix value, uint4 mask) +{ + matrix result; + [ForceUnroll] + for (int i = 0; i < N; ++i) + result[i] = WaveMultiPrefix$(opName.name)(value[i], mask); + return result; +} + +${{{{ +} +// WaveMultiPrefixInclusiveMin/WaveMultiPrefixInclusiveMax. +// WaveMultiPrefixExclusiveMin/WaveMultiPrefixExclusiveMax. +}}}} + // -// Wave Rotate intrinsics. -// These are Slang specific intrinsics to rotate values within a subgroup. +// WaveMultiBitAnd/WaveMultiBitOr/WaveMultiBitXor. // +${{{{ +struct WaveMultiBitsEntry { const char* name; }; +const WaveMultiBitsEntry kWaveMultiBitsNames[] = { {"And"}, {"Or"} , {"Xor"} }; +for (auto opName : kWaveMultiBitsNames) { +}}}} -__generic -__glsl_extension(GL_KHR_shader_subgroup_rotate) -[require(glsl_metal_spirv, subgroup_rotate)] -T WaveRotate(T value, uint delta) +__generic +__spirv_version(1.3) +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +T WaveMultiBit$(opName.name)(T value, uint4 mask) { - shader_subgroup_preamble(); + __shaderSubgroupPartitionedPreamble(); __target_switch { + case cuda: + __intrinsic_asm "_wave$(opName.name)($1.x, $0)"; case glsl: - __intrinsic_asm "subgroupRotate"; - case metal: - __intrinsic_asm "simd_shuffle_rotate_down"; + __intrinsic_asm "subgroupPartitioned$(opName.name)NV"; case spirv: return spirv_asm { - OpExtension "SPV_KHR_subgroup_rotate"; - OpCapability GroupNonUniformRotateKHR; - result:$$T = OpGroupNonUniformRotateKHR Subgroup $value $delta; + result:$$T = OpGroupNonUniformBitwise$(opName.name) Subgroup PartitionedReduceNV $value $mask; }; } } -__generic -__glsl_extension(GL_KHR_shader_subgroup_rotate) -[require(glsl_metal_spirv, subgroup_rotate)] -vector WaveRotate(vector value, uint delta) +__generic +__spirv_version(1.3) +[ForceInline] +[require(cuda_glsl_spirv, subgroup_partitioned)] +vector WaveMultiBit$(opName.name)(vector value, uint4 mask) { - shader_subgroup_preamble(); + __shaderSubgroupPartitionedPreamble(); __target_switch { + case cuda: + __intrinsic_asm "_wave$(opName.name)Multiple($1.x, $0)"; case glsl: - __intrinsic_asm "subgroupRotate"; - case metal: - __intrinsic_asm "simd_shuffle_rotate_down"; + __intrinsic_asm "subgroupPartitioned$(opName.name)NV"; case spirv: return spirv_asm { - OpExtension "SPV_KHR_subgroup_rotate"; - OpCapability GroupNonUniformRotateKHR; - result:$$vector = OpGroupNonUniformRotateKHR Subgroup $value $delta; + result:$$vector = OpGroupNonUniformBitwise$(opName.name) Subgroup PartitionedReduceNV $value $mask; }; } } -__generic -__glsl_extension(GL_KHR_shader_subgroup_rotate) -[require(glsl_spirv, subgroup_rotate)] -T WaveClusteredRotate(T value, uint delta, constexpr uint clusterSize) +__generic +[require(cuda_glsl_spirv, subgroup_partitioned)] +matrix WaveMultiBit$(opName.name)(matrix value, uint4 mask) { - shader_subgroup_preamble(); __target_switch { + case cuda: + __intrinsic_asm "_wave$(opName.name)Multiple($1.x, $0)"; + default: + matrix result; + [ForceUnroll] + for (int i = 0; i < N; ++i) + result[i] = WaveMultiBit$(opName.name)(value[i], mask); + return result; + } +} + +${{{{ +} // WaveMultiBitAnd/WaveMultiBitOr/WaveMultiBitXor. +}}}} + + +// +// WaveMultiPrefixInclusiveBitAnd/WaveMultiPrefixInclusiveBitOr/WaveMultiInclusiveBitXor. +// WaveMultiPrefixExclusiveBitAnd/WaveMultiPrefixExclusiveBitXor/WaveMultiExclusiveBitXor. +// WaveMultiPrefixBitAnd/WaveMultiPrefixBitOr/WaveMultiBitXor. +// +${{{{ +struct WaveMultiPrefixBitwiseEntry +{ + const char* name; + const char* spirvName; + const char* spirvGroupOperation; + const char* glslName; + const char* hlslName; + const char* cudaExtraOperation; + + bool cudaMatrixVariantSupport; +}; + +const WaveMultiPrefixBitwiseEntry kWaveMultiPrefixBitwiseNames[] = +{ + // name spirvName spirvGroupOperation glslName hlslName cudaExtraOperation cudaMatrixVariantSupport + { "InclusiveBitAnd", "And", "PartitionedInclusiveScanNV", "InclusiveAnd", "And($0, $1) & $0", "& $0", false }, + { "InclusiveBitOr", "Or", "PartitionedInclusiveScanNV", "InclusiveOr", "Or($0, $1) | $0", "| $0", false }, + { "InclusiveBitXor", "Xor", "PartitionedInclusiveScanNV", "InclusiveXor", "Xor($0, $1) ^ $0", "^ $0", false }, + { "ExclusiveBitAnd", "And", "PartitionedExclusiveScanNV", "ExclusiveAnd", "And", "", true }, + { "ExclusiveBitOr", "Or", "PartitionedExclusiveScanNV", "ExclusiveOr", "Or", "", true }, + { "ExclusiveBitXor", "Xor", "PartitionedExclusiveScanNV", "ExclusiveXor", "Xor", "", true }, + + // These are HLSL SM 6.5 intrinsics and are equal to the exclusive variants. + { "BitAnd", "And", "PartitionedExclusiveScanNV", "ExclusiveAnd", "And", "", true }, + { "BitOr", "Or", "PartitionedExclusiveScanNV", "ExclusiveOr", "Or", "", true }, + { "BitXor", "Xor", "PartitionedExclusiveScanNV", "ExclusiveXor", "Xor", "", true }, +}; + +for (auto opName : kWaveMultiPrefixBitwiseNames) { +}}}} + +__generic +__spirv_version(1.3) +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +T WaveMultiPrefix$(opName.name)(T value, uint4 mask) +{ + __shaderSubgroupPartitionedPreamble(); + __target_switch + { + case cuda: + __intrinsic_asm "_wavePrefix$(opName.spirvName)($1.x, $0) $(opName.cudaExtraOperation)"; case glsl: - __intrinsic_asm "subgroupClusteredRotate"; + __intrinsic_asm "subgroupPartitioned$(opName.glslName)NV"; + case hlsl: + __intrinsic_asm "WaveMultiPrefixBit$(opName.hlslName)"; case spirv: return spirv_asm { - OpExtension "SPV_KHR_subgroup_rotate"; - OpCapability GroupNonUniformRotateKHR; - result:$$T = OpGroupNonUniformRotateKHR Subgroup $value $delta $clusterSize; + result:$$T = OpGroupNonUniformBitwise$(opName.spirvName) Subgroup $(opName.spirvGroupOperation) $value $mask; }; } } -__generic -__glsl_extension(GL_KHR_shader_subgroup_rotate) -[require(glsl_spirv, subgroup_rotate)] -vector WaveClusteredRotate(vector value, uint delta, constexpr uint clusterSize) +__generic +__spirv_version(1.3) +[ForceInline] +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +vector WaveMultiPrefix$(opName.name)(vector value, uint4 mask) { - shader_subgroup_preamble(); + __shaderSubgroupPartitionedPreamble(); __target_switch { + case cuda: + __intrinsic_asm "_wavePrefix$(opName.spirvName)Multiple($1.x, $0) $(opName.cudaExtraOperation)"; case glsl: - __intrinsic_asm "subgroupClusteredRotate"; + __intrinsic_asm "subgroupPartitioned$(opName.glslName)NV"; + case hlsl: + __intrinsic_asm "WaveMultiPrefixBit$(opName.hlslName)"; case spirv: return spirv_asm { - OpExtension "SPV_KHR_subgroup_rotate"; - OpCapability GroupNonUniformRotateKHR; - result:$$vector = OpGroupNonUniformRotateKHR Subgroup $value $delta $clusterSize; + result:$$vector = OpGroupNonUniformBitwise$(opName.spirvName) Subgroup $(opName.spirvGroupOperation) $value $mask; }; } } +__generic +[require(cuda_glsl_hlsl_spirv, subgroup_partitioned)] +matrix WaveMultiPrefix$(opName.name)(matrix value, uint4 mask) +{ + __target_switch + { +${{{{ + if (opName.cudaMatrixVariantSupport) { +}}}} + case cuda: + __intrinsic_asm "_wavePrefix$(opName.spirvName)Multiple($1.x, $0) $(opName.cudaExtraOperation)"; +${{{{ + } +}}}} + default: + matrix result; + [ForceUnroll] + for (int i = 0; i < N; ++i) + result[i] = WaveMultiPrefix$(opName.name)(value[i], mask); + return result; + } +} +${{{{ +} +// WaveMultiPrefixInclusiveBitAnd/WaveMultiPrefixInclusiveBitOr/WaveMultiInclusiveBitXor. +// WaveMultiPrefixExclusiveBitAnd/WaveMultiPrefixExclusiveBitXor/WaveMultiExclusiveBitXor. +// WaveMultiPrefixBitAnd/WaveMultiPrefixBitOr/WaveMultiBitXor. +}}}} + + // // Quad Control intrinsics // diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index 1799d4bfc..48617c54d 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -1157,11 +1157,6 @@ alias fragmentshaderbarycentric = GL_EXT_fragment_shader_barycentric | _sm_6_1; /// (gfx targets) Capabilities needed to use memory barriers /// [Compound] alias shadermemorycontrol = glsl | _spirv_1_0 | _sm_5_0; -/// Capabilities needed to use HLSL tier wave operations -/// [Compound] -alias wave_multi_prefix = _sm_6_5 - | _cuda_sm_7_0 - | GL_KHR_shader_subgroup_ballot + GL_KHR_shader_subgroup_arithmetic + GL_NV_shader_subgroup_partitioned; /// Capabilities needed to use GLSL buffer-reference's /// [Compound] alias bufferreference = GL_EXT_buffer_reference; @@ -2186,7 +2181,9 @@ alias subgroup_quad = GL_KHR_shader_subgroup_quad ; /// Capabilities required to use GLSL-style subgroup operations 'subgroup_partitioned' /// [Compound] -alias subgroup_partitioned = GL_NV_shader_subgroup_partitioned + subgroup_ballot_activemask | _sm_6_5 | _cuda_sm_7_0; +alias subgroup_partitioned = _sm_6_5 + | _cuda_sm_7_0 + | GL_KHR_shader_subgroup_ballot + GL_KHR_shader_subgroup_arithmetic + GL_NV_shader_subgroup_partitioned; /// Capabilities required to use GLSL-style subgroup rotate operations 'subgroup_rotate' -- cgit v1.2.3