summaryrefslogtreecommitdiff
path: root/source/slang
diff options
context:
space:
mode:
authorDarren Wihandi <65404740+fairywreath@users.noreply.github.com>2025-02-10 19:40:39 -0500
committerGitHub <noreply@github.com>2025-02-10 16:40:39 -0800
commit133bd259c00984c6a01869f71951a7feb919463a (patch)
treea69f1a6b3caff0ac4d958453fde6176ab3c66c91 /source/slang
parentf761ab0586353da67bf7b3ae395ad7b090cd904f (diff)
Add support for Metal subgroup/simd operations (#6247)
* initial work for metal subgroups * add glsl intrinsics * enable wave tests * enable glsl subgroup tests, glsl barrier fixes * minor fixes * fix incorrect test target * disable some glsl functional tests * disable failing glsl test --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/glsl.meta.slang137
-rw-r--r--source/slang/hlsl.meta.slang249
-rw-r--r--source/slang/slang-capabilities.capdef10
3 files changed, 242 insertions, 154 deletions
diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang
index 0bad3c681..6f0ca1bf3 100644
--- a/source/slang/glsl.meta.slang
+++ b/source/slang/glsl.meta.slang
@@ -6525,7 +6525,7 @@ public property uvec4 gl_SubgroupLtMask
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_basic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_basic)]
+[require(cuda_glsl_hlsl_metal_spirv, subgroup_basic)]
public void subgroupBarrier()
{
__target_switch
@@ -6536,6 +6536,8 @@ public void subgroupBarrier()
__intrinsic_asm "AllMemoryBarrierWithGroupSync()";
case glsl:
__intrinsic_asm "subgroupBarrier()";
+ case metal:
+ __intrinsic_asm "simdgroup_barrier(mem_flags::mem_none)";
case spirv:
spirv_asm {
OpCapability Shader;
@@ -6548,7 +6550,7 @@ public void subgroupBarrier()
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_basic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_basic)]
+[require(cuda_glsl_hlsl_metal_spirv, subgroup_basic)]
public void subgroupMemoryBarrier()
{
__target_switch
@@ -6559,6 +6561,8 @@ public void subgroupMemoryBarrier()
__intrinsic_asm "AllMemoryBarrier()";
case glsl:
__intrinsic_asm "subgroupMemoryBarrier()";
+ case metal:
+ __intrinsic_asm "simdgroup_barrier(mem_flags::mem_device)";
case spirv:
spirv_asm {
OpCapability Shader;
@@ -6571,7 +6575,7 @@ public void subgroupMemoryBarrier()
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_basic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_basic)]
+[require(cuda_glsl_hlsl_metal_spirv, subgroup_basic)]
public void subgroupMemoryBarrierBuffer()
{
// the following implementation is NOT the same as DeviceMemoryBarrier
@@ -6584,6 +6588,8 @@ public void subgroupMemoryBarrierBuffer()
__intrinsic_asm "DeviceMemoryBarrier()";
case glsl:
__intrinsic_asm "subgroupMemoryBarrierBuffer()";
+ case metal:
+ __intrinsic_asm "simdgroup_barrier(mem_flags::mem_device)";
case spirv:
spirv_asm {
OpCapability Shader;
@@ -6596,7 +6602,7 @@ public void subgroupMemoryBarrierBuffer()
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_basic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_basic)]
+[require(cuda_glsl_hlsl_metal_spirv, subgroup_basic)]
public void subgroupMemoryBarrierImage()
{
__target_switch
@@ -6607,6 +6613,8 @@ public void subgroupMemoryBarrierImage()
__intrinsic_asm "DeviceMemoryBarrier()";
case glsl:
__intrinsic_asm "subgroupMemoryBarrierImage()";
+ case metal:
+ __intrinsic_asm "simdgroup_barrier(mem_flags::mem_texture)";
case spirv:
spirv_asm {
OpMemoryBarrier Subgroup AcquireRelease|ImageMemory
@@ -6618,7 +6626,7 @@ public void subgroupMemoryBarrierImage()
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_basic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_basic)]
+[require(cuda_glsl_hlsl_metal_spirv, subgroup_basic)]
public void subgroupMemoryBarrierShared()
{
__target_switch
@@ -6629,6 +6637,8 @@ public void subgroupMemoryBarrierShared()
__intrinsic_asm "GroupMemoryBarrier()";
case glsl:
__intrinsic_asm "subgroupMemoryBarrierShared()";
+ case metal:
+ __intrinsic_asm "simdgroup_barrier(mem_flags::mem_threadgroup)";
case spirv:
spirv_asm {
// SubgroupMemory triggers vulkan validation layer error;
@@ -6642,17 +6652,14 @@ public void subgroupMemoryBarrierShared()
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_basic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)]
public bool subgroupElect()
{
__target_switch
{
case cuda:
__intrinsic_asm "( (__activemask() & (__activemask()*-1)) == _getLaneId())";
- case glsl:
- case spirv:
- case hlsl:
- case wgsl:
+ default:
return WaveIsFirstLane();
}
@@ -6663,7 +6670,7 @@ public bool subgroupElect()
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_vote)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_vote)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_vote)]
public bool subgroupAll(bool value)
{
return WaveActiveAllTrue(value);
@@ -6672,7 +6679,7 @@ public bool subgroupAll(bool value)
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_vote)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_vote)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_vote)]
public bool subgroupAny(bool value)
{
return WaveActiveAnyTrue(value);
@@ -6706,7 +6713,7 @@ __generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
public T subgroupAdd(T value)
{
shader_subgroup_preamble<T>();
@@ -6717,7 +6724,7 @@ __generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
public T subgroupMul(T value)
{
shader_subgroup_preamble<T>();
@@ -6728,7 +6735,7 @@ __generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
public T subgroupMin(T value)
{
shader_subgroup_preamble<T>();
@@ -6739,7 +6746,7 @@ __generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
public T subgroupMax(T value)
{
shader_subgroup_preamble<T>();
@@ -6751,7 +6758,7 @@ __spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv_wgsl, subgroup_arithmetic)]
+[require(glsl_metal_spirv_wgsl, subgroup_arithmetic)]
public T subgroupAnd(T value)
{
shader_subgroup_preamble<T>();
@@ -6760,6 +6767,8 @@ public T subgroupAnd(T value)
case glsl:
case wgsl:
__intrinsic_asm "subgroupAnd($0)";
+ case metal:
+ __intrinsic_asm "simd_and";
case spirv:
if (__isBool<T>()) {
return spirv_asm {
@@ -6781,15 +6790,17 @@ __spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv_wgsl, subgroup_arithmetic)]
+[require(glsl_metal_spirv_wgsl, subgroup_arithmetic)]
public T subgroupOr(T value)
{
shader_subgroup_preamble<T>();
__target_switch
{
- case glsl:
+ case glsl:
case wgsl:
__intrinsic_asm "subgroupOr($0)";
+ case metal:
+ __intrinsic_asm "simd_or";
case spirv:
if (__isBool<T>()) {
return spirv_asm {
@@ -6811,7 +6822,7 @@ __spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv_wgsl, subgroup_arithmetic)]
+[require(glsl_metal_spirv_wgsl, subgroup_arithmetic)]
public T subgroupXor(T value)
{
shader_subgroup_preamble<T>();
@@ -6820,6 +6831,8 @@ public T subgroupXor(T value)
case glsl:
case wgsl:
__intrinsic_asm "subgroupXor($0)";
+ case metal:
+ __intrinsic_asm "simd_xor";
case spirv:
if (__isBool<T>()) {
return spirv_asm {
@@ -6841,7 +6854,7 @@ __spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv_wgsl, subgroup_arithmetic)]
+[require(glsl_metal_spirv_wgsl, subgroup_arithmetic)]
public T subgroupInclusiveAdd(T value)
{
shader_subgroup_preamble<T>();
@@ -6850,6 +6863,8 @@ public T subgroupInclusiveAdd(T value)
case glsl:
case wgsl:
__intrinsic_asm "subgroupInclusiveAdd($0)";
+ case metal:
+ __intrinsic_asm "simd_prefix_inclusive_sum";
case spirv:
if (__isFloat<T>())
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$T result Subgroup InclusiveScan $value};
@@ -6864,7 +6879,7 @@ __spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv_wgsl, subgroup_arithmetic)]
+[require(glsl_metal_spirv_wgsl, subgroup_arithmetic)]
public T subgroupInclusiveMul(T value)
{
shader_subgroup_preamble<T>();
@@ -6873,6 +6888,8 @@ public T subgroupInclusiveMul(T value)
case glsl:
case wgsl:
__intrinsic_asm "subgroupInclusiveMul($0)";
+ case metal:
+ __intrinsic_asm "simd_prefix_inclusive_product";
case spirv:
if (__isFloat<T>())
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMul $$T result Subgroup InclusiveScan $value};
@@ -7005,7 +7022,7 @@ __generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
public T subgroupExclusiveAdd(T value)
{
shader_subgroup_preamble<T>();
@@ -7017,7 +7034,7 @@ __generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
public T subgroupExclusiveMul(T value)
{
shader_subgroup_preamble<T>();
@@ -7128,7 +7145,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupAdd(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7139,7 +7156,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupMul(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7150,7 +7167,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupMin(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7161,7 +7178,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupMax(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7173,7 +7190,7 @@ __spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv_wgsl, subgroup_arithmetic)]
+[require(glsl_metal_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupAnd(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7181,8 +7198,10 @@ public vector<T,N> subgroupAnd(vector<T,N> value)
{
case glsl:
case wgsl:
- // TODO: Bool inputs are invalid for WGSL, cast them to int or don't allow them to compile.
+ // TODO: Bool inputs are invalid for Metal and WGSL, cast them to int or don't allow them to compile.
__intrinsic_asm "subgroupAnd($0)";
+ case metal:
+ __intrinsic_asm "simd_and";
case spirv:
if (__isBool<T>()) {
return spirv_asm {
@@ -7205,7 +7224,7 @@ __spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv_wgsl, subgroup_arithmetic)]
+[require(glsl_metal_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupOr(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7214,6 +7233,8 @@ public vector<T,N> subgroupOr(vector<T,N> value)
case glsl:
case wgsl:
__intrinsic_asm "subgroupOr($0)";
+ case metal:
+ __intrinsic_asm "simd_or";
case spirv:
if (__isBool<T>()) {
return spirv_asm {
@@ -7236,7 +7257,7 @@ __spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv_wgsl, subgroup_arithmetic)]
+[require(glsl_metal_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupXor(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7245,6 +7266,8 @@ public vector<T,N> subgroupXor(vector<T,N> value)
case glsl:
case wgsl:
__intrinsic_asm "subgroupXor($0)";
+ case metal:
+ __intrinsic_asm "simd_xor";
case spirv:
if (__isBool<T>()) {
return spirv_asm {
@@ -7266,7 +7289,7 @@ __spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv_wgsl, subgroup_arithmetic)]
+[require(glsl_metal_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupInclusiveAdd(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7275,6 +7298,8 @@ public vector<T,N> subgroupInclusiveAdd(vector<T,N> value)
case glsl:
case wgsl:
__intrinsic_asm "subgroupInclusiveAdd($0)";
+ case metal:
+ __intrinsic_asm "simd_prefix_inclusive_sum";
case spirv:
if (__isFloat<T>())
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$vector<T,N> result Subgroup InclusiveScan $value};
@@ -7289,7 +7314,7 @@ __spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv_wgsl, subgroup_arithmetic)]
+[require(glsl_metal_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupInclusiveMul(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7298,6 +7323,8 @@ public vector<T,N> subgroupInclusiveMul(vector<T,N> value)
case glsl:
case wgsl:
__intrinsic_asm "subgroupInclusiveMul($0)";
+ case metal:
+ __intrinsic_asm "simd_prefix_inclusive_product";
case spirv:
if (__isFloat<T>())
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMul $$vector<T,N> result Subgroup InclusiveScan $value};
@@ -7411,7 +7438,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupExclusiveAdd(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7423,7 +7450,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupExclusiveMul(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7533,7 +7560,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_ballot)]
public T subgroupBroadcast(T value, uint id)
{
shader_subgroup_preamble<T>();
@@ -7551,7 +7578,7 @@ __generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_ballot)]
public vector<T,N> subgroupBroadcast(vector<T,N> value, uint id)
{
shader_subgroup_preamble<T>();
@@ -7569,7 +7596,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_ballot)]
public T subgroupBroadcastFirst(T value)
{
shader_subgroup_preamble<T>();
@@ -7580,7 +7607,7 @@ __generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_ballot)]
public vector<T,N> subgroupBroadcastFirst(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7591,7 +7618,7 @@ public vector<T,N> subgroupBroadcastFirst(vector<T,N> value)
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_ballot)]
public uvec4 subgroupBallot(bool value)
{
return WaveActiveBallot(value);
@@ -7772,7 +7799,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_shuffle)]
public T subgroupShuffle(T value, uint index)
{
shader_subgroup_preamble<T>();
@@ -7783,7 +7810,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
__wgsl_extension(subgroups)
-[require(glsl_spirv_wgsl, subgroup_shuffle)]
+[require(glsl_metal_spirv_wgsl, subgroup_shuffle)]
[ForceInline] public T subgroupShuffleXor(T value, uint mask)
{
shader_subgroup_preamble<T>();
@@ -7792,6 +7819,8 @@ __wgsl_extension(subgroups)
case glsl:
case wgsl:
__intrinsic_asm "subgroupShuffleXor($0,$1)";
+ case metal:
+ __intrinsic_asm "simd_shuffle_xor($0, ushort($1))";
case spirv:
return spirv_asm {
OpCapability GroupNonUniformBallot;
@@ -7804,7 +7833,7 @@ __generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_shuffle)]
public vector<T,N> subgroupShuffle(vector<T,N> value, uint index)
{
shader_subgroup_preamble<T>();
@@ -7816,7 +7845,7 @@ __spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv_wgsl, subgroup_shuffle)]
+[require(glsl_metal_spirv_wgsl, subgroup_shuffle)]
public vector<T,N> subgroupShuffleXor(vector<T,N> value, uint mask)
{
shader_subgroup_preamble<T>();
@@ -7825,6 +7854,8 @@ public vector<T,N> subgroupShuffleXor(vector<T,N> value, uint mask)
case glsl:
case wgsl:
__intrinsic_asm "subgroupShuffleXor($0,$1)";
+ case metal:
+ __intrinsic_asm "simd_shuffle_xor($0, ushort($1))";
case spirv:
return spirv_asm {
OpCapability GroupNonUniformBallot;
@@ -7841,7 +7872,7 @@ __spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle_relative)
__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv_wgsl, subgroup_shufflerelative)]
+[require(glsl_metal_spirv_wgsl, subgroup_shufflerelative)]
public T subgroupShuffleUp(T value, uint delta)
{
shader_subgroup_preamble<T>();
@@ -7850,6 +7881,8 @@ public T subgroupShuffleUp(T value, uint delta)
case glsl:
case wgsl:
__intrinsic_asm "subgroupShuffleUp($0, $1)";
+ case metal:
+ __intrinsic_asm "simd_shuffle_up($0, ushort($1))";
case spirv:
return spirv_asm {
OpCapability GroupNonUniformShuffleRelative;
@@ -7863,7 +7896,7 @@ __spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle_relative)
__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv_wgsl, subgroup_shufflerelative)]
+[require(glsl_metal_spirv_wgsl, subgroup_shufflerelative)]
public T subgroupShuffleDown(T value, uint delta)
{
shader_subgroup_preamble<T>();
@@ -7872,6 +7905,8 @@ public T subgroupShuffleDown(T value, uint delta)
case glsl:
case wgsl:
__intrinsic_asm "subgroupShuffleDown($0, $1)";
+ case metal:
+ __intrinsic_asm "simd_shuffle_down($0, ushort($1))";
case spirv:
return spirv_asm {
OpCapability GroupNonUniformShuffleRelative;
@@ -7886,7 +7921,7 @@ __spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle_relative)
__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv_wgsl, subgroup_shufflerelative)]
+[require(glsl_metal_spirv_wgsl, subgroup_shufflerelative)]
public vector<T,N> subgroupShuffleUp(vector<T,N> value, uint delta)
{
shader_subgroup_preamble<T>();
@@ -7895,6 +7930,8 @@ public vector<T,N> subgroupShuffleUp(vector<T,N> value, uint delta)
case glsl:
case wgsl:
__intrinsic_asm "subgroupShuffleUp($0, $1)";
+ case metal:
+ __intrinsic_asm "simd_shuffle_up($0, ushort($1))";
case spirv:
return spirv_asm {
OpCapability GroupNonUniformShuffleRelative;
@@ -7908,7 +7945,7 @@ __spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle_relative)
__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv_wgsl, subgroup_shufflerelative)]
+[require(glsl_metal_spirv_wgsl, subgroup_shufflerelative)]
public vector<T,N> subgroupShuffleDown(vector<T,N> value, uint delta)
{
shader_subgroup_preamble<T>();
@@ -7917,6 +7954,8 @@ public vector<T,N> subgroupShuffleDown(vector<T,N> value, uint delta)
case glsl:
case wgsl:
__intrinsic_asm "subgroupShuffleDown($0, $1)";
+ case metal:
+ __intrinsic_asm "simd_shuffle_down($0, ushort($1))";
case spirv:
return spirv_asm {
OpCapability GroupNonUniformShuffleRelative;
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 491f0ef4d..884621960 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -9782,7 +9782,7 @@ void __subgroupBarrier()
case glsl: __intrinsic_asm "subgroupBarrier";
case hlsl: __intrinsic_asm "GroupMemoryBarrierWithGroupSync";
case cuda: __intrinsic_asm "__syncthreads()";
- case metal: __intrinsic_asm "threadgroup_barrier(mem_flags::mem_threadgroup)";
+ case metal: __intrinsic_asm "simdgroup_barrier(mem_flags::none)";
case spirv:
spirv_asm
{
@@ -14423,7 +14423,7 @@ matrix<T,N,M> WaveMaskPrefixBitXor(WaveMask mask, matrix<T,N,M> expr)
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
-[require(glsl_hlsl_spirv, subgroup_quad)]
+[require(glsl_hlsl_metal_spirv, subgroup_quad)]
T QuadReadLaneAt(T sourceValue, uint quadLaneID)
{
__target_switch
@@ -14432,6 +14432,9 @@ T QuadReadLaneAt(T sourceValue, uint quadLaneID)
__intrinsic_asm "QuadReadLaneAt";
case glsl:
__intrinsic_asm "subgroupQuadBroadcast";
+ case metal:
+ // TODO: Need to add intrinsics to access Metal and WGSL's broadcast variant where lane is const for all threads.
+ __intrinsic_asm "quad_shuffle($0, ushort($1))";
case spirv:
return spirv_asm {
OpCapability GroupNonUniformQuad;
@@ -14442,7 +14445,7 @@ T QuadReadLaneAt(T sourceValue, uint quadLaneID)
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
-[require(glsl_hlsl_spirv, subgroup_quad)]
+[require(glsl_hlsl_metal_spirv, subgroup_quad)]
vector<T,N> QuadReadLaneAt(vector<T,N> sourceValue, uint quadLaneID)
{
__target_switch
@@ -14451,6 +14454,8 @@ vector<T,N> QuadReadLaneAt(vector<T,N> sourceValue, uint quadLaneID)
__intrinsic_asm "QuadReadLaneAt";
case glsl:
__intrinsic_asm "subgroupQuadBroadcast";
+ case metal:
+ __intrinsic_asm "quad_shuffle($0, ushort($1))";
case spirv:
return spirv_asm {
OpCapability GroupNonUniformQuad;
@@ -14598,8 +14603,8 @@ __generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadAcr
// WaveActiveBitAnd, WaveActiveBitOr, WaveActiveBitXor
${{{{
-struct WaveActiveBitOpEntry { const char* hlslName; const char* glslName; const char* spirvName; };
-const WaveActiveBitOpEntry kWaveActiveBitOpEntries[] = {{"BitAnd", "And", "BitwiseAnd"}, {"BitOr", "Or", "BitwiseOr"}, {"BitXor", "Xor", "BitwiseXor"}};
+struct WaveActiveBitOpEntry { const char* hlslName; const char* glslName; const char* spirvName; const char* metalName; };
+const WaveActiveBitOpEntry kWaveActiveBitOpEntries[] = {{"BitAnd", "And", "BitwiseAnd", "and"}, {"BitOr", "Or", "BitwiseOr", "or"}, {"BitXor", "Xor", "BitwiseXor", "xor"}};
for (auto opName : kWaveActiveBitOpEntries) {
}}}}
/// @category wave Wave and quad functions
@@ -14607,7 +14612,7 @@ __generic<T : __BuiltinIntegerType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
T WaveActive$(opName.hlslName)(T expr)
{
__target_switch
@@ -14615,7 +14620,10 @@ T WaveActive$(opName.hlslName)(T expr)
case glsl:
case wgsl:
__intrinsic_asm "subgroup$(opName.glslName)";
- case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
+ case hlsl:
+ __intrinsic_asm "WaveActive$(opName.hlslName)";
+ case metal:
+ __intrinsic_asm "simd_$(opName.metalName)";
case spirv:
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniform$(opName.spirvName) $$T result Subgroup Reduce $expr};
case cuda:
@@ -14627,7 +14635,7 @@ __generic<T : __BuiltinIntegerType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
vector<T, N> WaveActive$(opName.hlslName)(vector<T, N> expr)
{
__target_switch
@@ -14635,7 +14643,10 @@ vector<T, N> WaveActive$(opName.hlslName)(vector<T, N> expr)
case glsl:
case wgsl:
__intrinsic_asm "subgroup$(opName.glslName)";
- case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
+ case hlsl:
+ __intrinsic_asm "WaveActive$(opName.hlslName)";
+ case metal:
+ __intrinsic_asm "simd_$(opName.metalName)";
case spirv:
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniform$(opName.spirvName) $$vector<T, N> result Subgroup Reduce $expr};
case cuda:
@@ -14644,22 +14655,21 @@ vector<T, N> WaveActive$(opName.hlslName)(vector<T, N> expr)
}
__generic<T : __BuiltinIntegerType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
matrix<T, N, M> WaveActive$(opName.hlslName)(matrix<T, N, M> expr)
{
__target_switch
{
- case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
- case glsl:
- case spirv:
- case wgsl:
+ case cuda:
+ return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
+ case hlsl:
+ __intrinsic_asm "WaveActive$(opName.hlslName)";
+ default:
matrix<T,N,M> result;
[ForceUnroll]
for (int i = 0; i < N; ++i)
result[i] = WaveActive$(opName.hlslName)(expr[i]);
return result;
- case cuda:
- return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
}
}
${{{{
@@ -14668,32 +14678,36 @@ ${{{{
// WaveActiveMin/Max
${{{{
-const char* kWaveActiveMinMaxNames[] = {"Min", "Max"};
-for (const char* opName : kWaveActiveMinMaxNames) {
+struct WaveActiveMinMaxEntry { const char* name; const char* metalName; };
+const WaveActiveMinMaxEntry kWaveActiveMinMaxNames[] = {{"Min", "min"}, {"Max", "max"}};
+for (const auto opName : kWaveActiveMinMaxNames) {
}}}}
/// @category wave
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
-T WaveActive$(opName)(T expr)
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
+T WaveActive$(opName.name)(T expr)
{
__target_switch
{
- case glsl:
+ case glsl:
case wgsl:
- __intrinsic_asm "subgroup$(opName)";
- case hlsl: __intrinsic_asm "WaveActive$(opName)";
+ __intrinsic_asm "subgroup$(opName.name)";
+ case hlsl:
+ __intrinsic_asm "WaveActive$(opName.name)";
+ case metal:
+ __intrinsic_asm "simd_$(opName.metalName)";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformF$(opName) $$T result Subgroup Reduce $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformF$(opName.name) $$T result Subgroup Reduce $expr};
else if (__isUnsignedInt<T>())
- return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformU$(opName) $$T result Subgroup Reduce $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformU$(opName.name) $$T result Subgroup Reduce $expr};
else
- return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformS$(opName) $$T result Subgroup Reduce $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformS$(opName.name) $$T result Subgroup Reduce $expr};
case cuda:
- return WaveMask$(opName)(WaveGetActiveMask(), expr);
+ return WaveMask$(opName.name)(WaveGetActiveMask(), expr);
}
}
@@ -14701,44 +14715,46 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
-vector<T, N> WaveActive$(opName)(vector<T, N> expr)
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
+vector<T, N> WaveActive$(opName.name)(vector<T, N> expr)
{
__target_switch
{
- case glsl:
+ case glsl:
case wgsl:
- __intrinsic_asm "subgroup$(opName)";
- case hlsl: __intrinsic_asm "WaveActive$(opName)";
+ __intrinsic_asm "subgroup$(opName.name)";
+ case hlsl:
+ __intrinsic_asm "WaveActive$(opName.name)";
+ case metal:
+ __intrinsic_asm "simd_$(opName.metalName)";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformF$(opName) $$vector<T, N> result Subgroup Reduce $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformF$(opName.name) $$vector<T, N> result Subgroup Reduce $expr};
else if (__isUnsignedInt<T>())
- return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformU$(opName) $$vector<T, N> result Subgroup Reduce $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformU$(opName.name) $$vector<T, N> result Subgroup Reduce $expr};
else
- return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformS$(opName) $$vector<T, N> result Subgroup Reduce $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformS$(opName.name) $$vector<T, N> result Subgroup Reduce $expr};
case cuda:
- return WaveMask$(opName)(WaveGetActiveMask(), expr);
+ return WaveMask$(opName.name)(WaveGetActiveMask(), expr);
}
}
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
-matrix<T, N, M> WaveActive$(opName)(matrix<T, N, M> expr)
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
+matrix<T, N, M> WaveActive$(opName.name)(matrix<T, N, M> expr)
{
__target_switch
{
- case hlsl: __intrinsic_asm "WaveActive$(opName)";
- case glsl:
- case spirv:
- case wgsl:
+ case cuda:
+ return WaveMask$(opName.name)(WaveGetActiveMask(), expr);
+ case hlsl:
+ __intrinsic_asm "WaveActive$(opName.name)";
+ default:
matrix<T, N, M> result;
[ForceUnroll]
for (int i = 0; i < N; ++i)
- result[i] = WaveActive$(opName)(expr[i]);
+ result[i] = WaveActive$(opName.name)(expr[i]);
return result;
- case cuda:
- return WaveMask$(opName)(WaveGetActiveMask(), expr);
}
}
@@ -14748,8 +14764,8 @@ ${{{{
// WaveActiveProduct/Sum
${{{{
-struct WaveActiveProductSumEntry { const char* hlslName; const char* glslName; };
-const WaveActiveProductSumEntry kWaveActivProductSumNames[] = {{"Product", "Mul"}, {"Sum", "Add"}};
+struct WaveActiveProductSumEntry { const char* hlslName; const char* glslName; const char* metalName; };
+const WaveActiveProductSumEntry kWaveActivProductSumNames[] = {{"Product", "Mul", "product"}, {"Sum", "Add", "sum"}};
for (auto opName : kWaveActivProductSumNames) {
}}}}
/// @category wave
@@ -14757,7 +14773,7 @@ __generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
T WaveActive$(opName.hlslName)(T expr)
{
__target_switch
@@ -14766,6 +14782,7 @@ T WaveActive$(opName.hlslName)(T expr)
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
__intrinsic_asm "subgroup$(opName.glslName)($0)";
case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
+ case metal: __intrinsic_asm "simd_$(opName.metalName)";
case spirv:
if (__isFloat<T>())
return spirv_asm {
@@ -14791,7 +14808,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
vector<T,N> WaveActive$(opName.hlslName)(vector<T,N> expr)
{
__target_switch
@@ -14800,6 +14817,7 @@ vector<T,N> WaveActive$(opName.hlslName)(vector<T,N> expr)
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
__intrinsic_asm "subgroup$(opName.glslName)($0)";
case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
+ case metal: __intrinsic_asm "simd_$(opName.metalName)";
case spirv:
if (__isFloat<T>())
return spirv_asm {
@@ -14822,27 +14840,27 @@ vector<T,N> WaveActive$(opName.hlslName)(vector<T,N> expr)
}
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
matrix<T, N, M> WaveActive$(opName.hlslName)(matrix<T, N, M> expr)
{
__target_switch
{
- case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
- case glsl:
- case spirv:
- case wgsl:
+ case cuda:
+ return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
+ case hlsl:
+ __intrinsic_asm "WaveActive$(opName.hlslName)";
+ default:
matrix<T, N, M> result;
[ForceUnroll]
for (int i = 0; i < N; ++i)
result[i] = WaveActive$(opName.hlslName)(expr[i]);
return result;
- case cuda:
- return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
}
}
${{{{
} // WaveActiveProduct/WaveActiveProductSum.
}}}}
+
/// @category wave
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_vote)
@@ -14906,7 +14924,7 @@ bool WaveActiveAllEqual(matrix<T, N, M> value)
__glsl_extension(GL_KHR_shader_subgroup_vote)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_vote)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_vote)]
bool WaveActiveAllTrue(bool condition)
{
__target_switch
@@ -14914,7 +14932,10 @@ bool WaveActiveAllTrue(bool condition)
case glsl:
case wgsl:
__intrinsic_asm "subgroupAll";
- case hlsl: __intrinsic_asm "WaveActiveAllTrue($0)";
+ case hlsl:
+ __intrinsic_asm "WaveActiveAllTrue($0)";
+ case metal:
+ __intrinsic_asm "simd_all";
case spirv:
return spirv_asm
{
@@ -14930,7 +14951,7 @@ bool WaveActiveAllTrue(bool condition)
__glsl_extension(GL_KHR_shader_subgroup_vote)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_vote)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_vote)]
bool WaveActiveAnyTrue(bool condition)
{
__target_switch
@@ -14940,6 +14961,8 @@ bool WaveActiveAnyTrue(bool condition)
__intrinsic_asm "subgroupAny";
case hlsl:
__intrinsic_asm "WaveActiveAnyTrue($0)";
+ case metal:
+ __intrinsic_asm "simd_any";
case spirv:
return spirv_asm
{
@@ -14951,12 +14974,28 @@ bool WaveActiveAnyTrue(bool condition)
}
}
+
+//@hidden:
+[ForceInline]
+uint64_t __metal_simd_ballot(bool expr)
+{
+ __intrinsic_asm "uint64_t(simd_ballot($0))";
+}
+
+[ForceInline]
+uint4 __metal_simd_vote_mask_to_uint4(uint64_t mask)
+{
+ return uint4(uint(mask & 0xFFFFFFFF), uint(mask >> 32), 0, 0);
+}
+
+//@public:
+
/// @category wave
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
__wgsl_extension(subgroups)
[NonUniformReturn]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_ballot)]
uint4 WaveActiveBallot(bool condition)
{
__target_switch
@@ -14966,6 +15005,7 @@ uint4 WaveActiveBallot(bool condition)
__intrinsic_asm "subgroupBallot";
case hlsl:
__intrinsic_asm "WaveActiveBallot";
+ case metal: return __metal_simd_vote_mask_to_uint4(__metal_simd_ballot(condition));
case spirv:
return spirv_asm
{
@@ -15039,13 +15079,14 @@ __glsl_extension(GL_KHR_shader_subgroup_basic)
__spirv_version(1.3)
__wgsl_extension(subgroups)
[NonUniformReturn]
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_basic)]
bool WaveIsFirstLane()
{
__target_switch
{
case glsl: __intrinsic_asm "subgroupElect()";
case hlsl: __intrinsic_asm "WaveIsFirstLane()";
+ case metal: __intrinsic_asm "simd_is_first";
case spirv:
return spirv_asm
{
@@ -15093,7 +15134,7 @@ __generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
T WavePrefixProduct(T expr)
{
__target_switch
@@ -15102,6 +15143,7 @@ T WavePrefixProduct(T expr)
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
__intrinsic_asm "subgroupExclusiveMul($0)";
case hlsl: __intrinsic_asm "WavePrefixProduct";
+ case metal: __intrinsic_asm "simd_prefix_exclusive_product";
case spirv:
if (__isFloat<T>())
return spirv_asm {
@@ -15128,7 +15170,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
vector<T,N> WavePrefixProduct(vector<T,N> expr)
{
__target_switch
@@ -15137,6 +15179,7 @@ vector<T,N> WavePrefixProduct(vector<T,N> expr)
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
__intrinsic_asm "subgroupExclusiveMul($0)";
case hlsl: __intrinsic_asm "WavePrefixProduct";
+ case metal: __intrinsic_asm "simd_prefix_exclusive_product";
case spirv:
if (__isFloat<T>())
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMul $$vector<T,N> result Subgroup ExclusiveScan $expr};
@@ -15161,16 +15204,15 @@ matrix<T, N, M> WavePrefixProduct(matrix<T, N, M> expr)
{
__target_switch
{
- case hlsl: __intrinsic_asm "WavePrefixProduct";
- case glsl:
- case spirv:
- case wgsl:
+ case cuda:
+ return WaveMaskPrefixProduct(WaveGetActiveMask(), expr);
+ case hlsl:
+ __intrinsic_asm "WavePrefixProduct";
+ default:
matrix<T, N, M> result;
for (int i = 0; i < N; ++i)
result[i] = WavePrefixProduct(expr[i]);
return result;
- case cuda:
- return WaveMaskPrefixProduct(WaveGetActiveMask(), expr);
}
}
@@ -15179,7 +15221,7 @@ __generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
T WavePrefixSum(T expr)
{
__target_switch
@@ -15188,6 +15230,7 @@ T WavePrefixSum(T expr)
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
__intrinsic_asm "subgroupExclusiveAdd($0)";
case hlsl: __intrinsic_asm "WavePrefixSum";
+ case metal: __intrinsic_asm "simd_prefix_exclusive_sum";
case spirv:
if (__isFloat<T>())
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$T result Subgroup ExclusiveScan $expr};
@@ -15210,7 +15253,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
vector<T,N> WavePrefixSum(vector<T,N> expr)
{
__target_switch
@@ -15219,6 +15262,7 @@ vector<T,N> WavePrefixSum(vector<T,N> expr)
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
__intrinsic_asm "subgroupExclusiveAdd($0)";
case hlsl: __intrinsic_asm "WavePrefixSum";
+ case metal: __intrinsic_asm "simd_prefix_exclusive_sum";
case spirv:
if (__isFloat<T>())
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$vector<T,N> result Subgroup ExclusiveScan $expr};
@@ -15238,21 +15282,20 @@ vector<T,N> WavePrefixSum(vector<T,N> expr)
}
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_arithmetic)]
matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr)
{
__target_switch
{
- case hlsl: __intrinsic_asm "WavePrefixSum";
- case glsl:
- case spirv:
- case wgsl:
+ case cuda:
+ return WaveMaskPrefixSum(WaveGetActiveMask(), expr);
+ case hlsl:
+ __intrinsic_asm "WavePrefixSum";
+ default:
matrix<T, N, M> result;
for (int i = 0; i < N; ++i)
result[i] = WavePrefixSum(expr[i]);
return result;
- case cuda:
- return WaveMaskPrefixSum(WaveGetActiveMask(), expr);
}
}
@@ -15261,7 +15304,7 @@ __generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_ballot)]
T WaveReadLaneFirst(T expr)
{
__target_switch
@@ -15270,6 +15313,7 @@ T WaveReadLaneFirst(T expr)
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
__intrinsic_asm "subgroupBroadcastFirst($0)";
case hlsl: __intrinsic_asm "WaveReadLaneFirst";
+ case metal: __intrinsic_asm "simd_broadcast_first";
case spirv:
return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$T result Subgroup $expr};
case wgsl: __intrinsic_asm "subgroupBroadcastFirst";
@@ -15282,7 +15326,7 @@ __generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_ballot)]
vector<T,N> WaveReadLaneFirst(vector<T,N> expr)
{
__target_switch
@@ -15291,6 +15335,7 @@ vector<T,N> WaveReadLaneFirst(vector<T,N> expr)
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
__intrinsic_asm "subgroupBroadcastFirst($0)";
case hlsl: __intrinsic_asm "WaveReadLaneFirst";
+ case metal: __intrinsic_asm "simd_broadcast_first";
case spirv:
return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$vector<T,N> result Subgroup $expr};
case wgsl: __intrinsic_asm "subgroupBroadcastFirst";
@@ -15300,21 +15345,19 @@ vector<T,N> WaveReadLaneFirst(vector<T,N> expr)
}
__generic<T : __BuiltinType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_ballot)]
matrix<T,N,M> WaveReadLaneFirst(matrix<T,N,M> expr)
{
__target_switch
{
+ case cuda:
+ return WaveMaskReadLaneFirst(WaveGetActiveMask(), expr);
case hlsl: __intrinsic_asm "WaveReadLaneFirst";
- case glsl:
- case spirv:
- case wgsl:
+ default:
matrix<T, N, M> result;
for (int i = 0; i < N; ++i)
result[i] = WaveReadLaneFirst(expr[i]);
return result;
- case cuda:
- return WaveMaskReadLaneFirst(WaveGetActiveMask(), expr);
}
}
@@ -15329,7 +15372,7 @@ __generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_ballot)]
T WaveBroadcastLaneAt(T value, constexpr int lane)
{
__target_switch
@@ -15338,6 +15381,7 @@ T WaveBroadcastLaneAt(T value, constexpr int lane)
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
__intrinsic_asm "subgroupBroadcast($0, $1)";
case hlsl: __intrinsic_asm "WaveReadLaneAt";
+ case metal: __intrinsic_asm "simd_broadcast($0, ushort($1))";
case spirv:
let ulane = uint(lane);
return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcast $$T result Subgroup $value $ulane};
@@ -15352,7 +15396,7 @@ __generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_ballot)]
vector<T,N> WaveBroadcastLaneAt(vector<T,N> value, constexpr int lane)
{
__target_switch
@@ -15361,6 +15405,7 @@ vector<T,N> WaveBroadcastLaneAt(vector<T,N> value, constexpr int lane)
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
__intrinsic_asm "subgroupBroadcast($0, $1)";
case hlsl: __intrinsic_asm "WaveReadLaneAt";
+ case metal: __intrinsic_asm "simd_broadcast($0, ushort($1))";
case spirv:
let ulane = uint(lane);
return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcast $$vector<T,N> result Subgroup $value $ulane};
@@ -15371,22 +15416,18 @@ vector<T,N> WaveBroadcastLaneAt(vector<T,N> value, constexpr int lane)
}
__generic<T : __BuiltinType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_ballot)]
matrix<T, N, M> WaveBroadcastLaneAt(matrix<T, N, M> value, constexpr int lane)
{
__target_switch
{
case cuda: __intrinsic_asm "_waveShuffleMultiple(_getActiveMask(), $0, $1)";
case hlsl: __intrinsic_asm "WaveReadLaneAt";
- case glsl:
- case spirv:
- case wgsl:
+ default:
matrix<T, N, M> result;
for (int i = 0; i < N; ++i)
result[i] = WaveBroadcastLaneAt(value[i], lane);
return result;
- case cuda:
- return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, lane);
}
}
@@ -15397,7 +15438,7 @@ __generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_shuffle)]
T WaveReadLaneAt(T value, int lane)
{
__target_switch
@@ -15406,6 +15447,7 @@ T WaveReadLaneAt(T value, int lane)
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
__intrinsic_asm "subgroupShuffle($0, $1)";
case hlsl: __intrinsic_asm "WaveReadLaneAt";
+ case metal: __intrinsic_asm "simd_shuffle($0, ushort($1))";
case spirv:
let ulane = uint(lane);
return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$T result Subgroup $value $ulane};
@@ -15419,7 +15461,7 @@ __generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_shuffle)]
vector<T,N> WaveReadLaneAt(vector<T,N> value, int lane)
{
__target_switch
@@ -15428,6 +15470,7 @@ vector<T,N> WaveReadLaneAt(vector<T,N> value, int lane)
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
__intrinsic_asm "subgroupShuffle($0, $1)";
case hlsl: __intrinsic_asm "WaveReadLaneAt";
+ case metal: __intrinsic_asm "simd_shuffle($0, ushort($1))";
case spirv:
let ulane = uint(lane);
return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$vector<T,N> result Subgroup $value $ulane};
@@ -15438,22 +15481,18 @@ vector<T,N> WaveReadLaneAt(vector<T,N> value, int lane)
}
__generic<T : __BuiltinType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_shuffle)]
matrix<T, N, M> WaveReadLaneAt(matrix<T, N, M> value, int lane)
{
__target_switch
{
case cuda: __intrinsic_asm "_waveShuffleMultiple(_getActiveMask(), $0, $1)";
case hlsl: __intrinsic_asm "WaveReadLaneAt";
- case glsl:
- case spirv:
- case wgsl:
+ default:
matrix<T,N,M> result;
for (int i = 0; i < N; ++i)
result[i] = WaveReadLaneAt(value[i], lane);
return result;
- case cuda:
- return WaveMaskReadLaneAt(WaveGetActiveMask(), value, lane);
}
}
@@ -15465,7 +15504,7 @@ __generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_shuffle)]
T WaveShuffle(T value, int lane)
{
__target_switch
@@ -15474,6 +15513,7 @@ T WaveShuffle(T value, int lane)
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
__intrinsic_asm "subgroupShuffle($0, $1)";
case hlsl: __intrinsic_asm "WaveReadLaneAt";
+ case metal: __intrinsic_asm "simd_shuffle($0, ushort($1))";
case spirv:
let ulane = uint(lane);
return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$T result Subgroup $value $ulane};
@@ -15488,7 +15528,7 @@ __generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_shuffle)]
vector<T,N> WaveShuffle(vector<T,N> value, int lane)
{
__target_switch
@@ -15497,6 +15537,7 @@ vector<T,N> WaveShuffle(vector<T,N> value, int lane)
if (__isHalf<T>()) __requireGLSLExtension("GL_EXT_shader_subgroup_extended_types_float16");
__intrinsic_asm "subgroupShuffle($0, $1)";
case hlsl: __intrinsic_asm "WaveReadLaneAt";
+ case metal: __intrinsic_asm "simd_shuffle($0, ushort($1))";
case spirv:
let ulane = uint(lane);
return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$vector<T,N> result Subgroup $value $ulane};
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index f98be0e32..130439fe1 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -1939,10 +1939,11 @@ alias shader5_sm_5_0 = GL_ARB_gpu_shader5 | sm_5_0_version;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_basic'
/// [Compound]
-alias subgroup_basic = GL_KHR_shader_subgroup_basic
+alias subgroup_basic = GL_KHR_shader_subgroup_basic
| _sm_6_0
| _cuda_sm_7_0
| wgsl
+ | metal
;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_ballot'
/// [Compound]
@@ -1951,6 +1952,7 @@ alias subgroup_ballot = spirv_1_0 + GL_KHR_shader_subgroup_ballot
| _sm_6_0 + shader5_sm_5_0
| _cuda_sm_7_0 + shader5_sm_5_0
| wgsl
+ | metal
;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_ballot_activemask'
/// [Compound]
@@ -1966,6 +1968,7 @@ alias subgroup_basic_ballot = glsl + GL_KHR_shader_subgroup_basic + subgroup_bal
| hlsl + subgroup_ballot
| cuda + subgroup_ballot
| wgsl
+ | metal
;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_vote'
/// [Compound]
@@ -1973,6 +1976,7 @@ alias subgroup_vote = GL_KHR_shader_subgroup_vote
| _sm_6_0
| _cuda_sm_7_0
| wgsl
+ | metal
;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_vote'
/// [Compound]
@@ -1983,6 +1987,7 @@ alias subgroup_arithmetic = GL_KHR_shader_subgroup_arithmetic
| _sm_6_0
| _cuda_sm_7_0
| wgsl
+ | metal
;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_shuffle'
@@ -1991,6 +1996,7 @@ alias subgroup_shuffle = GL_KHR_shader_subgroup_shuffle
| _sm_6_0
| _cuda_sm_7_0
| wgsl
+ | metal
;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_shuffle_relative'
/// [Compound]
@@ -1998,6 +2004,7 @@ alias subgroup_shufflerelative = GL_KHR_shader_subgroup_shuffle_relative
| _sm_6_0
| _cuda_sm_7_0
| wgsl
+ | metal
;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_clustered'
/// [Compound]
@@ -2008,6 +2015,7 @@ alias subgroup_quad = GL_KHR_shader_subgroup_quad
| _sm_6_0
| _cuda_sm_7_0
| wgsl
+ | metal
;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_partitioned'
/// [Compound]