diff options
| author | Yong He <yonghe@outlook.com> | 2024-05-03 15:23:23 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-05-03 15:23:23 -0700 |
| commit | 54153a3681c7c6ef86c6f7a864719d32a1934240 (patch) | |
| tree | a79bfcce08f754fd0187fb09fe014d98b1cd0c27 /source | |
| parent | 47a917c964f4fda32d75f200efe863f6d68c737c (diff) | |
Don't bottleneck Wave intrinsics through `WaveMask*` for spirv. (#4099)
* Don't bottleneck Wave intrinsics through `WaveMask*` for spirv.
* Fix.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/hlsl.meta.slang | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index dc64705dd..c96be49e1 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -11493,6 +11493,13 @@ matrix<T, N, M> WaveActive$(opName.hlslName)(matrix<T, N, M> expr) __target_switch { case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)"; + case glsl: + case spirv: + matrix<T,N,M> result; + [ForceUnroll] + for (int i = 0; i < N; ++i) + result[i] = WaveActive$(opName.hlslName)(expr[i]); + return result; default: return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr); } @@ -11558,6 +11565,13 @@ matrix<T, N, M> WaveActive$(opName)(matrix<T, N, M> expr) __target_switch { case hlsl: __intrinsic_asm "WaveActive$(opName)"; + case glsl: + case spirv: + matrix<T, N, M> result; + [ForceUnroll] + for (int i = 0; i < N; ++i) + result[i] = WaveActive$(opName)(expr[i]); + return result; default: return WaveMask$(opName)(WaveGetActiveMask(), expr); } @@ -11645,6 +11659,13 @@ matrix<T, N, M> WaveActive$(opName.hlslName)(matrix<T, N, M> expr) __target_switch { case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)"; + case glsl: + case spirv: + matrix<T, N, M> result; + [ForceUnroll] + for (int i = 0; i < N; ++i) + result[i] = WaveActive$(opName.hlslName)(expr[i]); + return result; default: return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr); } @@ -11784,6 +11805,9 @@ uint WaveActiveCountBits(bool value) __target_switch { case hlsl: __intrinsic_asm "WaveActiveCountBits"; + case glsl: + case spirv: + return _WaveCountBits(WaveActiveBallot(value)); default: return WaveMaskCountBits(WaveGetActiveMask(), value); } @@ -11952,6 +11976,12 @@ matrix<T, N, M> WavePrefixProduct(matrix<T, N, M> expr) __target_switch { case hlsl: __intrinsic_asm "WavePrefixProduct"; + case glsl: + case spirv: + matrix<T, N, M> result; + for (int i = 0; i < N; ++i) + result[i] = WavePrefixProduct(expr[i]); + return result; default: return WaveMaskPrefixProduct(WaveGetActiveMask(), expr); } @@ -12022,6 +12052,12 @@ matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr) __target_switch { case hlsl: __intrinsic_asm "WavePrefixSum"; + case glsl: + case spirv: + matrix<T, N, M> result; + for (int i = 0; i < N; ++i) + result[i] = WavePrefixSum(expr[i]); + return result; default: return WaveMaskPrefixSum(WaveGetActiveMask(), expr); } @@ -12072,6 +12108,12 @@ matrix<T,N,M> WaveReadLaneFirst(matrix<T,N,M> expr) __target_switch { case hlsl: __intrinsic_asm "WaveReadLaneFirst"; + case glsl: + case spirv: + matrix<T, N, M> result; + for (int i = 0; i < N; ++i) + result[i] = WaveReadLaneFirst(expr[i]); + return result; default: return WaveMaskReadLaneFirst(WaveGetActiveMask(), expr); } @@ -12131,6 +12173,12 @@ matrix<T, N, M> WaveBroadcastLaneAt(matrix<T, N, M> value, constexpr int lane) { case cuda: __intrinsic_asm "_waveShuffleMultiple(_getActiveMask(), $0, $1)"; case hlsl: __intrinsic_asm "WaveReadLaneAt"; + case glsl: + case spirv: + matrix<T, N, M> result; + for (int i = 0; i < N; ++i) + result[i] = WaveBroadcastLaneAt(value[i], lane); + return result; default: return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, lane); } @@ -12186,6 +12234,12 @@ matrix<T, N, M> WaveReadLaneAt(matrix<T, N, M> value, int lane) { case cuda: __intrinsic_asm "_waveShuffleMultiple(_getActiveMask(), $0, $1)"; case hlsl: __intrinsic_asm "WaveReadLaneAt"; + case glsl: + case spirv: + matrix<T,N,M> result; + for (int i = 0; i < N; ++i) + result[i] = WaveReadLaneAt(value[i], lane); + return result; default: return WaveMaskReadLaneAt(WaveGetActiveMask(), value, lane); } @@ -12305,6 +12359,14 @@ uint4 WaveMatch(T value) __target_switch { case hlsl: __intrinsic_asm "WaveMatch"; + case glsl: __intrinsic_asm "subgroupPartitionNV($0)"; + case spirv: + return spirv_asm + { + OpCapability GroupNonUniformPartitionedNV; + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpGroupNonUniformPartitionNV $$uint4 result $value + }; default: return WaveMaskMatch(WaveGetActiveMask(), value); } @@ -12317,6 +12379,14 @@ uint4 WaveMatch(vector<T,N> value) __target_switch { case hlsl: __intrinsic_asm "WaveMatch"; + case glsl: __intrinsic_asm "subgroupPartitionNV($0)"; + case spirv: + return spirv_asm + { + OpCapability GroupNonUniformPartitionedNV; + OpExtension "SPV_NV_shader_subgroup_partitioned"; + OpGroupNonUniformPartitionNV $$uint4 result $value + }; default: return WaveMaskMatch(WaveGetActiveMask(), value); } @@ -12329,6 +12399,14 @@ uint4 WaveMatch(matrix<T,N,M> value) __target_switch { case hlsl: __intrinsic_asm "WaveMatch"; + case glsl: + case cuda: + case spirv: + uint4 result = uint4(0xFFFFFFFF); + [ForceUnroll] + for (int i = 0; i < N; i++) + result &= WaveMatch(value[0]); + return result; default: return WaveMaskMatch(WaveGetActiveMask(), value); } |
