summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2024-05-03 15:23:23 -0700
committerGitHub <noreply@github.com>2024-05-03 15:23:23 -0700
commit54153a3681c7c6ef86c6f7a864719d32a1934240 (patch)
treea79bfcce08f754fd0187fb09fe014d98b1cd0c27 /source
parent47a917c964f4fda32d75f200efe863f6d68c737c (diff)
Don't bottleneck Wave intrinsics through `WaveMask*` for spirv. (#4099)
* Don't bottleneck Wave intrinsics through `WaveMask*` for spirv. * Fix.
Diffstat (limited to 'source')
-rw-r--r--source/slang/hlsl.meta.slang78
1 files changed, 78 insertions, 0 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index dc64705dd..c96be49e1 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -11493,6 +11493,13 @@ matrix<T, N, M> WaveActive$(opName.hlslName)(matrix<T, N, M> expr)
__target_switch
{
case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
+ case glsl:
+ case spirv:
+ matrix<T,N,M> result;
+ [ForceUnroll]
+ for (int i = 0; i < N; ++i)
+ result[i] = WaveActive$(opName.hlslName)(expr[i]);
+ return result;
default:
return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
}
@@ -11558,6 +11565,13 @@ matrix<T, N, M> WaveActive$(opName)(matrix<T, N, M> expr)
__target_switch
{
case hlsl: __intrinsic_asm "WaveActive$(opName)";
+ case glsl:
+ case spirv:
+ matrix<T, N, M> result;
+ [ForceUnroll]
+ for (int i = 0; i < N; ++i)
+ result[i] = WaveActive$(opName)(expr[i]);
+ return result;
default:
return WaveMask$(opName)(WaveGetActiveMask(), expr);
}
@@ -11645,6 +11659,13 @@ matrix<T, N, M> WaveActive$(opName.hlslName)(matrix<T, N, M> expr)
__target_switch
{
case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
+ case glsl:
+ case spirv:
+ matrix<T, N, M> result;
+ [ForceUnroll]
+ for (int i = 0; i < N; ++i)
+ result[i] = WaveActive$(opName.hlslName)(expr[i]);
+ return result;
default:
return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
}
@@ -11784,6 +11805,9 @@ uint WaveActiveCountBits(bool value)
__target_switch
{
case hlsl: __intrinsic_asm "WaveActiveCountBits";
+ case glsl:
+ case spirv:
+ return _WaveCountBits(WaveActiveBallot(value));
default:
return WaveMaskCountBits(WaveGetActiveMask(), value);
}
@@ -11952,6 +11976,12 @@ matrix<T, N, M> WavePrefixProduct(matrix<T, N, M> expr)
__target_switch
{
case hlsl: __intrinsic_asm "WavePrefixProduct";
+ case glsl:
+ case spirv:
+ matrix<T, N, M> result;
+ for (int i = 0; i < N; ++i)
+ result[i] = WavePrefixProduct(expr[i]);
+ return result;
default:
return WaveMaskPrefixProduct(WaveGetActiveMask(), expr);
}
@@ -12022,6 +12052,12 @@ matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr)
__target_switch
{
case hlsl: __intrinsic_asm "WavePrefixSum";
+ case glsl:
+ case spirv:
+ matrix<T, N, M> result;
+ for (int i = 0; i < N; ++i)
+ result[i] = WavePrefixSum(expr[i]);
+ return result;
default:
return WaveMaskPrefixSum(WaveGetActiveMask(), expr);
}
@@ -12072,6 +12108,12 @@ matrix<T,N,M> WaveReadLaneFirst(matrix<T,N,M> expr)
__target_switch
{
case hlsl: __intrinsic_asm "WaveReadLaneFirst";
+ case glsl:
+ case spirv:
+ matrix<T, N, M> result;
+ for (int i = 0; i < N; ++i)
+ result[i] = WaveReadLaneFirst(expr[i]);
+ return result;
default:
return WaveMaskReadLaneFirst(WaveGetActiveMask(), expr);
}
@@ -12131,6 +12173,12 @@ matrix<T, N, M> WaveBroadcastLaneAt(matrix<T, N, M> value, constexpr int lane)
{
case cuda: __intrinsic_asm "_waveShuffleMultiple(_getActiveMask(), $0, $1)";
case hlsl: __intrinsic_asm "WaveReadLaneAt";
+ case glsl:
+ case spirv:
+ matrix<T, N, M> result;
+ for (int i = 0; i < N; ++i)
+ result[i] = WaveBroadcastLaneAt(value[i], lane);
+ return result;
default:
return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, lane);
}
@@ -12186,6 +12234,12 @@ matrix<T, N, M> WaveReadLaneAt(matrix<T, N, M> value, int lane)
{
case cuda: __intrinsic_asm "_waveShuffleMultiple(_getActiveMask(), $0, $1)";
case hlsl: __intrinsic_asm "WaveReadLaneAt";
+ case glsl:
+ case spirv:
+ matrix<T,N,M> result;
+ for (int i = 0; i < N; ++i)
+ result[i] = WaveReadLaneAt(value[i], lane);
+ return result;
default:
return WaveMaskReadLaneAt(WaveGetActiveMask(), value, lane);
}
@@ -12305,6 +12359,14 @@ uint4 WaveMatch(T value)
__target_switch
{
case hlsl: __intrinsic_asm "WaveMatch";
+ case glsl: __intrinsic_asm "subgroupPartitionNV($0)";
+ case spirv:
+ return spirv_asm
+ {
+ OpCapability GroupNonUniformPartitionedNV;
+ OpExtension "SPV_NV_shader_subgroup_partitioned";
+ OpGroupNonUniformPartitionNV $$uint4 result $value
+ };
default:
return WaveMaskMatch(WaveGetActiveMask(), value);
}
@@ -12317,6 +12379,14 @@ uint4 WaveMatch(vector<T,N> value)
__target_switch
{
case hlsl: __intrinsic_asm "WaveMatch";
+ case glsl: __intrinsic_asm "subgroupPartitionNV($0)";
+ case spirv:
+ return spirv_asm
+ {
+ OpCapability GroupNonUniformPartitionedNV;
+ OpExtension "SPV_NV_shader_subgroup_partitioned";
+ OpGroupNonUniformPartitionNV $$uint4 result $value
+ };
default:
return WaveMaskMatch(WaveGetActiveMask(), value);
}
@@ -12329,6 +12399,14 @@ uint4 WaveMatch(matrix<T,N,M> value)
__target_switch
{
case hlsl: __intrinsic_asm "WaveMatch";
+ case glsl:
+ case cuda:
+ case spirv:
+ uint4 result = uint4(0xFFFFFFFF);
+ [ForceUnroll]
+ for (int i = 0; i < N; i++)
+ result &= WaveMatch(value[0]);
+ return result;
default:
return WaveMaskMatch(WaveGetActiveMask(), value);
}