diff options
| author | Darren Wihandi <65404740+fairywreath@users.noreply.github.com> | 2025-04-04 19:46:28 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-04 19:46:28 -0400 |
| commit | e3e84a1682c9e2d371f3f50f6425374c8b04828d (patch) | |
| tree | f89f00045acb0dfa3cf03740040f9d78ae22c0b5 /source | |
| parent | 41e7e565eb3dfa13562cbfa3e8641874c2c6d66c (diff) | |
Implement subgroup quad operations for Metal (#6745)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/glsl.meta.slang | 12 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 31 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-varying-params.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-legalize-varying-params.h | 3 |
4 files changed, 40 insertions, 13 deletions
diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang index de5a3fb66..8ffc50f6d 100644 --- a/source/slang/glsl.meta.slang +++ b/source/slang/glsl.meta.slang @@ -8100,7 +8100,7 @@ __generic<T : __BuiltinType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] public T subgroupQuadSwapHorizontal(T value) { shader_subgroup_preamble<T>(); @@ -8111,7 +8111,7 @@ __generic<T : __BuiltinType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] public T subgroupQuadSwapVertical(T value) { shader_subgroup_preamble<T>(); @@ -8122,7 +8122,7 @@ __generic<T : __BuiltinType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] public T subgroupQuadSwapDiagonal(T value) { shader_subgroup_preamble<T>(); @@ -8145,7 +8145,7 @@ __generic<T : __BuiltinType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] public vector<T,N> subgroupQuadSwapHorizontal(vector<T,N> value) { shader_subgroup_preamble<T>(); @@ -8156,7 +8156,7 @@ __generic<T : __BuiltinType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] public vector<T,N> subgroupQuadSwapVertical(vector<T,N> value) { shader_subgroup_preamble<T>(); @@ -8167,7 +8167,7 @@ __generic<T : __BuiltinType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] public vector<T,N> subgroupQuadSwapDiagonal(vector<T,N> value) { shader_subgroup_preamble<T>(); diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 5f0edb3f9..ae1f6da98 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -11,6 +11,7 @@ void __requireTargetExtension(constexpr String extensionName); /// explicitly passed as entry point parameters. in uint __builtinWaveLaneIndex : SV_WaveLaneIndex; in uint __builtinWaveLaneCount : SV_WaveLaneCount; +in uint __builtinQuadLaneIndex : SV_QuadLaneIndex; //@public: /// Represents an interface for buffer data layout. @@ -14699,13 +14700,16 @@ __generic<T : __BuiltinType> __glsl_extension(GL_KHR_shader_subgroup_quad) __spirv_version(1.3) __wgsl_extension(subgroups) -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[ForceInline] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] T QuadReadAcrossX(T localValue) { __target_switch { case hlsl: __intrinsic_asm "QuadReadAcrossX"; case glsl: __intrinsic_asm "subgroupQuadSwapHorizontal($0)"; + case metal: + return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 1U); case spirv: uint direction = 0u; return spirv_asm @@ -14721,13 +14725,16 @@ __generic<T : __BuiltinType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_quad) __spirv_version(1.3) __wgsl_extension(subgroups) -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[ForceInline] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] vector<T,N> QuadReadAcrossX(vector<T,N> localValue) { __target_switch { case hlsl: __intrinsic_asm "QuadReadAcrossX"; case glsl: __intrinsic_asm "subgroupQuadSwapHorizontal($0)"; + case metal: + return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 1U); case spirv: uint direction = 0u; return spirv_asm @@ -14745,13 +14752,16 @@ __generic<T : __BuiltinType> __glsl_extension(GL_KHR_shader_subgroup_quad) __spirv_version(1.3) __wgsl_extension(subgroups) -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[ForceInline] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] T QuadReadAcrossY(T localValue) { __target_switch { case hlsl: __intrinsic_asm "QuadReadAcrossY"; case glsl: __intrinsic_asm "subgroupQuadSwapVertical($0)"; + case metal: + return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 2U); case spirv: uint direction = 1u; return spirv_asm @@ -14766,13 +14776,16 @@ __generic<T : __BuiltinType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_quad) __spirv_version(1.3) __wgsl_extension(subgroups) -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[ForceInline] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] vector<T,N> QuadReadAcrossY(vector<T,N> localValue) { __target_switch { case hlsl: __intrinsic_asm "QuadReadAcrossY"; case glsl: __intrinsic_asm "subgroupQuadSwapVertical($0)"; + case metal: + return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 2U); case spirv: uint direction = 1u; return spirv_asm @@ -14790,13 +14803,16 @@ __generic<T : __BuiltinType> __glsl_extension(GL_KHR_shader_subgroup_quad) __spirv_version(1.3) __wgsl_extension(subgroups) -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[ForceInline] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] T QuadReadAcrossDiagonal(T localValue) { __target_switch { case hlsl: __intrinsic_asm "QuadReadAcrossDiagonal"; case glsl: __intrinsic_asm "subgroupQuadSwapDiagonal($0)"; + case metal: + return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 3U); case spirv: uint direction = 2u; return spirv_asm @@ -14811,13 +14827,16 @@ __generic<T : __BuiltinType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_quad) __spirv_version(1.3) __wgsl_extension(subgroups) -[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] +[ForceInline] +[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)] vector<T,N> QuadReadAcrossDiagonal(vector<T,N> localValue) { __target_switch { case hlsl: __intrinsic_asm "QuadReadAcrossDiagonal"; case glsl: __intrinsic_asm "subgroupQuadSwapDiagonal($0)"; + case metal: + return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 3U); case spirv: uint direction = 2u; return spirv_asm diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp index ad1e89ede..a0d0ef91d 100644 --- a/source/slang/slang-ir-legalize-varying-params.cpp +++ b/source/slang/slang-ir-legalize-varying-params.cpp @@ -3242,6 +3242,13 @@ protected: result.permittedTypes.add(builder.getUInt16Type()); break; } + case SystemValueSemanticName::QuadLaneIndex: + { + result.systemValueName = toSlice("thread_index_in_quadgroup"); + result.permittedTypes.add(builder.getUInt16Type()); + result.permittedTypes.add(builder.getUIntType()); + break; + } default: m_sink->diagnose( parentVar, diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h index 0a7c3be8e..ae88fbad1 100644 --- a/source/slang/slang-ir-legalize-varying-params.h +++ b/source/slang/slang-ir-legalize-varying-params.h @@ -70,7 +70,8 @@ void depointerizeInputParams(IRFunc* entryPoint); M(StartInstanceLocation, SV_StartInstanceLocation) \ M(WaveLaneCount, SV_WaveLaneCount) \ M(WaveLaneIndex, SV_WaveLaneIndex) \ - /* end */ + M(QuadLaneIndex, SV_QuadLaneIndex) \ +/* end */ /// A known system-value semantic name that can be applied to a parameter /// |
