From 54153a3681c7c6ef86c6f7a864719d32a1934240 Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 3 May 2024 15:23:23 -0700 Subject: Don't bottleneck Wave intrinsics through `WaveMask*` for spirv. (#4099) * Don't bottleneck Wave intrinsics through `WaveMask*` for spirv. * Fix. --- source/slang/hlsl.meta.slang | 78 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 78 insertions(+) (limited to 'source') diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index dc64705dd..c96be49e1 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -11493,6 +11493,13 @@ matrix WaveActive$(opName.hlslName)(matrix expr) __target_switch { case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)"; + case glsl: + case spirv: + matrix result; + [ForceUnroll] + for (int i = 0; i < N; ++i) + result[i] = WaveActive$(opName.hlslName)(expr[i]); + return result; default: return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr); } @@ -11558,6 +11565,13 @@ matrix WaveActive$(opName)(matrix expr) __target_switch { case hlsl: __intrinsic_asm "WaveActive$(opName)"; + case glsl: + case spirv: + matrix result; + [ForceUnroll] + for (int i = 0; i < N; ++i) + result[i] = WaveActive$(opName)(expr[i]); + return result; default: return WaveMask$(opName)(WaveGetActiveMask(), expr); } @@ -11645,6 +11659,13 @@ matrix WaveActive$(opName.hlslName)(matrix expr) __target_switch { case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)"; + case glsl: + case spirv: + matrix result; + [ForceUnroll] + for (int i = 0; i < N; ++i) + result[i] = WaveActive$(opName.hlslName)(expr[i]); + return result; default: return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr); } @@ -11784,6 +11805,9 @@ uint WaveActiveCountBits(bool value) __target_switch { case hlsl: __intrinsic_asm "WaveActiveCountBits"; + case glsl: + case spirv: + return _WaveCountBits(WaveActiveBallot(value)); default: return WaveMaskCountBits(WaveGetActiveMask(), value); } @@ -11952,6 +11976,12 @@ matrix WavePrefixProduct(matrix expr) __target_switch { case hlsl: __intrinsic_asm "WavePrefixProduct"; + case glsl: + case spirv: + matrix result; + for (int i = 0; i < N; ++i) + result[i] = WavePrefixProduct(expr[i]); + return result; default: return WaveMaskPrefixProduct(WaveGetActiveMask(), expr); } @@ -12022,6 +12052,12 @@ matrix WavePrefixSum(matrix expr) __target_switch { case hlsl: __intrinsic_asm "WavePrefixSum"; + case glsl: + case spirv: + matrix result; + for (int i = 0; i < N; ++i) + result[i] = WavePrefixSum(expr[i]); + return result; default: return WaveMaskPrefixSum(WaveGetActiveMask(), expr); } @@ -12072,6 +12108,12 @@ matrix WaveReadLaneFirst(matrix expr) __target_switch { case hlsl: __intrinsic_asm "WaveReadLaneFirst"; + case glsl: + case spirv: + matrix result; + for (int i = 0; i < N; ++i) + result[i] = WaveReadLaneFirst(expr[i]); + return result; default: return WaveMaskReadLaneFirst(WaveGetActiveMask(), expr); } @@ -12131,6 +12173,12 @@ matrix WaveBroadcastLaneAt(matrix value, constexpr int lane) { case cuda: __intrinsic_asm "_waveShuffleMultiple(_getActiveMask(), $0, $1)"; case hlsl: __intrinsic_asm "WaveReadLaneAt"; + case glsl: + case spirv: + matrix result; + for (int i = 0; i < N; ++i) + result[i] = WaveBroadcastLaneAt(value[i], lane); + return result; default: return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, lane); } @@ -12186,6 +12234,12 @@ matrix WaveReadLaneAt(matrix value, int lane) { case cuda: __intrinsic_asm "_waveShuffleMultiple(_getActiveMask(), $0, $1)"; case hlsl: __intrinsic_asm "WaveReadLaneAt"; + case glsl: + case spirv: + matrix result; + for (int i = 0; i < N; ++i) + result[i] = WaveReadLaneAt(value[i], lane); + return result; default: return WaveMaskReadLaneAt(WaveGetActiveMask(), value, lane); } @@ -12305,6 +12359,14 @@ uint4 WaveMatch(T value) __target_switch { case hlsl: __intrinsic_asm "WaveMatch"; + case glsl: __intrinsic_asm "subgroupPartitionNV($0)"; + case spirv: + return spirv_asm + { + OpCapability GroupNonUniformPartitionedNV; + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpGroupNonUniformPartitionNV $$uint4 result $value + }; default: return WaveMaskMatch(WaveGetActiveMask(), value); } @@ -12317,6 +12379,14 @@ uint4 WaveMatch(vector value) __target_switch { case hlsl: __intrinsic_asm "WaveMatch"; + case glsl: __intrinsic_asm "subgroupPartitionNV($0)"; + case spirv: + return spirv_asm + { + OpCapability GroupNonUniformPartitionedNV; + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpGroupNonUniformPartitionNV $$uint4 result $value + }; default: return WaveMaskMatch(WaveGetActiveMask(), value); } @@ -12329,6 +12399,14 @@ uint4 WaveMatch(matrix value) __target_switch { case hlsl: __intrinsic_asm "WaveMatch"; + case glsl: + case cuda: + case spirv: + uint4 result = uint4(0xFFFFFFFF); + [ForceUnroll] + for (int i = 0; i < N; i++) + result &= WaveMatch(value[0]); + return result; default: return WaveMaskMatch(WaveGetActiveMask(), value); } -- cgit v1.2.3