From e3e84a1682c9e2d371f3f50f6425374c8b04828d Mon Sep 17 00:00:00 2001 From: Darren Wihandi <65404740+fairywreath@users.noreply.github.com> Date: Fri, 4 Apr 2025 19:46:28 -0400 Subject: Implement subgroup quad operations for Metal (#6745) --- source/slang/glsl.meta.slang | 12 ++++---- source/slang/hlsl.meta.slang | 31 ++++++++++++++++---- source/slang/slang-ir-legalize-varying-params.cpp | 7 +++++ source/slang/slang-ir-legalize-varying-params.h | 3 +- .../shader-subgroup/shader-subgroup-quad.slang | 33 +++++++++++++--------- tests/hlsl-intrinsic/subgroup-quad.slang | 13 +++++++++ 6 files changed, 73 insertions(+), 26 deletions(-) diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang index de5a3fb66..8ffc50f6d 100644 --- a/source/slang/glsl.meta.slang +++ b/source/slang/glsl.meta.slang @@ -8100,7 +8100,7 @@ __generic __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] public T subgroupQuadSwapHorizontal(T value) { shader_subgroup_preamble(); @@ -8111,7 +8111,7 @@ __generic __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] public T subgroupQuadSwapVertical(T value) { shader_subgroup_preamble(); @@ -8122,7 +8122,7 @@ __generic __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] public T subgroupQuadSwapDiagonal(T value) { shader_subgroup_preamble(); @@ -8145,7 +8145,7 @@ __generic __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] public vector subgroupQuadSwapHorizontal(vector value) { shader_subgroup_preamble(); @@ -8156,7 +8156,7 @@ __generic __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] public vector subgroupQuadSwapVertical(vector value) { shader_subgroup_preamble(); @@ -8167,7 +8167,7 @@ __generic __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] public vector subgroupQuadSwapDiagonal(vector value) { shader_subgroup_preamble(); diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 5f0edb3f9..ae1f6da98 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -11,6 +11,7 @@ void __requireTargetExtension(constexpr String extensionName); /// explicitly passed as entry point parameters. in uint __builtinWaveLaneIndex : SV_WaveLaneIndex; in uint __builtinWaveLaneCount : SV_WaveLaneCount; +in uint __builtinQuadLaneIndex : SV_QuadLaneIndex; //@public: /// Represents an interface for buffer data layout. @@ -14699,13 +14700,16 @@ __generic __glsl_extension(GL_KHR_shader_subgroup_quad) __spirv_version(1.3) __wgsl_extension(subgroups) -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[ForceInline] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] T QuadReadAcrossX(T localValue) { __target_switch { case hlsl: __intrinsic_asm "QuadReadAcrossX"; case glsl: __intrinsic_asm "subgroupQuadSwapHorizontal($0)"; + case metal: + return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 1U); case spirv: uint direction = 0u; return spirv_asm @@ -14721,13 +14725,16 @@ __generic __glsl_extension(GL_KHR_shader_subgroup_quad) __spirv_version(1.3) __wgsl_extension(subgroups) -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[ForceInline] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] vector QuadReadAcrossX(vector localValue) { __target_switch { case hlsl: __intrinsic_asm "QuadReadAcrossX"; case glsl: __intrinsic_asm "subgroupQuadSwapHorizontal($0)"; + case metal: + return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 1U); case spirv: uint direction = 0u; return spirv_asm @@ -14745,13 +14752,16 @@ __generic __glsl_extension(GL_KHR_shader_subgroup_quad) __spirv_version(1.3) __wgsl_extension(subgroups) -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[ForceInline] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] T QuadReadAcrossY(T localValue) { __target_switch { case hlsl: __intrinsic_asm "QuadReadAcrossY"; case glsl: __intrinsic_asm "subgroupQuadSwapVertical($0)"; + case metal: + return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 2U); case spirv: uint direction = 1u; return spirv_asm @@ -14766,13 +14776,16 @@ __generic __glsl_extension(GL_KHR_shader_subgroup_quad) __spirv_version(1.3) __wgsl_extension(subgroups) -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[ForceInline] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] vector QuadReadAcrossY(vector localValue) { __target_switch { case hlsl: __intrinsic_asm "QuadReadAcrossY"; case glsl: __intrinsic_asm "subgroupQuadSwapVertical($0)"; + case metal: + return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 2U); case spirv: uint direction = 1u; return spirv_asm @@ -14790,13 +14803,16 @@ __generic __glsl_extension(GL_KHR_shader_subgroup_quad) __spirv_version(1.3) __wgsl_extension(subgroups) -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[ForceInline] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] T QuadReadAcrossDiagonal(T localValue) { __target_switch { case hlsl: __intrinsic_asm "QuadReadAcrossDiagonal"; case glsl: __intrinsic_asm "subgroupQuadSwapDiagonal($0)"; + case metal: + return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 3U); case spirv: uint direction = 2u; return spirv_asm @@ -14811,13 +14827,16 @@ __generic __glsl_extension(GL_KHR_shader_subgroup_quad) __spirv_version(1.3) __wgsl_extension(subgroups) -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[ForceInline] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] vector QuadReadAcrossDiagonal(vector localValue) { __target_switch { case hlsl: __intrinsic_asm "QuadReadAcrossDiagonal"; case glsl: __intrinsic_asm "subgroupQuadSwapDiagonal($0)"; + case metal: + return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 3U); case spirv: uint direction = 2u; return spirv_asm diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index ad1e89ede..a0d0ef91d 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -3242,6 +3242,13 @@ protected: result.permittedTypes.add(builder.getUInt16Type()); break; } + case SystemValueSemanticName::QuadLaneIndex: + { + result.systemValueName = toSlice("thread_index_in_quadgroup"); + result.permittedTypes.add(builder.getUInt16Type()); + result.permittedTypes.add(builder.getUIntType()); + break; + } default: m_sink->diagnose( parentVar, diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h index 0a7c3be8e..ae88fbad1 100644 --- a/source/slang/slang-ir-legalize-varying-params.h +++ b/source/slang/slang-ir-legalize-varying-params.h @@ -70,7 +70,8 @@ void depointerizeInputParams(IRFunc* entryPoint); M(StartInstanceLocation, SV_StartInstanceLocation) \ M(WaveLaneCount, SV_WaveLaneCount) \ M(WaveLaneIndex, SV_WaveLaneIndex) \ - /* end */ + M(QuadLaneIndex, SV_QuadLaneIndex) \ +/* end */ /// A known system-value semantic name that can be applied to a parameter /// diff --git a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-quad.slang b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-quad.slang index b847cf460..e5d4c9de0 100644 --- a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-quad.slang +++ b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-quad.slang @@ -10,6 +10,7 @@ //TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl //TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -emit-spirv-directly //TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-wgpu -compute -entry computeMain -allow-glsl -xslang -DWGPU +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-metal -compute -entry computeMain -allow-glsl -xslang -DMETAL #version 430 @@ -27,7 +28,7 @@ bool test1QuadX() { & subgroupQuadSwapHorizontal(T(2)) == T(2) & subgroupQuadSwapVertical(T(2)) == T(2) & subgroupQuadSwapDiagonal(T(3)) == T(3) - // subgroupQuadBroadcast is not implemented for WGSL as the WGSL intrinsic only accepts const integers expressions, as of time of writing. + // subgroupQuadBroadcast is not implemented for WGSL and Metal as their intrinsics only accepts const integers expressions. #if !defined(WGPU) & subgroupQuadBroadcast(T(1), 1) == T(1) #endif @@ -41,7 +42,7 @@ bool testVQuadX() { & subgroupQuadSwapHorizontal(gvec(T(2))) == gvec(T(2)) & subgroupQuadSwapVertical(gvec(T(2))) == gvec(T(2)) & subgroupQuadSwapDiagonal(gvec(T(3))) == gvec(T(3)) - // subgroupQuadBroadcast is not implemented for WGSL as the WGSL intrinsic only accepts const integers expressions, as of time of writing. + // subgroupQuadBroadcast is not implemented for WGSL and Metal as their intrinsics only accepts const integers expressions. #if !defined(WGPU) & subgroupQuadBroadcast(gvec(T(1)), 1) == gvec(T(1)) #endif @@ -54,7 +55,7 @@ bool test1QuadX() { & subgroupQuadSwapHorizontal(T(2)) == T(2) & subgroupQuadSwapVertical(T(2)) == T(2) & subgroupQuadSwapDiagonal(T(3)) == T(3) - // subgroupQuadBroadcast is not implemented for WGSL as the WGSL intrinsic only accepts const integers expressions, as of time of writing. + // subgroupQuadBroadcast is not implemented for WGSL and Metal as their intrinsics only accepts const integers expressions. #if !defined(WGPU) & subgroupQuadBroadcast(T(1), 1) == T(1) #endif @@ -68,7 +69,7 @@ bool testVQuadX() { & subgroupQuadSwapHorizontal(gvec(T(2))) == gvec(T(2)) & subgroupQuadSwapVertical(gvec(T(2))) == gvec(T(2)) & subgroupQuadSwapDiagonal(gvec(T(3))) == gvec(T(3)) - // subgroupQuadBroadcast is not implemented for WGSL as the WGSL intrinsic only accepts const integers expressions, as of time of writing. + // subgroupQuadBroadcast is not implemented for WGSL and Metal as their intrinsics only accepts const integers expressions. #if !defined(WGPU) & subgroupQuadBroadcast(gvec(T(1)), 1) == gvec(T(1)) #endif @@ -93,8 +94,8 @@ bool testQuadSwapX() { & testVQuadX() & testVQuadX() - // Disabled on WGPU as these built-in types are not supported as of time of writing. -#if !defined (WGPU) + // Disabled on WGSL and Metal as these built-in types are not supported. +#if !defined(WGPU) && !defined(METAL) & test1QuadX() // WARNING: intel GPU's lack FP64 support & testVQuadX() & testVQuadX() @@ -115,19 +116,25 @@ bool testQuadSwapX() { & testVQuadX() & testVQuadX() & testVQuadX() - & test1QuadX() - & testVQuadX() - & testVQuadX() - & testVQuadX() & test1QuadX() & testVQuadX() & testVQuadX() - & testVQuadX() - & test1QuadX() - & testVQuadX() + & testVQuadX() & test1QuadX() & testVQuadX() & testVQuadX() & testVQuadX() #endif + + +#if !defined(WGPU) + & test1QuadX() + & testVQuadX() + & testVQuadX() + & testVQuadX() + & test1QuadX() + & testVQuadX() + & testVQuadX() + & testVQuadX() +#endif ; } diff --git a/tests/hlsl-intrinsic/subgroup-quad.slang b/tests/hlsl-intrinsic/subgroup-quad.slang index 928431a45..1cfbffb49 100644 --- a/tests/hlsl-intrinsic/subgroup-quad.slang +++ b/tests/hlsl-intrinsic/subgroup-quad.slang @@ -1,6 +1,7 @@ //TEST:SIMPLE(filecheck=SPIRV): -entry main -stage compute -target spirv //TEST:SIMPLE(filecheck=SPIRV): -entry main -stage compute -target spirv -emit-spirv-directly //TEST:SIMPLE(filecheck=HLSL): -entry main -stage compute -target hlsl +//TEST:SIMPLE(filecheck=METAL): -entry main -stage compute -target metal RWStructuredBuffer output; @@ -38,4 +39,16 @@ void main() // SPIRV: OpGroupNonUniformQuadSwap {{.*}} %{{u?int_3}} {{.*}} %{{u?int_1}} // SPIRV: OpGroupNonUniformQuadSwap {{.*}} %{{u?int_3}} {{.*}} %{{u?int_2}} // SPIRV: OpGroupNonUniformQuadSwap {{.*}} %{{u?int_3}} {{.*}} %{{u?int_2}} + + // METAL: quad_shuffle + // METAL: quad_shuffle + // METAL: ^ 1 + // METAL: quad_shuffle + // METAL: quad_shuffle + // METAL: ^ 2 + // METAL: quad_shuffle + // METAL: quad_shuffle + // METAL: ^ 3 + // METAL: quad_shuffle + // METAL: quad_shuffle } -- cgit v1.2.3