summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/glsl.meta.slang12
-rw-r--r--source/slang/hlsl.meta.slang31
-rw-r--r--source/slang/slang-ir-legalize-varying-params.cpp7
-rw-r--r--source/slang/slang-ir-legalize-varying-params.h3
-rw-r--r--tests/glsl-intrinsic/shader-subgroup/shader-subgroup-quad.slang33
-rw-r--r--tests/hlsl-intrinsic/subgroup-quad.slang13
6 files changed, 73 insertions, 26 deletions
diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang
index de5a3fb66..8ffc50f6d 100644
--- a/source/slang/glsl.meta.slang
+++ b/source/slang/glsl.meta.slang
@@ -8100,7 +8100,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
public T subgroupQuadSwapHorizontal(T value)
{
shader_subgroup_preamble<T>();
@@ -8111,7 +8111,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
public T subgroupQuadSwapVertical(T value)
{
shader_subgroup_preamble<T>();
@@ -8122,7 +8122,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
public T subgroupQuadSwapDiagonal(T value)
{
shader_subgroup_preamble<T>();
@@ -8145,7 +8145,7 @@ __generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
public vector<T,N> subgroupQuadSwapHorizontal(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -8156,7 +8156,7 @@ __generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
public vector<T,N> subgroupQuadSwapVertical(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -8167,7 +8167,7 @@ __generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
public vector<T,N> subgroupQuadSwapDiagonal(vector<T,N> value)
{
shader_subgroup_preamble<T>();
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 5f0edb3f9..ae1f6da98 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -11,6 +11,7 @@ void __requireTargetExtension(constexpr String extensionName);
/// explicitly passed as entry point parameters.
in uint __builtinWaveLaneIndex : SV_WaveLaneIndex;
in uint __builtinWaveLaneCount : SV_WaveLaneCount;
+in uint __builtinQuadLaneIndex : SV_QuadLaneIndex;
//@public:
/// Represents an interface for buffer data layout.
@@ -14699,13 +14700,16 @@ __generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[ForceInline]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
T QuadReadAcrossX(T localValue)
{
__target_switch
{
case hlsl: __intrinsic_asm "QuadReadAcrossX";
case glsl: __intrinsic_asm "subgroupQuadSwapHorizontal($0)";
+ case metal:
+ return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 1U);
case spirv:
uint direction = 0u;
return spirv_asm
@@ -14721,13 +14725,16 @@ __generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[ForceInline]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
vector<T,N> QuadReadAcrossX(vector<T,N> localValue)
{
__target_switch
{
case hlsl: __intrinsic_asm "QuadReadAcrossX";
case glsl: __intrinsic_asm "subgroupQuadSwapHorizontal($0)";
+ case metal:
+ return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 1U);
case spirv:
uint direction = 0u;
return spirv_asm
@@ -14745,13 +14752,16 @@ __generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[ForceInline]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
T QuadReadAcrossY(T localValue)
{
__target_switch
{
case hlsl: __intrinsic_asm "QuadReadAcrossY";
case glsl: __intrinsic_asm "subgroupQuadSwapVertical($0)";
+ case metal:
+ return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 2U);
case spirv:
uint direction = 1u;
return spirv_asm
@@ -14766,13 +14776,16 @@ __generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[ForceInline]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
vector<T,N> QuadReadAcrossY(vector<T,N> localValue)
{
__target_switch
{
case hlsl: __intrinsic_asm "QuadReadAcrossY";
case glsl: __intrinsic_asm "subgroupQuadSwapVertical($0)";
+ case metal:
+ return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 2U);
case spirv:
uint direction = 1u;
return spirv_asm
@@ -14790,13 +14803,16 @@ __generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[ForceInline]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
T QuadReadAcrossDiagonal(T localValue)
{
__target_switch
{
case hlsl: __intrinsic_asm "QuadReadAcrossDiagonal";
case glsl: __intrinsic_asm "subgroupQuadSwapDiagonal($0)";
+ case metal:
+ return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 3U);
case spirv:
uint direction = 2u;
return spirv_asm
@@ -14811,13 +14827,16 @@ __generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
__wgsl_extension(subgroups)
-[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
+[ForceInline]
+[require(glsl_hlsl_metal_spirv_wgsl, subgroup_quad)]
vector<T,N> QuadReadAcrossDiagonal(vector<T,N> localValue)
{
__target_switch
{
case hlsl: __intrinsic_asm "QuadReadAcrossDiagonal";
case glsl: __intrinsic_asm "subgroupQuadSwapDiagonal($0)";
+ case metal:
+ return QuadReadLaneAt(localValue, __builtinQuadLaneIndex ^ 3U);
case spirv:
uint direction = 2u;
return spirv_asm
diff --git a/source/slang/slang-ir-legalize-varying-params.cpp b/source/slang/slang-ir-legalize-varying-params.cpp
index ad1e89ede..a0d0ef91d 100644
--- a/source/slang/slang-ir-legalize-varying-params.cpp
+++ b/source/slang/slang-ir-legalize-varying-params.cpp
@@ -3242,6 +3242,13 @@ protected:
result.permittedTypes.add(builder.getUInt16Type());
break;
}
+ case SystemValueSemanticName::QuadLaneIndex:
+ {
+ result.systemValueName = toSlice("thread_index_in_quadgroup");
+ result.permittedTypes.add(builder.getUInt16Type());
+ result.permittedTypes.add(builder.getUIntType());
+ break;
+ }
default:
m_sink->diagnose(
parentVar,
diff --git a/source/slang/slang-ir-legalize-varying-params.h b/source/slang/slang-ir-legalize-varying-params.h
index 0a7c3be8e..ae88fbad1 100644
--- a/source/slang/slang-ir-legalize-varying-params.h
+++ b/source/slang/slang-ir-legalize-varying-params.h
@@ -70,7 +70,8 @@ void depointerizeInputParams(IRFunc* entryPoint);
M(StartInstanceLocation, SV_StartInstanceLocation) \
M(WaveLaneCount, SV_WaveLaneCount) \
M(WaveLaneIndex, SV_WaveLaneIndex) \
- /* end */
+ M(QuadLaneIndex, SV_QuadLaneIndex) \
+/* end */
/// A known system-value semantic name that can be applied to a parameter
///
diff --git a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-quad.slang b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-quad.slang
index b847cf460..e5d4c9de0 100644
--- a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-quad.slang
+++ b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-quad.slang
@@ -10,6 +10,7 @@
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -emit-spirv-directly
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-wgpu -compute -entry computeMain -allow-glsl -xslang -DWGPU
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-metal -compute -entry computeMain -allow-glsl -xslang -DMETAL
#version 430
@@ -27,7 +28,7 @@ bool test1QuadX() {
& subgroupQuadSwapHorizontal(T(2)) == T(2)
& subgroupQuadSwapVertical(T(2)) == T(2)
& subgroupQuadSwapDiagonal(T(3)) == T(3)
- // subgroupQuadBroadcast is not implemented for WGSL as the WGSL intrinsic only accepts const integers expressions, as of time of writing.
+ // subgroupQuadBroadcast is not implemented for WGSL and Metal as their intrinsics only accepts const integers expressions.
#if !defined(WGPU)
& subgroupQuadBroadcast(T(1), 1) == T(1)
#endif
@@ -41,7 +42,7 @@ bool testVQuadX() {
& subgroupQuadSwapHorizontal(gvec(T(2))) == gvec(T(2))
& subgroupQuadSwapVertical(gvec(T(2))) == gvec(T(2))
& subgroupQuadSwapDiagonal(gvec(T(3))) == gvec(T(3))
- // subgroupQuadBroadcast is not implemented for WGSL as the WGSL intrinsic only accepts const integers expressions, as of time of writing.
+ // subgroupQuadBroadcast is not implemented for WGSL and Metal as their intrinsics only accepts const integers expressions.
#if !defined(WGPU)
& subgroupQuadBroadcast(gvec(T(1)), 1) == gvec(T(1))
#endif
@@ -54,7 +55,7 @@ bool test1QuadX() {
& subgroupQuadSwapHorizontal(T(2)) == T(2)
& subgroupQuadSwapVertical(T(2)) == T(2)
& subgroupQuadSwapDiagonal(T(3)) == T(3)
- // subgroupQuadBroadcast is not implemented for WGSL as the WGSL intrinsic only accepts const integers expressions, as of time of writing.
+ // subgroupQuadBroadcast is not implemented for WGSL and Metal as their intrinsics only accepts const integers expressions.
#if !defined(WGPU)
& subgroupQuadBroadcast(T(1), 1) == T(1)
#endif
@@ -68,7 +69,7 @@ bool testVQuadX() {
& subgroupQuadSwapHorizontal(gvec(T(2))) == gvec(T(2))
& subgroupQuadSwapVertical(gvec(T(2))) == gvec(T(2))
& subgroupQuadSwapDiagonal(gvec(T(3))) == gvec(T(3))
- // subgroupQuadBroadcast is not implemented for WGSL as the WGSL intrinsic only accepts const integers expressions, as of time of writing.
+ // subgroupQuadBroadcast is not implemented for WGSL and Metal as their intrinsics only accepts const integers expressions.
#if !defined(WGPU)
& subgroupQuadBroadcast(gvec(T(1)), 1) == gvec(T(1))
#endif
@@ -93,8 +94,8 @@ bool testQuadSwapX() {
& testVQuadX<uint, 3>()
& testVQuadX<uint, 4>()
- // Disabled on WGPU as these built-in types are not supported as of time of writing.
-#if !defined (WGPU)
+ // Disabled on WGSL and Metal as these built-in types are not supported.
+#if !defined(WGPU) && !defined(METAL)
& test1QuadX<double>() // WARNING: intel GPU's lack FP64 support
& testVQuadX<double, 2>()
& testVQuadX<double, 3>()
@@ -115,19 +116,25 @@ bool testQuadSwapX() {
& testVQuadX<uint8_t, 2>()
& testVQuadX<uint8_t, 3>()
& testVQuadX<uint8_t, 4>()
- & test1QuadX<uint16_t>()
- & testVQuadX<uint16_t, 2>()
- & testVQuadX<uint16_t, 3>()
- & testVQuadX<uint16_t, 4>()
& test1QuadX<uint64_t>()
& testVQuadX<uint64_t, 2>()
& testVQuadX<uint64_t, 3>()
- & testVQuadX<uint64_t, 4>()
- & test1QuadX<bool>()
- & testVQuadX<bool, 2>()
+ & testVQuadX<uint64_t, 4>() & test1QuadX<bool>() & testVQuadX<bool, 2>()
& testVQuadX<bool, 3>()
& testVQuadX<bool, 4>()
#endif
+
+
+#if !defined(WGPU)
+ & test1QuadX<int16_t>()
+ & testVQuadX<int16_t, 2>()
+ & testVQuadX<int16_t, 3>()
+ & testVQuadX<int16_t, 4>()
+ & test1QuadX<uint16_t>()
+ & testVQuadX<uint16_t, 2>()
+ & testVQuadX<uint16_t, 3>()
+ & testVQuadX<uint16_t, 4>()
+#endif
;
}
diff --git a/tests/hlsl-intrinsic/subgroup-quad.slang b/tests/hlsl-intrinsic/subgroup-quad.slang
index 928431a45..1cfbffb49 100644
--- a/tests/hlsl-intrinsic/subgroup-quad.slang
+++ b/tests/hlsl-intrinsic/subgroup-quad.slang
@@ -1,6 +1,7 @@
//TEST:SIMPLE(filecheck=SPIRV): -entry main -stage compute -target spirv
//TEST:SIMPLE(filecheck=SPIRV): -entry main -stage compute -target spirv -emit-spirv-directly
//TEST:SIMPLE(filecheck=HLSL): -entry main -stage compute -target hlsl
+//TEST:SIMPLE(filecheck=METAL): -entry main -stage compute -target metal
RWStructuredBuffer<float> output;
@@ -38,4 +39,16 @@ void main()
// SPIRV: OpGroupNonUniformQuadSwap {{.*}} %{{u?int_3}} {{.*}} %{{u?int_1}}
// SPIRV: OpGroupNonUniformQuadSwap {{.*}} %{{u?int_3}} {{.*}} %{{u?int_2}}
// SPIRV: OpGroupNonUniformQuadSwap {{.*}} %{{u?int_3}} {{.*}} %{{u?int_2}}
+
+ // METAL: quad_shuffle
+ // METAL: quad_shuffle
+ // METAL: ^ 1
+ // METAL: quad_shuffle
+ // METAL: quad_shuffle
+ // METAL: ^ 2
+ // METAL: quad_shuffle
+ // METAL: quad_shuffle
+ // METAL: ^ 3
+ // METAL: quad_shuffle
+ // METAL: quad_shuffle
}