summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorTianyu Li <ltyucb@gmail.com>2025-08-21 01:32:13 +0800
committerGitHub <noreply@github.com>2025-08-20 17:32:13 +0000
commit82e3fc9c1064f06780f6449154c7cf8f663fceac (patch)
tree99ad135852e66f8f6d20384ca479f936e2475fcf /source/slang
parenta26a11ff5c3f14401bbbbbc4b6e3731056618138 (diff)
Add Metal support for WaveGetActiveMask and WaveActiveCountBits (#8218)
## Summary - Add Metal platform support for `WaveGetActiveMask()` and `WaveActiveCountBits()` wave intrinsics - Update capability requirements to include Metal platform for subgroup ballot operations - Implement Metal-specific intrinsic assembly using `simd_ballot()` and `simd_vote` APIs ## Changes - **source/slang/hlsl.meta.slang**: - Add Metal target case for `WaveGetActiveMask()` using `simd_ballot(true)` - Update capability requirements from `cuda_glsl_hlsl_spirv` to `cuda_glsl_hlsl_metal_spirv` for wave ballot functions - **source/slang/slang-capabilities.capdef**: - Add `metal` to `subgroup_ballot_activemask` capability alias
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/hlsl.meta.slang12
-rw-r--r--source/slang/slang-capabilities.capdef1
2 files changed, 9 insertions, 4 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index d2e98529b..c2b3fc436 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -13936,7 +13936,7 @@ WaveMask __WaveGetActiveMask();
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot_activemask)]
+[require(cuda_glsl_hlsl_metal_spirv, subgroup_ballot_activemask)]
WaveMask WaveGetActiveMask()
{
__target_switch
@@ -13945,6 +13945,8 @@ WaveMask WaveGetActiveMask()
__intrinsic_asm "subgroupBallot(true).x";
case hlsl:
__intrinsic_asm "WaveActiveBallot(true).x";
+ case metal:
+ __intrinsic_asm "((uint32_t)((simd_vote::vote_t)simd_ballot(true)))";
case spirv:
let _true = true;
return (spirv_asm
@@ -15503,7 +15505,7 @@ uint4 WaveActiveBallot(bool condition)
}
/// @category wave
-[require(cuda_glsl_hlsl_spirv, subgroup_basic_ballot)]
+[require(cuda_glsl_hlsl_metal_spirv, subgroup_basic_ballot)]
uint WaveActiveCountBits(bool value)
{
__target_switch
@@ -15511,6 +15513,7 @@ uint WaveActiveCountBits(bool value)
case hlsl: __intrinsic_asm "WaveActiveCountBits";
case glsl:
case spirv:
+ case metal:
return _WaveCountBits(WaveActiveBallot(value));
default:
return WaveMaskCountBits(WaveGetActiveMask(), value);
@@ -15600,7 +15603,7 @@ bool WaveIsFirstLane()
// This implementation tries to limit the amount of work required by the actual lane count.
/// @category wave
__spirv_version(1.3)
-[require(cpp_cuda_glsl_hlsl_spirv, subgroup_basic_ballot)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv, subgroup_basic_ballot)]
uint _WaveCountBits(uint4 value)
{
__target_switch
@@ -16083,7 +16086,7 @@ uint WavePrefixCountBits(bool value)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_ballot)]
uint4 WaveGetConvergedMulti()
{
__target_switch
@@ -16093,6 +16096,7 @@ uint4 WaveGetConvergedMulti()
__intrinsic_asm "subgroupBallot(true)";
case hlsl: __intrinsic_asm "WaveActiveBallot(true)";
case cuda: __intrinsic_asm "make_uint4(__activemask(), 0, 0, 0)";
+ case metal: __intrinsic_asm "((uint4)((simd_vote::vote_t)simd_ballot(true)))";
case spirv:
let _true = true;
return spirv_asm
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index afad1b984..0ea43a8df 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -2153,6 +2153,7 @@ alias subgroup_ballot_activemask = spirv_1_0 + GL_KHR_shader_subgroup_ballot
| glsl + GL_KHR_shader_subgroup_ballot
| _sm_6_0
| _cuda_sm_7_0
+ | metal
;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_basic_ballot'
/// [Compound]