summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTianyu Li <ltyucb@gmail.com>2025-08-21 01:32:13 +0800
committerGitHub <noreply@github.com>2025-08-20 17:32:13 +0000
commit82e3fc9c1064f06780f6449154c7cf8f663fceac (patch)
tree99ad135852e66f8f6d20384ca479f936e2475fcf
parenta26a11ff5c3f14401bbbbbc4b6e3731056618138 (diff)
Add Metal support for WaveGetActiveMask and WaveActiveCountBits (#8218)
## Summary - Add Metal platform support for `WaveGetActiveMask()` and `WaveActiveCountBits()` wave intrinsics - Update capability requirements to include Metal platform for subgroup ballot operations - Implement Metal-specific intrinsic assembly using `simd_ballot()` and `simd_vote` APIs ## Changes - **source/slang/hlsl.meta.slang**: - Add Metal target case for `WaveGetActiveMask()` using `simd_ballot(true)` - Update capability requirements from `cuda_glsl_hlsl_spirv` to `cuda_glsl_hlsl_metal_spirv` for wave ballot functions - **source/slang/slang-capabilities.capdef**: - Add `metal` to `subgroup_ballot_activemask` capability alias
-rw-r--r--source/slang/hlsl.meta.slang12
-rw-r--r--source/slang/slang-capabilities.capdef1
-rw-r--r--tests/hlsl-intrinsic/wave-active-count-bits.slang1
-rw-r--r--tests/hlsl-intrinsic/wave-mask/wave-get-active.slang1
4 files changed, 11 insertions, 4 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index d2e98529b..c2b3fc436 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -13936,7 +13936,7 @@ WaveMask __WaveGetActiveMask();
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot_activemask)]
+[require(cuda_glsl_hlsl_metal_spirv, subgroup_ballot_activemask)]
WaveMask WaveGetActiveMask()
{
__target_switch
@@ -13945,6 +13945,8 @@ WaveMask WaveGetActiveMask()
__intrinsic_asm "subgroupBallot(true).x";
case hlsl:
__intrinsic_asm "WaveActiveBallot(true).x";
+ case metal:
+ __intrinsic_asm "((uint32_t)((simd_vote::vote_t)simd_ballot(true)))";
case spirv:
let _true = true;
return (spirv_asm
@@ -15503,7 +15505,7 @@ uint4 WaveActiveBallot(bool condition)
}
/// @category wave
-[require(cuda_glsl_hlsl_spirv, subgroup_basic_ballot)]
+[require(cuda_glsl_hlsl_metal_spirv, subgroup_basic_ballot)]
uint WaveActiveCountBits(bool value)
{
__target_switch
@@ -15511,6 +15513,7 @@ uint WaveActiveCountBits(bool value)
case hlsl: __intrinsic_asm "WaveActiveCountBits";
case glsl:
case spirv:
+ case metal:
return _WaveCountBits(WaveActiveBallot(value));
default:
return WaveMaskCountBits(WaveGetActiveMask(), value);
@@ -15600,7 +15603,7 @@ bool WaveIsFirstLane()
// This implementation tries to limit the amount of work required by the actual lane count.
/// @category wave
__spirv_version(1.3)
-[require(cpp_cuda_glsl_hlsl_spirv, subgroup_basic_ballot)]
+[require(cpp_cuda_glsl_hlsl_metal_spirv, subgroup_basic_ballot)]
uint _WaveCountBits(uint4 value)
{
__target_switch
@@ -16083,7 +16086,7 @@ uint WavePrefixCountBits(bool value)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
+[require(cuda_glsl_hlsl_metal_spirv_wgsl, subgroup_ballot)]
uint4 WaveGetConvergedMulti()
{
__target_switch
@@ -16093,6 +16096,7 @@ uint4 WaveGetConvergedMulti()
__intrinsic_asm "subgroupBallot(true)";
case hlsl: __intrinsic_asm "WaveActiveBallot(true)";
case cuda: __intrinsic_asm "make_uint4(__activemask(), 0, 0, 0)";
+ case metal: __intrinsic_asm "((uint4)((simd_vote::vote_t)simd_ballot(true)))";
case spirv:
let _true = true;
return spirv_asm
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index afad1b984..0ea43a8df 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -2153,6 +2153,7 @@ alias subgroup_ballot_activemask = spirv_1_0 + GL_KHR_shader_subgroup_ballot
| glsl + GL_KHR_shader_subgroup_ballot
| _sm_6_0
| _cuda_sm_7_0
+ | metal
;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_basic_ballot'
/// [Compound]
diff --git a/tests/hlsl-intrinsic/wave-active-count-bits.slang b/tests/hlsl-intrinsic/wave-active-count-bits.slang
index a7aa48687..7e8da0907 100644
--- a/tests/hlsl-intrinsic/wave-active-count-bits.slang
+++ b/tests/hlsl-intrinsic/wave-active-count-bits.slang
@@ -5,6 +5,7 @@
//TEST:COMPARE_COMPUTE_EX:-slang -compute -cuda -profile cs_6_0 -shaderobj -render-feature hardware-device
//TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -render-feature hardware-device
//TEST:COMPARE_COMPUTE_EX:-cuda -compute -render-features cuda_sm_7_0 -shaderobj
+//TEST:COMPARE_COMPUTE_EX:-metal -compute -shaderobj
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer;
diff --git a/tests/hlsl-intrinsic/wave-mask/wave-get-active.slang b/tests/hlsl-intrinsic/wave-mask/wave-get-active.slang
index 91debe999..2ff9ef9e3 100644
--- a/tests/hlsl-intrinsic/wave-mask/wave-get-active.slang
+++ b/tests/hlsl-intrinsic/wave-mask/wave-get-active.slang
@@ -4,6 +4,7 @@
//TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0 -shaderobj
//TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj
//TEST:COMPARE_COMPUTE_EX:-cuda -compute -shaderobj
+//TEST:COMPARE_COMPUTE_EX:-metal -compute -shaderobj
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer;