summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-11-03 17:10:09 -0700
committerGitHub <noreply@github.com>2023-11-03 17:10:09 -0700
commit79677b83870577fbad9ce65a731d3ae8a4c553c1 (patch)
tree82e10e5eaef138cf61f21c8842d5c8e4b199c198
parent111de4d5527a07877edd971e8be335e067ff9a1b (diff)
Add SubgroupQuad intrinsics for glsl/spirv. (#3310)
* Add SubgroupQuad intrinsics for glsl/spirv. * Fix. * Add test for quad intrinsics. * fix. * improve diagnostics text. * Fix. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/hlsl.meta.slang161
-rw-r--r--source/slang/slang-check-decl.cpp4
-rw-r--r--source/slang/slang-diagnostic-defs.h4
-rw-r--r--tests/hlsl-intrinsic/subgroup-quad.slang41
4 files changed, 197 insertions, 13 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 43e40778c..7ce266e2c 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -6240,20 +6240,165 @@ matrix<T,N,M> WaveMaskPrefixBitXor(WaveMask mask, matrix<T,N,M> expr);
// Information for GLSL wave/subgroup support
// https://github.com/KhronosGroup/GLSL/blob/master/extensions/khr/GL_KHR_shader_subgroup.txt
-__generic<T : __BuiltinType> T QuadReadLaneAt(T sourceValue, uint quadLaneID);
-__generic<T : __BuiltinType, let N : int> vector<T,N> QuadReadLaneAt(vector<T,N> sourceValue, uint quadLaneID);
+__generic<T : __BuiltinType>
+__glsl_extension(GL_KHR_shader_subgroup_quad)
+__spirv_version(1.3)
+T QuadReadLaneAt(T sourceValue, uint quadLaneID)
+{
+ __target_switch
+ {
+ case hlsl:
+ __intrinsic_asm "QuadReadLaneAt";
+ case glsl:
+ __intrinsic_asm "subgroupQuadBroadcast";
+ case spirv:
+ return spirv_asm {
+ OpCapability GroupNonUniformQuad;
+ result:$$T = OpGroupNonUniformQuadBroadcast Subgroup $sourceValue $quadLaneID;
+ };
+ }
+}
+__generic<T : __BuiltinType, let N : int>
+__glsl_extension(GL_KHR_shader_subgroup_quad)
+__spirv_version(1.3)
+vector<T,N> QuadReadLaneAt(vector<T,N> sourceValue, uint quadLaneID)
+{
+ __target_switch
+ {
+ case hlsl:
+ __intrinsic_asm "QuadReadLaneAt";
+ case glsl:
+ __intrinsic_asm "subgroupQuadBroadcast";
+ case spirv:
+ return spirv_asm {
+ OpCapability GroupNonUniformQuad;
+ result:$$vector<T,N> = OpGroupNonUniformQuadBroadcast Subgroup $sourceValue $quadLaneID;
+ };
+ }
+}
__generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadLaneAt(matrix<T,N,M> sourceValue, uint quadLaneID);
-__generic<T : __BuiltinType> T QuadReadAcrossX(T localValue);
-__generic<T : __BuiltinType, let N : int> vector<T,N> QuadReadAcrossX(vector<T,N> localValue);
+
+__generic<T : __BuiltinType>
+__glsl_extension(GL_KHR_shader_subgroup_quad)
+__spirv_version(1.3)
+T QuadReadAcrossX(T localValue)
+{
+ __target_switch
+ {
+ case hlsl:
+ __intrinsic_asm "QuadReadAcrossX";
+ case glsl:
+ __intrinsic_asm "subgroupQuadSwapHorizontal($0)";
+ case spirv:
+ uint direction = 0u;
+ return spirv_asm {
+ OpCapability GroupNonUniformQuad;
+ result:$$T = OpGroupNonUniformQuadSwap Subgroup $localValue $direction;
+ };
+ }
+}
+
+__generic<T : __BuiltinType, let N : int>
+__glsl_extension(GL_KHR_shader_subgroup_quad)
+__spirv_version(1.3)
+vector<T,N> QuadReadAcrossX(vector<T,N> localValue)
+{
+ __target_switch
+ {
+ case hlsl:
+ __intrinsic_asm "QuadReadAcrossX";
+ case glsl:
+ __intrinsic_asm "subgroupQuadSwapHorizontal($0)";
+ case spirv:
+ uint direction = 0u;
+ return spirv_asm {
+ OpCapability GroupNonUniformQuad;
+ result:$$vector<T,N> = OpGroupNonUniformQuadSwap Subgroup $localValue $direction;
+ };
+ }
+}
__generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadAcrossX(matrix<T,N,M> localValue);
-__generic<T : __BuiltinType> T QuadReadAcrossY(T localValue);
-__generic<T : __BuiltinType, let N : int> vector<T,N> QuadReadAcrossY(vector<T,N> localValue);
+__generic<T : __BuiltinType>
+__glsl_extension(GL_KHR_shader_subgroup_quad)
+__spirv_version(1.3)
+T QuadReadAcrossY(T localValue)
+{
+ __target_switch
+ {
+ case hlsl:
+ __intrinsic_asm "QuadReadAcrossY";
+ case glsl:
+ __intrinsic_asm "subgroupQuadSwapVertical($0)";
+ case spirv:
+ uint direction = 1u;
+ return spirv_asm {
+ OpCapability GroupNonUniformQuad;
+ result:$$T = OpGroupNonUniformQuadSwap Subgroup $localValue $direction;
+ };
+ }
+}
+__generic<T : __BuiltinType, let N : int>
+__glsl_extension(GL_KHR_shader_subgroup_quad)
+__spirv_version(1.3)
+vector<T,N> QuadReadAcrossY(vector<T,N> localValue)
+{
+ __target_switch
+ {
+ case hlsl:
+ __intrinsic_asm "QuadReadAcrossY";
+ case glsl:
+ __intrinsic_asm "subgroupQuadSwapVertical($0)";
+ case spirv:
+ uint direction = 1u;
+ return spirv_asm {
+ OpCapability GroupNonUniformQuad;
+ result:$$vector<T,N> = OpGroupNonUniformQuadSwap Subgroup $localValue $direction;
+ };
+ }
+}
+
__generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadAcrossY(matrix<T,N,M> localValue);
-__generic<T : __BuiltinType> T QuadReadAcrossDiagonal(T localValue);
-__generic<T : __BuiltinType, let N : int> vector<T,N> QuadReadAcrossDiagonal(vector<T,N> localValue);
+__generic<T : __BuiltinType>
+__glsl_extension(GL_KHR_shader_subgroup_quad)
+__spirv_version(1.3)
+T QuadReadAcrossDiagonal(T localValue)
+{
+ __target_switch
+ {
+ case hlsl:
+ __intrinsic_asm "QuadReadAcrossDiagonal";
+ case glsl:
+ __intrinsic_asm "subgroupQuadSwapDiagonal($0)";
+ case spirv:
+ uint direction = 2u;
+ return spirv_asm {
+ OpCapability GroupNonUniformQuad;
+ result:$$T = OpGroupNonUniformQuadSwap Subgroup $localValue $direction;
+ };
+ }
+}
+__generic<T : __BuiltinType, let N : int>
+__glsl_extension(GL_KHR_shader_subgroup_quad)
+__spirv_version(1.3)
+vector<T,N> QuadReadAcrossDiagonal(vector<T,N> localValue)
+{
+ __target_switch
+ {
+ case hlsl:
+ __intrinsic_asm "QuadReadAcrossDiagonal";
+ case glsl:
+ __intrinsic_asm "subgroupQuadSwapDiagonal($0)";
+ case spirv:
+ uint direction = 2u;
+ return spirv_asm {
+ OpCapability GroupNonUniformQuad;
+ result:$$vector<T,N> = OpGroupNonUniformQuadSwap Subgroup $localValue $direction;
+ };
+ }
+}
__generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadAcrossDiagonal(matrix<T,N,M> localValue);
// WaveActiveBitAnd, WaveActiveBitOr, WaveActiveBitXor
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 64db4cdc5..03d9f0d9e 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1685,8 +1685,8 @@ namespace Slang
if (!differentialType->equals(diffDiffType))
{
SourceLoc sourceLoc = differentialType->getDeclRef().getDecl()->loc;
- getSink()->diagnose(sourceLoc, Diagnostics::differentialTypeShouldServeAsItsOwnDifferentialType, differentialType);
- getSink()->diagnose(inheritanceDecl, Diagnostics::noteSeeUseOfDifferentialType, differentialType, inheritanceDecl->getSup());
+ getSink()->diagnose(inheritanceDecl, Diagnostics::differentialTypeShouldServeAsItsOwnDifferentialType, differentialType, diffDiffType);
+ getSink()->diagnose(sourceLoc, Diagnostics::seeDefinitionOf, differentialType);
}
}
};
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 2dc3c7388..15f9af156 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -341,14 +341,12 @@ DIAGNOSTIC(30093, Error, uncaughtTryCallInNonThrowFunc, "the current function or
DIAGNOSTIC(30094, Error, mustUseTryClauseToCallAThrowFunc, "the callee may throw an error, and therefore must be called within a 'try' clause")
DIAGNOSTIC(30095, Error, errorTypeOfCalleeIncompatibleWithCaller, "the error type `$1` of callee `$0` is not compatible with the caller's error type `$2`.")
-DIAGNOSTIC(30096, Error, differentialTypeShouldServeAsItsOwnDifferentialType, "type '$0' is used as a `Differential` type, therefore it must serve as its own `Differential` type.")
+DIAGNOSTIC(30096, Error, differentialTypeShouldServeAsItsOwnDifferentialType, "cannot use type '$0' a `Differential` type. A differential type's differential must be itself. However, '$0.Differential' is '$1'.")
DIAGNOSTIC(30097, Error, functionNotMarkedAsDifferentiable, "function '$0' is not marked as $1-differentiable.")
DIAGNOSTIC(30098, Error, nonStaticMemberFunctionNotAllowedAsDiffOperand, "non-static function reference '$0' is not allowed here.")
DIAGNOSTIC(30099, Error, sizeOfArgumentIsInvalid, "argument to sizeof is invalid")
-DIAGNOSTIC(-1, Note, noteSeeUseOfDifferentialType, "see use of '$0' as Differential of '$1'.")
-
// Attributes
DIAGNOSTIC(31000, Error, unknownAttributeName, "unknown attribute '$0'")
DIAGNOSTIC(31001, Error, attributeArgumentCountMismatch, "attribute '$0' expects $1 arguments ($2 provided)")
diff --git a/tests/hlsl-intrinsic/subgroup-quad.slang b/tests/hlsl-intrinsic/subgroup-quad.slang
new file mode 100644
index 000000000..928431a45
--- /dev/null
+++ b/tests/hlsl-intrinsic/subgroup-quad.slang
@@ -0,0 +1,41 @@
+//TEST:SIMPLE(filecheck=SPIRV): -entry main -stage compute -target spirv
+//TEST:SIMPLE(filecheck=SPIRV): -entry main -stage compute -target spirv -emit-spirv-directly
+//TEST:SIMPLE(filecheck=HLSL): -entry main -stage compute -target hlsl
+
+RWStructuredBuffer<float> output;
+
+[numthreads(1,1,1)]
+void main()
+{
+ float x = output[0];
+ float3 vx = float3(x, x, x);
+
+ float v1 = QuadReadLaneAt(x, 1);
+ float v11 = QuadReadLaneAt(vx, 1).x;
+ float v2 = QuadReadAcrossX(x);
+ float v21 = QuadReadAcrossX(vx).x;
+ float v3 = QuadReadAcrossY(x);
+ float v31 = QuadReadAcrossY(vx).x;
+ float v4 = QuadReadAcrossDiagonal(x);
+ float v41 = QuadReadAcrossDiagonal(vx).x;
+
+ output[0] = v1 + v2 + v3 + v4 + v11 + v21 + v31 + v41;
+
+ // HLSL: QuadReadLaneAt
+ // HLSL: QuadReadLaneAt
+ // HLSL: QuadReadAcrossX
+ // HLSL: QuadReadAcrossX
+ // HLSL: QuadReadAcrossY
+ // HLSL: QuadReadAcrossY
+ // HLSL: QuadReadAcrossDiagonal
+ // HLSL: QuadReadAcrossDiagonal
+
+ // SPIRV: OpGroupNonUniformQuadBroadcast {{.*}} %{{u?int_3}} {{.*}} %{{u?int_1}}
+ // SPIRV: OpGroupNonUniformQuadBroadcast {{.*}} %{{u?int_3}} {{.*}} %{{u?int_1}}
+ // SPIRV: OpGroupNonUniformQuadSwap {{.*}} %{{u?int_3}} {{.*}} %{{u?int_0}}
+ // SPIRV: OpGroupNonUniformQuadSwap {{.*}} %{{u?int_3}} {{.*}} %{{u?int_0}}
+ // SPIRV: OpGroupNonUniformQuadSwap {{.*}} %{{u?int_3}} {{.*}} %{{u?int_1}}
+ // SPIRV: OpGroupNonUniformQuadSwap {{.*}} %{{u?int_3}} {{.*}} %{{u?int_1}}
+ // SPIRV: OpGroupNonUniformQuadSwap {{.*}} %{{u?int_3}} {{.*}} %{{u?int_2}}
+ // SPIRV: OpGroupNonUniformQuadSwap {{.*}} %{{u?int_3}} {{.*}} %{{u?int_2}}
+}