summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorDarren Wihandi <65404740+fairywreath@users.noreply.github.com>2025-04-04 19:46:28 -0400
committerGitHub <noreply@github.com>2025-04-04 19:46:28 -0400
commite3e84a1682c9e2d371f3f50f6425374c8b04828d (patch)
treef89f00045acb0dfa3cf03740040f9d78ae22c0b5 /source
parent41e7e565eb3dfa13562cbfa3e8641874c2c6d66c (diff)
Implement subgroup quad operations for Metal (#6745)
Diffstat (limited to 'source')
-rw-r--r--source/slang/glsl.meta.slang12
-rw-r--r--source/slang/hlsl.meta.slang31
-rw-r--r--source/slang/slang-ir-legalize-varying-params.cpp7
-rw-r--r--source/slang/slang-ir-legalize-varying-params.h3
4 files changed, 40 insertions, 13 deletions
diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang
index de5a3fb66..8ffc50f6d 100644
--- a/source/slang/glsl.meta.slang
+++ b/source/slang/glsl.meta.slang
@@ -8100,7 +8100,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
public T subgroupQuadSwapHorizontal(T value)
{
shader_subgroup_preamble<T>();
@@ -8111,7 +8111,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
public T subgroupQuadSwapVertical(T value)
{
shader_subgroup_preamble<T>();
@@ -8122,7 +8122,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
public T subgroupQuadSwapDiagonal(T value)
{
shader_subgroup_preamble<T>();
@@ -8145,7 +8145,7 @@ __generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
public vector<T,N> subgroupQuadSwapHorizontal(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -8156,7 +8156,7 @@ __generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
public vector<T,N> subgroupQuadSwapVertical(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -8167,7 +8167,7 @@ __generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
public vector<T,N> subgroupQuadSwapDiagonal(vector<T,N> value)
{
shader_subgroup_preamble<T>();
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 5f0edb3f9..ae1f6da98 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -11,6 +11,7 @@ void __requireTargetExtension(constexpr String extensionName);
/// explicitly passed as entry point parameters.
in uint __builtinWaveLaneIndex : SV_WaveLaneIndex;
in uint __builtinWaveLaneCount : SV_WaveLaneCount;
+in uint __builtinQuadLaneIndex : SV_QuadLaneIndex;
//@public:
/// Represents an interface for buffer data layout.
@@ -14699,13 +14700,16 @@ __generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[ForceInline]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
T QuadReadAcrossX(T localValue)
{
__target_switch
{
case hlsl: __intrinsic_asm "QuadReadAcrossX";
case glsl: __intrinsic_asm "subgroupQuadSwapHorizontal($0)";
+ case metal:
+ return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 1U);
case spirv:
uint direction = 0u;
return spirv_asm
@@ -14721,13 +14725,16 @@ __generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[ForceInline]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
vector<T,N> QuadReadAcrossX(vector<T,N> localValue)
{
__target_switch
{
case hlsl: __intrinsic_asm "QuadReadAcrossX";
case glsl: __intrinsic_asm "subgroupQuadSwapHorizontal($0)";
+ case metal:
+ return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 1U);
case spirv:
uint direction = 0u;
return spirv_asm
@@ -14745,13 +14752,16 @@ __generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[ForceInline]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
T QuadReadAcrossY(T localValue)
{
__target_switch
{
case hlsl: __intrinsic_asm "QuadReadAcrossY";
case glsl: __intrinsic_asm "subgroupQuadSwapVertical($0)";
+ case metal:
+ return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 2U);
case spirv:
uint direction = 1u;
return spirv_asm
@@ -14766,13 +14776,16 @@ __generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[ForceInline]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
vector<T,N> QuadReadAcrossY(vector<T,N> localValue)
{
__target_switch
{
case hlsl: __intrinsic_asm "QuadReadAcrossY";
case glsl: __intrinsic_asm "subgroupQuadSwapVertical($0)";
+ case metal:
+ return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 2U);
case spirv:
uint direction = 1u;
return spirv_asm
@@ -14790,13 +14803,16 @@ __generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[ForceInline]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
T QuadReadAcrossDiagonal(T localValue)
{
__target_switch
{
case hlsl: __intrinsic_asm "QuadReadAcrossDiagonal";
case glsl: __intrinsic_asm "subgroupQuadSwapDiagonal($0)";
+ case metal:
+ return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 3U);
case spirv:
uint direction = 2u;
return spirv_asm
@@ -14811,13 +14827,16 @@ __generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[ForceInline]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
vector<T,N> QuadReadAcrossDiagonal(vector<T,N> localValue)
{
__target_switch
{
case hlsl: __intrinsic_asm "QuadReadAcrossDiagonal";
case glsl: __intrinsic_asm "subgroupQuadSwapDiagonal($0)";
+ case metal:
+ return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 3U);
case spirv:
uint direction = 2u;
return spirv_asm
diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp
index ad1e89ede..a0d0ef91d 100644
--- a/source/slang/slang-ir-legalize-varying-params.cpp
+++ b/source/slang/slang-ir-legalize-varying-params.cpp
@@ -3242,6 +3242,13 @@ protected:
result.permittedTypes.add(builder.getUInt16Type());
break;
}
+ case SystemValueSemanticName::QuadLaneIndex:
+ {
+ result.systemValueName = toSlice("thread_index_in_quadgroup");
+ result.permittedTypes.add(builder.getUInt16Type());
+ result.permittedTypes.add(builder.getUIntType());
+ break;
+ }
default:
m_sink->diagnose(
parentVar,
diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h
index 0a7c3be8e..ae88fbad1 100644
--- a/source/slang/slang-ir-legalize-varying-params.h
+++ b/source/slang/slang-ir-legalize-varying-params.h
@@ -70,7 +70,8 @@ void depointerizeInputParams(IRFunc* entryPoint);
M(StartInstanceLocation, SV_StartInstanceLocation) \
M(WaveLaneCount, SV_WaveLaneCount) \
M(WaveLaneIndex, SV_WaveLaneIndex) \
- /* end */
+ M(QuadLaneIndex, SV_QuadLaneIndex) \
+/* end */
/// A known system-value semantic name that can be applied to a parameter
///