summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorDarren Wihandi <65404740+fairywreath@users.noreply.github.com>2025-02-02 15:27:11 -0500
committerGitHub <noreply@github.com>2025-02-02 12:27:11 -0800
commit0a6828572aa4cc1f0f99993e77c321799eb88cca (patch)
treed18f1950074958ff3276e303425eed15067ea2bc
parent2949b786a7f04ad31c113b622039fb5b72bc8622 (diff)
Add support for WGSL subgroup operations (#6213)
* initial work * more work * more work on glsl intrinsics * add subgroup broadcast for glsl * wip add wgsl extension tracking * enable tests, enable extensions and added some todos * format and warning fixes * fix wgsl extension tracker --------- Co-authored-by: Yong He <yonghe@outlook.com>
-rw-r--r--docs/user-guide/a3-02-reference-capability-atoms.md6
-rw-r--r--source/slang/glsl.meta.slang174
-rw-r--r--source/slang/hlsl.meta.slang280
-rw-r--r--source/slang/slang-ast-modifier.h9
-rw-r--r--source/slang/slang-ast-print.cpp2
-rw-r--r--source/slang/slang-capabilities.capdef47
-rw-r--r--source/slang/slang-compiler.cpp13
-rw-r--r--source/slang/slang-emit-glsl.cpp6
-rw-r--r--source/slang/slang-emit-glsl.h4
-rw-r--r--source/slang/slang-emit-wgsl.cpp36
-rw-r--r--source/slang/slang-emit-wgsl.h15
-rw-r--r--source/slang/slang-emit.cpp6
-rw-r--r--source/slang/slang-extension-tracker.cpp (renamed from source/slang/slang-glsl-extension-tracker.cpp)22
-rw-r--r--source/slang/slang-extension-tracker.h (renamed from source/slang/slang-glsl-extension-tracker.h)15
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp8
-rw-r--r--source/slang/slang-ir-glsl-legalize.h4
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-insts.h14
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp1
-rw-r--r--source/slang/slang-lower-to-ir.cpp10
-rw-r--r--source/slang/slang-parser.cpp12
-rw-r--r--tests/glsl-intrinsic/shader-subgroup/shader-subgroup-arithmetic_Exclusive.slang33
-rw-r--r--tests/glsl-intrinsic/shader-subgroup/shader-subgroup-arithmetic_Inclusive.slang72
-rw-r--r--tests/glsl-intrinsic/shader-subgroup/shader-subgroup-arithmetic_None.slang34
-rw-r--r--tests/glsl-intrinsic/shader-subgroup/shader-subgroup-ballot.slang25
-rw-r--r--tests/glsl-intrinsic/shader-subgroup/shader-subgroup-basic.slang60
-rw-r--r--tests/glsl-intrinsic/shader-subgroup/shader-subgroup-quad.slang34
-rw-r--r--tests/glsl-intrinsic/shader-subgroup/shader-subgroup-shuffle-relative.slang22
-rw-r--r--tests/glsl-intrinsic/shader-subgroup/shader-subgroup-shuffle.slang22
-rw-r--r--tests/glsl-intrinsic/shader-subgroup/shader-subgroup-vote.slang31
-rw-r--r--tests/hlsl-intrinsic/wave-active-product.slang1
-rw-r--r--tests/hlsl-intrinsic/wave-broadcast-lane-at-vk.slang1
-rw-r--r--tests/hlsl-intrinsic/wave-diverge.slang1
-rw-r--r--tests/hlsl-intrinsic/wave-is-first-lane.slang1
-rw-r--r--tests/hlsl-intrinsic/wave-prefix-product.slang1
-rw-r--r--tests/hlsl-intrinsic/wave-prefix-sum-fp16.slang8
-rw-r--r--tests/hlsl-intrinsic/wave-prefix-sum.slang1
-rw-r--r--tests/hlsl-intrinsic/wave-read-lane-at-vk.slang1
-rw-r--r--tests/hlsl-intrinsic/wave-shuffle-vk.slang1
-rw-r--r--tests/hlsl-intrinsic/wave-vector.slang1
-rw-r--r--tests/hlsl-intrinsic/wave.slang1
41 files changed, 727 insertions, 309 deletions
diff --git a/docs/user-guide/a3-02-reference-capability-atoms.md b/docs/user-guide/a3-02-reference-capability-atoms.md
index 7cdf89bff..e7de2cfa9 100644
--- a/docs/user-guide/a3-02-reference-capability-atoms.md
+++ b/docs/user-guide/a3-02-reference-capability-atoms.md
@@ -788,6 +788,9 @@ Compound Capabilities
`cuda_glsl_hlsl_spirv`
> CUDA, GLSL, HLSL, and SPIRV code-gen targets
+`cuda_glsl_hlsl_spirv_wgsl`
+> CUDA, GLSL, HLSL, SPIRV, and WGSL code-gen targets
+
`cuda_glsl_hlsl_metal_spirv`
> CUDA, GLSL, HLSL, Metal, and SPIRV code-gen targets
@@ -830,6 +833,9 @@ Compound Capabilities
`glsl_spirv`
> GLSL, and SPIRV code-gen targets
+`glsl_spirv_wgsl`
+> GLSL, SPIRV, and WGSL code-gen targets
+
`hlsl_spirv`
> HLSL, and SPIRV code-gen targets
diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang
index dd1c5a907..ef3bfd683 100644
--- a/source/slang/glsl.meta.slang
+++ b/source/slang/glsl.meta.slang
@@ -6305,7 +6305,7 @@ void shader_subgroup_preamble() {
// GL_KHR_shader_subgroup_basic Built-in Variables
-[require(cpp_cuda_glsl_hlsl_spirv, subgroup_basic)]
+[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
void requireGLSLExtForSubgroupBasicBuiltin() {
__target_switch
{
@@ -6317,7 +6317,7 @@ void requireGLSLExtForSubgroupBasicBuiltin() {
}
}
-[require(cpp_cuda_glsl_hlsl_spirv, subgroup_basic)]
+[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
void setupExtForSubgroupBasicBuiltIn() {
__target_switch
{
@@ -6329,7 +6329,7 @@ void setupExtForSubgroupBasicBuiltIn() {
}
__spirv_version(1.3)
-[require(cpp_cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
void requireGLSLExtForSubgroupBallotBuiltin() {
__target_switch
{
@@ -6342,7 +6342,7 @@ void requireGLSLExtForSubgroupBallotBuiltin() {
}
__spirv_version(1.3)
-[require(cpp_cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
void setupExtForSubgroupBallotBuiltIn() {
__target_switch
{
@@ -6392,7 +6392,7 @@ public property uint gl_SubgroupID
public property uint gl_SubgroupSize
{
- [require(cpp_cuda_glsl_hlsl_spirv, subgroup_basic)]
+ [require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
get {
setupExtForSubgroupBasicBuiltIn();
return WaveGetLaneCount();
@@ -6401,7 +6401,7 @@ public property uint gl_SubgroupSize
public property uint gl_SubgroupInvocationID
{
- [require(cpp_cuda_glsl_hlsl_spirv, subgroup_basic)]
+ [require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
get {
setupExtForSubgroupBasicBuiltIn();
return WaveGetLaneIndex();
@@ -6625,7 +6625,7 @@ public void subgroupMemoryBarrierShared()
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_basic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_basic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
public bool subgroupElect()
{
__target_switch
@@ -6635,6 +6635,7 @@ public bool subgroupElect()
case glsl:
case spirv:
case hlsl:
+ case wgsl:
return WaveIsFirstLane();
}
@@ -6645,7 +6646,7 @@ public bool subgroupElect()
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_vote)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_vote)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_vote)]
public bool subgroupAll(bool value)
{
return WaveActiveAllTrue(value);
@@ -6654,7 +6655,7 @@ public bool subgroupAll(bool value)
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_vote)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_vote)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_vote)]
public bool subgroupAny(bool value)
{
return WaveActiveAnyTrue(value);
@@ -6688,7 +6689,7 @@ __generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupAdd(T value)
{
shader_subgroup_preamble<T>();
@@ -6699,7 +6700,7 @@ __generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupMul(T value)
{
shader_subgroup_preamble<T>();
@@ -6710,7 +6711,7 @@ __generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupMin(T value)
{
shader_subgroup_preamble<T>();
@@ -6721,7 +6722,7 @@ __generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupMax(T value)
{
shader_subgroup_preamble<T>();
@@ -6731,14 +6732,17 @@ public T subgroupMax(T value)
__generic<T : __BuiltinLogicalType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupAnd(T value)
{
shader_subgroup_preamble<T>();
__target_switch
{
- case glsl: __intrinsic_asm "subgroupAnd($0)";
+ case glsl:
+ case wgsl:
+ __intrinsic_asm "subgroupAnd($0)";
case spirv:
if (__isBool<T>()) {
return spirv_asm {
@@ -6758,14 +6762,17 @@ public T subgroupAnd(T value)
__generic<T : __BuiltinLogicalType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupOr(T value)
{
shader_subgroup_preamble<T>();
__target_switch
{
- case glsl: __intrinsic_asm "subgroupOr($0)";
+ case glsl:
+ case wgsl:
+ __intrinsic_asm "subgroupOr($0)";
case spirv:
if (__isBool<T>()) {
return spirv_asm {
@@ -6785,14 +6792,17 @@ public T subgroupOr(T value)
__generic<T : __BuiltinLogicalType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupXor(T value)
{
shader_subgroup_preamble<T>();
__target_switch
{
- case glsl: __intrinsic_asm "subgroupXor($0)";
+ case glsl:
+ case wgsl:
+ __intrinsic_asm "subgroupXor($0)";
case spirv:
if (__isBool<T>()) {
return spirv_asm {
@@ -6812,14 +6822,16 @@ public T subgroupXor(T value)
__generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupInclusiveAdd(T value)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupInclusiveAdd($0)";
case spirv:
if (__isFloat<T>())
@@ -6833,14 +6845,16 @@ public T subgroupInclusiveAdd(T value)
__generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupInclusiveMul(T value)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupInclusiveMul($0)";
case spirv:
if (__isFloat<T>())
@@ -6974,7 +6988,7 @@ __generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupExclusiveAdd(T value)
{
shader_subgroup_preamble<T>();
@@ -6986,7 +7000,7 @@ __generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupExclusiveMul(T value)
{
shader_subgroup_preamble<T>();
@@ -7097,7 +7111,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupAdd(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7108,7 +7122,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupMul(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7119,7 +7133,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupMin(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7130,7 +7144,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupMax(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7140,14 +7154,18 @@ public vector<T,N> subgroupMax(vector<T,N> value)
__generic<T : __BuiltinLogicalType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupAnd(vector<T,N> value)
{
shader_subgroup_preamble<T>();
__target_switch
{
- case glsl: __intrinsic_asm "subgroupAnd($0)";
+ case glsl:
+ case wgsl:
+ // TODO: Bool inputs are invalid for WGSL, cast them to int or don't allow them to compile.
+ __intrinsic_asm "subgroupAnd($0)";
case spirv:
if (__isBool<T>()) {
return spirv_asm {
@@ -7168,14 +7186,17 @@ public vector<T,N> subgroupAnd(vector<T,N> value)
__generic<T : __BuiltinLogicalType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupOr(vector<T,N> value)
{
shader_subgroup_preamble<T>();
__target_switch
{
- case glsl: __intrinsic_asm "subgroupOr($0)";
+ case glsl:
+ case wgsl:
+ __intrinsic_asm "subgroupOr($0)";
case spirv:
if (__isBool<T>()) {
return spirv_asm {
@@ -7196,14 +7217,17 @@ public vector<T,N> subgroupOr(vector<T,N> value)
__generic<T : __BuiltinLogicalType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupXor(vector<T,N> value)
{
shader_subgroup_preamble<T>();
__target_switch
{
- case glsl: __intrinsic_asm "subgroupXor($0)";
+ case glsl:
+ case wgsl:
+ __intrinsic_asm "subgroupXor($0)";
case spirv:
if (__isBool<T>()) {
return spirv_asm {
@@ -7223,14 +7247,16 @@ public vector<T,N> subgroupXor(vector<T,N> value)
__generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupInclusiveAdd(vector<T,N> value)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupInclusiveAdd($0)";
case spirv:
if (__isFloat<T>())
@@ -7244,14 +7270,16 @@ public vector<T,N> subgroupInclusiveAdd(vector<T,N> value)
__generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupInclusiveMul(vector<T,N> value)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupInclusiveMul($0)";
case spirv:
if (__isFloat<T>())
@@ -7366,7 +7394,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupExclusiveAdd(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7378,7 +7406,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupExclusiveMul(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7488,51 +7516,65 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
public T subgroupBroadcast(T value, uint id)
{
shader_subgroup_preamble<T>();
- return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, id);
+ __target_switch
+ {
+ case wgsl:
+ // WGSL's intrinsic does not accept non-const ids, do shuffle instead.
+ __intrinsic_asm "subgroupShuffle";
+ default:
+ return WaveBroadcastLaneAt(value, id);
+ }
}
__generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
public vector<T,N> subgroupBroadcast(vector<T,N> value, uint id)
{
shader_subgroup_preamble<T>();
- return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, id);
+ __target_switch
+ {
+ case wgsl:
+ // WGSL's intrinsic does not accept non-const ids, do shuffle instead.
+ __intrinsic_asm "subgroupShuffle";
+ default:
+ return WaveBroadcastLaneAt(value, id);
+ }
}
__generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
public T subgroupBroadcastFirst(T value)
{
shader_subgroup_preamble<T>();
- return WaveMaskReadLaneFirst(WaveGetActiveMask(), value);
+ return WaveReadLaneFirst(value);
}
__generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
public vector<T,N> subgroupBroadcastFirst(vector<T,N> value)
{
shader_subgroup_preamble<T>();
- return WaveMaskReadLaneFirst(WaveGetActiveMask(), value);
+ return WaveReadLaneFirst(value);
}
// WaveMaskBallot is not the same; it force trunc's
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
public uvec4 subgroupBallot(bool value)
{
return WaveActiveBallot(value);
@@ -7713,7 +7755,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
public T subgroupShuffle(T value, uint index)
{
shader_subgroup_preamble<T>();
@@ -7723,13 +7765,15 @@ public T subgroupShuffle(T value, uint index)
__generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
-[require(glsl_spirv, subgroup_shuffle)]
+__wgsl_extension(subgroups)
+[require(glsl_spirv_wgsl, subgroup_shuffle)]
[ForceInline] public T subgroupShuffleXor(T value, uint mask)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupShuffleXor($0,$1)";
case spirv:
return spirv_asm {
@@ -7743,7 +7787,7 @@ __generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
public vector<T,N> subgroupShuffle(vector<T,N> value, uint index)
{
shader_subgroup_preamble<T>();
@@ -7753,14 +7797,16 @@ public vector<T,N> subgroupShuffle(vector<T,N> value, uint index)
__generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_shuffle)]
+[require(glsl_spirv_wgsl, subgroup_shuffle)]
public vector<T,N> subgroupShuffleXor(vector<T,N> value, uint mask)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupShuffleXor($0,$1)";
case spirv:
return spirv_asm {
@@ -7776,14 +7822,16 @@ public vector<T,N> subgroupShuffleXor(vector<T,N> value, uint mask)
__generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle_relative)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_shufflerelative)]
+[require(glsl_spirv_wgsl, subgroup_shufflerelative)]
public T subgroupShuffleUp(T value, uint delta)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupShuffleUp($0, $1)";
case spirv:
return spirv_asm {
@@ -7796,14 +7844,16 @@ public T subgroupShuffleUp(T value, uint delta)
__generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle_relative)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_shufflerelative)]
+[require(glsl_spirv_wgsl, subgroup_shufflerelative)]
public T subgroupShuffleDown(T value, uint delta)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupShuffleDown($0, $1)";
case spirv:
return spirv_asm {
@@ -7817,14 +7867,16 @@ public T subgroupShuffleDown(T value, uint delta)
__generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle_relative)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_shufflerelative)]
+[require(glsl_spirv_wgsl, subgroup_shufflerelative)]
public vector<T,N> subgroupShuffleUp(vector<T,N> value, uint delta)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupShuffleUp($0, $1)";
case spirv:
return spirv_asm {
@@ -7837,14 +7889,16 @@ public vector<T,N> subgroupShuffleUp(vector<T,N> value, uint delta)
__generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle_relative)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_shufflerelative)]
+[require(glsl_spirv_wgsl, subgroup_shufflerelative)]
public vector<T,N> subgroupShuffleDown(vector<T,N> value, uint delta)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupShuffleDown($0, $1)";
case spirv:
return spirv_asm {
@@ -8161,7 +8215,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv, subgroup_quad)]
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
public T subgroupQuadSwapHorizontal(T value)
{
shader_subgroup_preamble<T>();
@@ -8172,7 +8226,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv, subgroup_quad)]
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
public T subgroupQuadSwapVertical(T value)
{
shader_subgroup_preamble<T>();
@@ -8183,7 +8237,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv, subgroup_quad)]
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
public T subgroupQuadSwapDiagonal(T value)
{
shader_subgroup_preamble<T>();
@@ -8206,7 +8260,7 @@ __generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv, subgroup_quad)]
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
public vector<T,N> subgroupQuadSwapHorizontal(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -8217,7 +8271,7 @@ __generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv, subgroup_quad)]
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
public vector<T,N> subgroupQuadSwapVertical(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -8228,7 +8282,7 @@ __generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv, subgroup_quad)]
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
public vector<T,N> subgroupQuadSwapDiagonal(vector<T,N> value)
{
shader_subgroup_preamble<T>();
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index ecab7ff93..3baad5c10 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -14460,42 +14460,44 @@ __generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadLan
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
-[require(glsl_hlsl_spirv, subgroup_quad)]
+__wgsl_extension(subgroups)
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
T QuadReadAcrossX(T localValue)
{
__target_switch
{
- case hlsl:
- __intrinsic_asm "QuadReadAcrossX";
- case glsl:
- __intrinsic_asm "subgroupQuadSwapHorizontal($0)";
+ case hlsl: __intrinsic_asm "QuadReadAcrossX";
+ case glsl: __intrinsic_asm "subgroupQuadSwapHorizontal($0)";
case spirv:
uint direction = 0u;
- return spirv_asm {
+ return spirv_asm
+ {
OpCapability GroupNonUniformQuad;
result:$$T = OpGroupNonUniformQuadSwap Subgroup $localValue $direction;
};
+ case wgsl: __intrinsic_asm "quadSwapX";
}
}
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
-[require(glsl_hlsl_spirv, subgroup_quad)]
+__wgsl_extension(subgroups)
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
vector<T,N> QuadReadAcrossX(vector<T,N> localValue)
{
__target_switch
{
- case hlsl:
- __intrinsic_asm "QuadReadAcrossX";
- case glsl:
- __intrinsic_asm "subgroupQuadSwapHorizontal($0)";
+ case hlsl: __intrinsic_asm "QuadReadAcrossX";
+ case glsl: __intrinsic_asm "subgroupQuadSwapHorizontal($0)";
case spirv:
uint direction = 0u;
- return spirv_asm {
+ return spirv_asm
+ {
OpCapability GroupNonUniformQuad;
result:$$vector<T,N> = OpGroupNonUniformQuadSwap Subgroup $localValue $direction;
};
+ case wgsl: __intrinsic_asm "quadSwapX";
}
}
__generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadAcrossX(matrix<T,N,M> localValue);
@@ -14504,85 +14506,88 @@ __generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadAcr
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
-[require(glsl_hlsl_spirv, subgroup_quad)]
+__wgsl_extension(subgroups)
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
T QuadReadAcrossY(T localValue)
{
__target_switch
{
- case hlsl:
- __intrinsic_asm "QuadReadAcrossY";
- case glsl:
- __intrinsic_asm "subgroupQuadSwapVertical($0)";
+ case hlsl: __intrinsic_asm "QuadReadAcrossY";
+ case glsl: __intrinsic_asm "subgroupQuadSwapVertical($0)";
case spirv:
uint direction = 1u;
- return spirv_asm {
+ return spirv_asm
+ {
OpCapability GroupNonUniformQuad;
result:$$T = OpGroupNonUniformQuadSwap Subgroup $localValue $direction;
};
+ case wgsl: __intrinsic_asm "quadSwapY";
}
}
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
-[require(glsl_hlsl_spirv, subgroup_quad)]
+__wgsl_extension(subgroups)
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
vector<T,N> QuadReadAcrossY(vector<T,N> localValue)
{
__target_switch
{
- case hlsl:
- __intrinsic_asm "QuadReadAcrossY";
- case glsl:
- __intrinsic_asm "subgroupQuadSwapVertical($0)";
+ case hlsl: __intrinsic_asm "QuadReadAcrossY";
+ case glsl: __intrinsic_asm "subgroupQuadSwapVertical($0)";
case spirv:
uint direction = 1u;
- return spirv_asm {
+ return spirv_asm
+ {
OpCapability GroupNonUniformQuad;
result:$$vector<T,N> = OpGroupNonUniformQuadSwap Subgroup $localValue $direction;
};
+ case wgsl: __intrinsic_asm "quadSwapY";
}
}
-
__generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadAcrossY(matrix<T,N,M> localValue);
/// @category wave
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
-[require(glsl_hlsl_spirv, subgroup_quad)]
+__wgsl_extension(subgroups)
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
T QuadReadAcrossDiagonal(T localValue)
{
__target_switch
{
- case hlsl:
- __intrinsic_asm "QuadReadAcrossDiagonal";
- case glsl:
- __intrinsic_asm "subgroupQuadSwapDiagonal($0)";
+ case hlsl: __intrinsic_asm "QuadReadAcrossDiagonal";
+ case glsl: __intrinsic_asm "subgroupQuadSwapDiagonal($0)";
case spirv:
uint direction = 2u;
- return spirv_asm {
+ return spirv_asm
+ {
OpCapability GroupNonUniformQuad;
result:$$T = OpGroupNonUniformQuadSwap Subgroup $localValue $direction;
};
+ case wgsl: __intrinsic_asm "quadSwapDiagonal";
}
}
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
-[require(glsl_hlsl_spirv, subgroup_quad)]
+__wgsl_extension(subgroups)
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
vector<T,N> QuadReadAcrossDiagonal(vector<T,N> localValue)
{
__target_switch
{
- case hlsl:
- __intrinsic_asm "QuadReadAcrossDiagonal";
- case glsl:
- __intrinsic_asm "subgroupQuadSwapDiagonal($0)";
+ case hlsl: __intrinsic_asm "QuadReadAcrossDiagonal";
+ case glsl: __intrinsic_asm "subgroupQuadSwapDiagonal($0)";
case spirv:
uint direction = 2u;
- return spirv_asm {
+ return spirv_asm
+ {
OpCapability GroupNonUniformQuad;
result:$$vector<T,N> = OpGroupNonUniformQuadSwap Subgroup $localValue $direction;
};
+ case wgsl: __intrinsic_asm "quadSwapDiagonal";
}
}
__generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadAcrossDiagonal(matrix<T,N,M> localValue);
@@ -14597,16 +14602,19 @@ for (auto opName : kWaveActiveBitOpEntries) {
__generic<T : __BuiltinIntegerType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
T WaveActive$(opName.hlslName)(T expr)
{
__target_switch
{
- case glsl: __intrinsic_asm "subgroup$(opName.glslName)($0)";
+ case glsl:
+ case wgsl:
+ __intrinsic_asm "subgroup$(opName.glslName)";
case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
case spirv:
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniform$(opName.spirvName) $$T result Subgroup Reduce $expr};
- default:
+ case cuda:
return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
}
}
@@ -14614,22 +14622,25 @@ T WaveActive$(opName.hlslName)(T expr)
__generic<T : __BuiltinIntegerType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
vector<T, N> WaveActive$(opName.hlslName)(vector<T, N> expr)
{
__target_switch
{
- case glsl: __intrinsic_asm "subgroup$(opName.glslName)($0)";
+ case glsl:
+ case wgsl:
+ __intrinsic_asm "subgroup$(opName.glslName)";
case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
case spirv:
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniform$(opName.spirvName) $$vector<T, N> result Subgroup Reduce $expr};
- default:
+ case cuda:
return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
}
}
__generic<T : __BuiltinIntegerType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
matrix<T, N, M> WaveActive$(opName.hlslName)(matrix<T, N, M> expr)
{
__target_switch
@@ -14637,12 +14648,13 @@ matrix<T, N, M> WaveActive$(opName.hlslName)(matrix<T, N, M> expr)
case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
case glsl:
case spirv:
+ case wgsl:
matrix<T,N,M> result;
[ForceUnroll]
for (int i = 0; i < N; ++i)
result[i] = WaveActive$(opName.hlslName)(expr[i]);
return result;
- default:
+ case cuda:
return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
}
}
@@ -14659,12 +14671,15 @@ for (const char* opName : kWaveActiveMinMaxNames) {
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
T WaveActive$(opName)(T expr)
{
__target_switch
{
- case glsl: __intrinsic_asm "subgroup$(opName)($0)";
+ case glsl:
+ case wgsl:
+ __intrinsic_asm "subgroup$(opName)";
case hlsl: __intrinsic_asm "WaveActive$(opName)";
case spirv:
if (__isFloat<T>())
@@ -14673,7 +14688,7 @@ T WaveActive$(opName)(T expr)
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformU$(opName) $$T result Subgroup Reduce $expr};
else
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformS$(opName) $$T result Subgroup Reduce $expr};
- default:
+ case cuda:
return WaveMask$(opName)(WaveGetActiveMask(), expr);
}
}
@@ -14681,12 +14696,15 @@ T WaveActive$(opName)(T expr)
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
vector<T, N> WaveActive$(opName)(vector<T, N> expr)
{
__target_switch
{
- case glsl: __intrinsic_asm "subgroup$(opName)($0)";
+ case glsl:
+ case wgsl:
+ __intrinsic_asm "subgroup$(opName)";
case hlsl: __intrinsic_asm "WaveActive$(opName)";
case spirv:
if (__isFloat<T>())
@@ -14695,13 +14713,13 @@ vector<T, N> WaveActive$(opName)(vector<T, N> expr)
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformU$(opName) $$vector<T, N> result Subgroup Reduce $expr};
else
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformS$(opName) $$vector<T, N> result Subgroup Reduce $expr};
- default:
+ case cuda:
return WaveMask$(opName)(WaveGetActiveMask(), expr);
}
}
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
matrix<T, N, M> WaveActive$(opName)(matrix<T, N, M> expr)
{
__target_switch
@@ -14709,12 +14727,13 @@ matrix<T, N, M> WaveActive$(opName)(matrix<T, N, M> expr)
case hlsl: __intrinsic_asm "WaveActive$(opName)";
case glsl:
case spirv:
+ case wgsl:
matrix<T, N, M> result;
[ForceUnroll]
for (int i = 0; i < N; ++i)
result[i] = WaveActive$(opName)(expr[i]);
return result;
- default:
+ case cuda:
return WaveMask$(opName)(WaveGetActiveMask(), expr);
}
}
@@ -14733,7 +14752,8 @@ for (auto opName : kWaveActivProductSumNames) {
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
T WaveActive$(opName.hlslName)(T expr)
{
__target_switch
@@ -14757,7 +14777,8 @@ T WaveActive$(opName.hlslName)(T expr)
};
}
else return expr;
- default:
+ case wgsl: __intrinsic_asm "subgroup$(opName.glslName)";
+ case cuda:
return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
}
}
@@ -14765,7 +14786,8 @@ T WaveActive$(opName.hlslName)(T expr)
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
vector<T,N> WaveActive$(opName.hlslName)(vector<T,N> expr)
{
__target_switch
@@ -14789,13 +14811,14 @@ vector<T,N> WaveActive$(opName.hlslName)(vector<T,N> expr)
};
}
else return expr;
- default:
+ case wgsl: __intrinsic_asm "subgroup$(opName.glslName)";
+ case cuda:
return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
}
}
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
matrix<T, N, M> WaveActive$(opName.hlslName)(matrix<T, N, M> expr)
{
__target_switch
@@ -14803,12 +14826,13 @@ matrix<T, N, M> WaveActive$(opName.hlslName)(matrix<T, N, M> expr)
case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
case glsl:
case spirv:
+ case wgsl:
matrix<T, N, M> result;
[ForceUnroll]
for (int i = 0; i < N; ++i)
result[i] = WaveActive$(opName.hlslName)(expr[i]);
return result;
- default:
+ case cuda:
return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
}
}
@@ -14877,22 +14901,23 @@ bool WaveActiveAllEqual(matrix<T, N, M> value)
/// @category wave
__glsl_extension(GL_KHR_shader_subgroup_vote)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_vote)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_vote)]
bool WaveActiveAllTrue(bool condition)
{
__target_switch
{
case glsl:
- __intrinsic_asm "subgroupAll($0)";
- case hlsl:
- __intrinsic_asm "WaveActiveAllTrue($0)";
+ case wgsl:
+ __intrinsic_asm "subgroupAll";
+ case hlsl: __intrinsic_asm "WaveActiveAllTrue($0)";
case spirv:
return spirv_asm
{
OpCapability GroupNonUniformVote;
OpGroupNonUniformAll $$bool result Subgroup $condition
};
- default:
+ case cuda:
return WaveMaskAllTrue(WaveGetActiveMask(), condition);
}
}
@@ -14900,13 +14925,15 @@ bool WaveActiveAllTrue(bool condition)
/// @category wave
__glsl_extension(GL_KHR_shader_subgroup_vote)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_vote)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_vote)]
bool WaveActiveAnyTrue(bool condition)
{
__target_switch
{
case glsl:
- __intrinsic_asm "subgroupAny($0)";
+ case wgsl:
+ __intrinsic_asm "subgroupAny";
case hlsl:
__intrinsic_asm "WaveActiveAnyTrue($0)";
case spirv:
@@ -14923,14 +14950,16 @@ bool WaveActiveAnyTrue(bool condition)
/// @category wave
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
+__wgsl_extension(subgroups)
[NonUniformReturn]
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
uint4 WaveActiveBallot(bool condition)
{
__target_switch
{
case glsl:
- __intrinsic_asm "subgroupBallot($0)";
+ case wgsl:
+ __intrinsic_asm "subgroupBallot";
case hlsl:
__intrinsic_asm "WaveActiveBallot";
case spirv:
@@ -15004,23 +15033,23 @@ uint WaveGetLaneIndex()
/// @category wave
__glsl_extension(GL_KHR_shader_subgroup_basic)
__spirv_version(1.3)
+__wgsl_extension(subgroups)
[NonUniformReturn]
-[require(cuda_glsl_hlsl_spirv, subgroup_basic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
bool WaveIsFirstLane()
{
__target_switch
{
- case glsl:
- __intrinsic_asm "subgroupElect()";
- case hlsl:
- __intrinsic_asm "WaveIsFirstLane()";
+ case glsl: __intrinsic_asm "subgroupElect()";
+ case hlsl: __intrinsic_asm "WaveIsFirstLane()";
case spirv:
return spirv_asm
{
OpCapability GroupNonUniformBallot;
OpGroupNonUniformElect $$bool result Subgroup
};
- default:
+ case wgsl: __intrinsic_asm "subgroupElect";
+ case cuda:
return WaveMaskIsFirstLane(WaveGetActiveMask());
}
}
@@ -15059,7 +15088,8 @@ uint _WaveCountBits(uint4 value)
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
T WavePrefixProduct(T expr)
{
__target_switch
@@ -15083,7 +15113,8 @@ T WavePrefixProduct(T expr)
};
}
else return expr;
- default:
+ case wgsl: __intrinsic_asm "subgroupExclusiveMul";
+ case cuda:
return WaveMaskPrefixProduct(WaveGetActiveMask(), expr);
}
}
@@ -15092,7 +15123,8 @@ T WavePrefixProduct(T expr)
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
vector<T,N> WavePrefixProduct(vector<T,N> expr)
{
__target_switch
@@ -15113,13 +15145,14 @@ vector<T,N> WavePrefixProduct(vector<T,N> expr)
};
}
else return expr;
- default:
+ case wgsl: __intrinsic_asm "subgroupExclusiveMul";
+ case cuda:
return WaveMaskPrefixProduct(WaveGetActiveMask(), expr);
}
}
/// @category wave
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
matrix<T, N, M> WavePrefixProduct(matrix<T, N, M> expr)
{
__target_switch
@@ -15127,11 +15160,12 @@ matrix<T, N, M> WavePrefixProduct(matrix<T, N, M> expr)
case hlsl: __intrinsic_asm "WavePrefixProduct";
case glsl:
case spirv:
+ case wgsl:
matrix<T, N, M> result;
for (int i = 0; i < N; ++i)
result[i] = WavePrefixProduct(expr[i]);
return result;
- default:
+ case cuda:
return WaveMaskPrefixProduct(WaveGetActiveMask(), expr);
}
}
@@ -15140,7 +15174,8 @@ matrix<T, N, M> WavePrefixProduct(matrix<T, N, M> expr)
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
T WavePrefixSum(T expr)
{
__target_switch
@@ -15161,7 +15196,8 @@ T WavePrefixSum(T expr)
};
}
else return expr;
- default:
+ case wgsl: __intrinsic_asm "subgroupExclusiveAdd";
+ case cuda:
return WaveMaskPrefixSum(WaveGetActiveMask(), expr);
}
}
@@ -15169,7 +15205,8 @@ T WavePrefixSum(T expr)
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
vector<T,N> WavePrefixSum(vector<T,N> expr)
{
__target_switch
@@ -15190,13 +15227,14 @@ vector<T,N> WavePrefixSum(vector<T,N> expr)
};
}
else return expr;
- default:
+ case wgsl: __intrinsic_asm "subgroupExclusiveAdd";
+ case cuda:
return WaveMaskPrefixSum(WaveGetActiveMask(), expr);
}
}
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr)
{
__target_switch
@@ -15204,11 +15242,12 @@ matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr)
case hlsl: __intrinsic_asm "WavePrefixSum";
case glsl:
case spirv:
+ case wgsl:
matrix<T, N, M> result;
for (int i = 0; i < N; ++i)
result[i] = WavePrefixSum(expr[i]);
return result;
- default:
+ case cuda:
return WaveMaskPrefixSum(WaveGetActiveMask(), expr);
}
}
@@ -15217,7 +15256,8 @@ matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr)
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
T WaveReadLaneFirst(T expr)
{
__target_switch
@@ -15228,7 +15268,8 @@ T WaveReadLaneFirst(T expr)
case hlsl: __intrinsic_asm "WaveReadLaneFirst";
case spirv:
return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$T result Subgroup $expr};
- default:
+ case wgsl: __intrinsic_asm "subgroupBroadcastFirst";
+ case cuda:
return WaveMaskReadLaneFirst(WaveGetActiveMask(), expr);
}
}
@@ -15236,7 +15277,8 @@ T WaveReadLaneFirst(T expr)
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
vector<T,N> WaveReadLaneFirst(vector<T,N> expr)
{
__target_switch
@@ -15247,13 +15289,14 @@ vector<T,N> WaveReadLaneFirst(vector<T,N> expr)
case hlsl: __intrinsic_asm "WaveReadLaneFirst";
case spirv:
return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$vector<T,N> result Subgroup $expr};
- default:
+ case wgsl: __intrinsic_asm "subgroupBroadcastFirst";
+ case cuda:
return WaveMaskReadLaneFirst(WaveGetActiveMask(), expr);
}
}
__generic<T : __BuiltinType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
matrix<T,N,M> WaveReadLaneFirst(matrix<T,N,M> expr)
{
__target_switch
@@ -15261,11 +15304,12 @@ matrix<T,N,M> WaveReadLaneFirst(matrix<T,N,M> expr)
case hlsl: __intrinsic_asm "WaveReadLaneFirst";
case glsl:
case spirv:
+ case wgsl:
matrix<T, N, M> result;
for (int i = 0; i < N; ++i)
result[i] = WaveReadLaneFirst(expr[i]);
return result;
- default:
+ case cuda:
return WaveMaskReadLaneFirst(WaveGetActiveMask(), expr);
}
}
@@ -15280,7 +15324,8 @@ matrix<T,N,M> WaveReadLaneFirst(matrix<T,N,M> expr)
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
T WaveBroadcastLaneAt(T value, constexpr int lane)
{
__target_switch
@@ -15292,7 +15337,8 @@ T WaveBroadcastLaneAt(T value, constexpr int lane)
case spirv:
let ulane = uint(lane);
return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcast $$T result Subgroup $value $ulane};
- default:
+ case wgsl: __intrinsic_asm "subgroupBroadcast";
+ case cuda:
return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, lane);
}
}
@@ -15301,7 +15347,8 @@ T WaveBroadcastLaneAt(T value, constexpr int lane)
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
vector<T,N> WaveBroadcastLaneAt(vector<T,N> value, constexpr int lane)
{
__target_switch
@@ -15313,13 +15360,14 @@ vector<T,N> WaveBroadcastLaneAt(vector<T,N> value, constexpr int lane)
case spirv:
let ulane = uint(lane);
return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcast $$vector<T,N> result Subgroup $value $ulane};
- default:
+ case wgsl: __intrinsic_asm "subgroupBroadcast";
+ case cuda:
return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, lane);
}
}
__generic<T : __BuiltinType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
matrix<T, N, M> WaveBroadcastLaneAt(matrix<T, N, M> value, constexpr int lane)
{
__target_switch
@@ -15328,11 +15376,12 @@ matrix<T, N, M> WaveBroadcastLaneAt(matrix<T, N, M> value, constexpr int lane)
case hlsl: __intrinsic_asm "WaveReadLaneAt";
case glsl:
case spirv:
+ case wgsl:
matrix<T, N, M> result;
for (int i = 0; i < N; ++i)
result[i] = WaveBroadcastLaneAt(value[i], lane);
return result;
- default:
+ case cuda:
return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, lane);
}
}
@@ -15343,7 +15392,8 @@ matrix<T, N, M> WaveBroadcastLaneAt(matrix<T, N, M> value, constexpr int lane)
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
T WaveReadLaneAt(T value, int lane)
{
__target_switch
@@ -15355,15 +15405,17 @@ T WaveReadLaneAt(T value, int lane)
case spirv:
let ulane = uint(lane);
return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$T result Subgroup $value $ulane};
- default:
+ case wgsl: __intrinsic_asm "subgroupShuffle";
+ case cuda:
return WaveMaskReadLaneAt(WaveGetActiveMask(), value, lane);
}
}
__generic<T : __BuiltinType, let N : int>
-__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
-[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)]
+__spirv_version(1.3)
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
vector<T,N> WaveReadLaneAt(vector<T,N> value, int lane)
{
__target_switch
@@ -15375,13 +15427,14 @@ vector<T,N> WaveReadLaneAt(vector<T,N> value, int lane)
case spirv:
let ulane = uint(lane);
return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$vector<T,N> result Subgroup $value $ulane};
- default:
+ case wgsl: __intrinsic_asm "subgroupShuffle";
+ case cuda:
return WaveMaskReadLaneAt(WaveGetActiveMask(), value, lane);
}
}
__generic<T : __BuiltinType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
matrix<T, N, M> WaveReadLaneAt(matrix<T, N, M> value, int lane)
{
__target_switch
@@ -15390,11 +15443,12 @@ matrix<T, N, M> WaveReadLaneAt(matrix<T, N, M> value, int lane)
case hlsl: __intrinsic_asm "WaveReadLaneAt";
case glsl:
case spirv:
+ case wgsl:
matrix<T,N,M> result;
for (int i = 0; i < N; ++i)
result[i] = WaveReadLaneAt(value[i], lane);
return result;
- default:
+ case cuda:
return WaveMaskReadLaneAt(WaveGetActiveMask(), value, lane);
}
}
@@ -15406,7 +15460,8 @@ matrix<T, N, M> WaveReadLaneAt(matrix<T, N, M> value, int lane)
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
T WaveShuffle(T value, int lane)
{
__target_switch
@@ -15418,7 +15473,8 @@ T WaveShuffle(T value, int lane)
case spirv:
let ulane = uint(lane);
return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$T result Subgroup $value $ulane};
- default:
+ case wgsl: __intrinsic_asm "subgroupShuffle";
+ case cuda:
return WaveMaskShuffle(WaveGetActiveMask(), value, lane);
}
}
@@ -15427,7 +15483,8 @@ T WaveShuffle(T value, int lane)
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
vector<T,N> WaveShuffle(vector<T,N> value, int lane)
{
__target_switch
@@ -15439,7 +15496,8 @@ vector<T,N> WaveShuffle(vector<T,N> value, int lane)
case spirv:
let ulane = uint(lane);
return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$vector<T,N> result Subgroup $value $ulane};
- default:
+ case wgsl: __intrinsic_asm "subgroupShuffle";
+ case cuda:
return WaveMaskShuffle(WaveGetActiveMask(), value, lane);
}
}
@@ -15482,12 +15540,14 @@ uint WavePrefixCountBits(bool value)
/// @category wave
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
uint4 WaveGetConvergedMulti()
{
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupBallot(true)";
case hlsl: __intrinsic_asm "WaveActiveBallot(true)";
case cuda: __intrinsic_asm "make_uint4(__activemask(), 0, 0, 0)";
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 36dddd15f..cc4901236 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -218,6 +218,15 @@ class RequiredSPIRVVersionModifier : public Modifier
};
// A modifier to tag something as an intrinsic that requires
+// a certain WGSL extension to be enabled when used
+class RequiredWGSLExtensionModifier : public Modifier
+{
+ SLANG_AST_CLASS(RequiredWGSLExtensionModifier)
+
+ Token extensionNameToken;
+};
+
+// A modifier to tag something as an intrinsic that requires
// a certain CUDA SM version to be enabled when used. Specified as "major.minor"
class RequiredCUDASMVersionModifier : public Modifier
{
diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp
index 22a5ec8d1..597a35f4b 100644
--- a/source/slang/slang-ast-print.cpp
+++ b/source/slang/slang-ast-print.cpp
@@ -434,6 +434,8 @@ void ASTPrinter::addDeclKindPrefix(Decl* decl)
continue;
if (as<RequiredGLSLExtensionModifier>(modifier))
continue;
+ if (as<RequiredWGSLExtensionModifier>(modifier))
+ continue;
if (as<GLSLLayoutModifierGroupMarker>(modifier))
continue;
if (as<HLSLLayoutSemantic>(modifier))
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index 3be09b8d3..f98be0e32 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -331,6 +331,10 @@ alias cuda_hlsl_metal_spirv = cuda | hlsl | metal | spirv;
/// [Compound]
alias cuda_glsl_hlsl_spirv = cuda | glsl | hlsl | spirv;
+/// CUDA, GLSL, HLSL, SPIRV, and WGSL code-gen targets
+/// [Compound]
+alias cuda_glsl_hlsl_spirv_wgsl = cuda | glsl | hlsl | spirv | wgsl;
+
/// CUDA, GLSL, HLSL, Metal, and SPIRV code-gen targets
/// [Compound]
alias cuda_glsl_hlsl_metal_spirv = cuda | glsl | hlsl | metal | spirv;
@@ -387,6 +391,10 @@ alias glsl_metal_spirv_wgsl = glsl | metal | spirv | wgsl;
/// [Compound]
alias glsl_spirv = glsl | spirv;
+/// GLSL, SPIRV, and WGSL code-gen targets
+/// [Compound]
+alias glsl_spirv_wgsl = glsl | spirv | wgsl;
+
/// HLSL, and SPIRV code-gen targets
/// [Compound]
alias hlsl_spirv = hlsl | spirv;
@@ -1931,13 +1939,18 @@ alias shader5_sm_5_0 = GL_ARB_gpu_shader5 | sm_5_0_version;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_basic'
/// [Compound]
-alias subgroup_basic = GL_KHR_shader_subgroup_basic | _sm_6_0 | _cuda_sm_7_0;
+alias subgroup_basic = GL_KHR_shader_subgroup_basic
+ | _sm_6_0
+ | _cuda_sm_7_0
+ | wgsl
+ ;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_ballot'
/// [Compound]
alias subgroup_ballot = spirv_1_0 + GL_KHR_shader_subgroup_ballot
| glsl + GL_KHR_shader_subgroup_ballot + shader5_sm_5_0
| _sm_6_0 + shader5_sm_5_0
| _cuda_sm_7_0 + shader5_sm_5_0
+ | wgsl
;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_ballot_activemask'
/// [Compound]
@@ -1952,28 +1965,50 @@ alias subgroup_basic_ballot = glsl + GL_KHR_shader_subgroup_basic + subgroup_bal
| spirv + GL_KHR_shader_subgroup_basic + subgroup_ballot
| hlsl + subgroup_ballot
| cuda + subgroup_ballot
+ | wgsl
;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_vote'
/// [Compound]
-alias subgroup_vote = GL_KHR_shader_subgroup_vote | _sm_6_0 | _cuda_sm_7_0;
+alias subgroup_vote = GL_KHR_shader_subgroup_vote
+ | _sm_6_0
+ | _cuda_sm_7_0
+ | wgsl
+ ;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_vote'
/// [Compound]
alias shaderinvocationgroup = subgroup_vote;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_arithmetic'
/// [Compound]
-alias subgroup_arithmetic = GL_KHR_shader_subgroup_arithmetic | _sm_6_0 | _cuda_sm_7_0;
+alias subgroup_arithmetic = GL_KHR_shader_subgroup_arithmetic
+ | _sm_6_0
+ | _cuda_sm_7_0
+ | wgsl
+ ;
+
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_shuffle'
/// [Compound]
-alias subgroup_shuffle = GL_KHR_shader_subgroup_shuffle | _sm_6_0 | _cuda_sm_7_0;
+alias subgroup_shuffle = GL_KHR_shader_subgroup_shuffle
+ | _sm_6_0
+ | _cuda_sm_7_0
+ | wgsl
+ ;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_shuffle_relative'
/// [Compound]
-alias subgroup_shufflerelative = GL_KHR_shader_subgroup_shuffle_relative | _sm_6_0 | _cuda_sm_7_0;
+alias subgroup_shufflerelative = GL_KHR_shader_subgroup_shuffle_relative
+ | _sm_6_0
+ | _cuda_sm_7_0
+ | wgsl
+ ;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_clustered'
/// [Compound]
alias subgroup_clustered = GL_KHR_shader_subgroup_clustered | _sm_6_0 | _cuda_sm_7_0;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_quad'
/// [Compound]
-alias subgroup_quad = GL_KHR_shader_subgroup_quad | _sm_6_0 | _cuda_sm_7_0;
+alias subgroup_quad = GL_KHR_shader_subgroup_quad
+ | _sm_6_0
+ | _cuda_sm_7_0
+ | wgsl
+ ;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_partitioned'
/// [Compound]
alias subgroup_partitioned = GL_NV_shader_subgroup_partitioned + subgroup_ballot_activemask | _sm_6_5 | _cuda_sm_7_0;
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index 448534ce8..04ebb753c 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -28,7 +28,7 @@
// Artifact output
#include "slang-artifact-output-util.h"
#include "slang-emit-cuda.h"
-#include "slang-glsl-extension-tracker.h"
+#include "slang-extension-tracker.h"
#include "slang-lower-to-ir.h"
#include "slang-mangle.h"
#include "slang-parameter-binding.h"
@@ -658,7 +658,7 @@ static void _appendCodeWithPath(
outCodeBuilder << fileContent << "\n";
}
-void trackGLSLTargetCaps(GLSLExtensionTracker* extensionTracker, CapabilitySet const& caps)
+void trackGLSLTargetCaps(ShaderExtensionTracker* extensionTracker, CapabilitySet const& caps)
{
for (auto& conjunctions : caps.getAtomSets())
{
@@ -1037,8 +1037,11 @@ static RefPtr<ExtensionTracker> _newExtensionTracker(CodeGenTarget target)
}
case CodeGenTarget::SPIRV:
case CodeGenTarget::GLSL:
+ case CodeGenTarget::WGSL:
+ case CodeGenTarget::WGSLSPIRV:
+ case CodeGenTarget::WGSLSPIRVAssembly:
{
- return new GLSLExtensionTracker;
+ return new ShaderExtensionTracker;
}
default:
return nullptr;
@@ -1261,7 +1264,7 @@ SlangResult CodeGenContext::emitWithDownstreamForEntryPoints(ComPtr<IArtifact>&
if (auto endToEndReq = isPassThroughEnabled())
{
// If we are pass through, we may need to set extension tracker state.
- if (GLSLExtensionTracker* glslTracker = as<GLSLExtensionTracker>(extensionTracker))
+ if (ShaderExtensionTracker* glslTracker = as<ShaderExtensionTracker>(extensionTracker))
{
trackGLSLTargetCaps(glslTracker, getTargetCaps());
}
@@ -1400,7 +1403,7 @@ SlangResult CodeGenContext::emitWithDownstreamForEntryPoints(ComPtr<IArtifact>&
options.flags |= CompileOptions::Flag::EnableFloat16;
}
}
- else if (GLSLExtensionTracker* glslTracker = as<GLSLExtensionTracker>(extensionTracker))
+ else if (ShaderExtensionTracker* glslTracker = as<ShaderExtensionTracker>(extensionTracker))
{
DownstreamCompileOptions::CapabilityVersion version;
version.kind = DownstreamCompileOptions::CapabilityVersion::Kind::SPIRV;
diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp
index 1429139f9..25dab3fb3 100644
--- a/source/slang/slang-emit-glsl.cpp
+++ b/source/slang/slang-emit-glsl.cpp
@@ -15,13 +15,13 @@
namespace Slang
{
-void trackGLSLTargetCaps(GLSLExtensionTracker* extensionTracker, CapabilitySet const& caps);
+void trackGLSLTargetCaps(ShaderExtensionTracker* extensionTracker, CapabilitySet const& caps);
GLSLSourceEmitter::GLSLSourceEmitter(const Desc& desc)
: Super(desc)
{
m_glslExtensionTracker =
- dynamicCast<GLSLExtensionTracker>(desc.codeGenContext->getExtensionTracker());
+ dynamicCast<ShaderExtensionTracker>(desc.codeGenContext->getExtensionTracker());
SLANG_ASSERT(m_glslExtensionTracker);
}
@@ -2997,7 +2997,7 @@ void GLSLSourceEmitter::emitFrontMatterImpl(TargetRequest* targetReq)
trackGLSLTargetCaps(m_glslExtensionTracker, targetReq->getTargetCaps());
StringBuilder builder;
- m_glslExtensionTracker->appendExtensionRequireLines(builder);
+ m_glslExtensionTracker->appendExtensionRequireLinesForGLSL(builder);
m_writer->emit(builder.getUnownedSlice());
}
diff --git a/source/slang/slang-emit-glsl.h b/source/slang/slang-emit-glsl.h
index b07b410ca..8308a9954 100644
--- a/source/slang/slang-emit-glsl.h
+++ b/source/slang/slang-emit-glsl.h
@@ -3,7 +3,7 @@
#define SLANG_EMIT_GLSL_H
#include "slang-emit-c-like.h"
-#include "slang-glsl-extension-tracker.h"
+#include "slang-extension-tracker.h"
namespace Slang
{
@@ -180,7 +180,7 @@ protected:
Dictionary<IRInst*, HashSet<IRFunc*>> m_referencingEntryPoints;
- RefPtr<GLSLExtensionTracker> m_glslExtensionTracker;
+ RefPtr<ShaderExtensionTracker> m_glslExtensionTracker;
};
} // namespace Slang
diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp
index ce60cc2a0..aea766f9f 100644
--- a/source/slang/slang-emit-wgsl.cpp
+++ b/source/slang/slang-emit-wgsl.cpp
@@ -49,6 +49,14 @@ fn _slang_getNan() -> f32
}
)";
+WGSLSourceEmitter::WGSLSourceEmitter(const Desc& desc)
+ : CLikeSourceEmitter(desc)
+{
+ m_extensionTracker =
+ dynamicCast<ShaderExtensionTracker>(desc.codeGenContext->getExtensionTracker());
+ SLANG_ASSERT(m_extensionTracker);
+}
+
void WGSLSourceEmitter::emitSwitchCaseSelectorsImpl(
const SwitchRegion::Case* const currentCase,
const bool isDefault)
@@ -1556,6 +1564,10 @@ void WGSLSourceEmitter::emitFrontMatterImpl(TargetRequest* /* targetReq */)
m_writer->emit("enable f16;\n");
m_writer->emit("\n");
}
+
+ StringBuilder builder;
+ m_extensionTracker->appendExtensionRequireLinesForWGSL(builder);
+ m_writer->emit(builder.getUnownedSlice());
}
void WGSLSourceEmitter::emitIntrinsicCallExprImpl(
@@ -1626,4 +1638,28 @@ void WGSLSourceEmitter::emitInterpolationModifiersImpl(
// https://www.w3.org/TR/WGSL/#interpolation
}
+void WGSLSourceEmitter::_requireExtension(const UnownedStringSlice& name)
+{
+ m_extensionTracker->requireExtension(name);
+}
+
+void WGSLSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst)
+{
+ for (auto decoration : inst->getDecorations())
+ {
+ if (const auto extensionDecoration = as<IRRequireWGSLExtensionDecoration>(decoration))
+ {
+ _requireExtension(extensionDecoration->getExtensionName());
+
+ // TODO: Make this cleaner and only enable this extension if f16 is actually used on the
+ // subgroup intrinsic. Check float type in meta file.
+ if (m_f16ExtensionEnabled && extensionDecoration->getExtensionName() == "subgroups")
+ {
+ String extName = "subgroups_f16";
+ _requireExtension(extName.getUnownedSlice());
+ }
+ }
+ }
+}
+
} // namespace Slang
diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h
index 714a722d7..441933b57 100644
--- a/source/slang/slang-emit-wgsl.h
+++ b/source/slang/slang-emit-wgsl.h
@@ -1,6 +1,7 @@
#pragma once
#include "slang-emit-c-like.h"
+#include "slang-extension-tracker.h"
namespace Slang
{
@@ -8,10 +9,8 @@ namespace Slang
class WGSLSourceEmitter : public CLikeSourceEmitter
{
public:
- WGSLSourceEmitter(const Desc& desc)
- : CLikeSourceEmitter(desc)
- {
- }
+ explicit WGSLSourceEmitter(const Desc& desc);
+
virtual bool isResourceTypeBindless(IRType* type) SLANG_OVERRIDE
{
SLANG_UNUSED(type);
@@ -58,10 +57,14 @@ public:
EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE;
virtual void emitGlobalParamDefaultVal(IRGlobalParam* varDecl) SLANG_OVERRIDE;
+ virtual void handleRequiredCapabilitiesImpl(IRInst* inst) SLANG_OVERRIDE;
+
void emit(const AddressSpace addressSpace);
virtual bool shouldFoldInstIntoUseSites(IRInst* inst) SLANG_OVERRIDE;
+ virtual RefObject* getExtensionTracker() SLANG_OVERRIDE { return m_extensionTracker; }
+
private:
bool maybeEmitSystemSemantic(IRInst* inst);
@@ -73,7 +76,11 @@ private:
const char* getWgslImageFormat(IRTextureTypeBase* type);
+ void _requireExtension(const UnownedStringSlice& name);
+
bool m_f16ExtensionEnabled = false;
+
+ RefPtr<ShaderExtensionTracker> m_extensionTracker;
};
} // namespace Slang
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 70521c7ee..58376bbc1 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -1400,10 +1400,10 @@ Result linkAndOptimizeIR(
case CodeGenTarget::SPIRV:
case CodeGenTarget::SPIRVAssembly:
{
- GLSLExtensionTracker glslExtensionTracker;
- GLSLExtensionTracker* glslExtensionTrackerPtr =
+ ShaderExtensionTracker glslExtensionTracker;
+ ShaderExtensionTracker* glslExtensionTrackerPtr =
options.sourceEmitter
- ? as<GLSLExtensionTracker>(options.sourceEmitter->getExtensionTracker())
+ ? as<ShaderExtensionTracker>(options.sourceEmitter->getExtensionTracker())
: &glslExtensionTracker;
#if 0
diff --git a/source/slang/slang-glsl-extension-tracker.cpp b/source/slang/slang-extension-tracker.cpp
index 268f123f7..f3818cbba 100644
--- a/source/slang/slang-glsl-extension-tracker.cpp
+++ b/source/slang/slang-extension-tracker.cpp
@@ -1,10 +1,10 @@
-// slang-glsl-extension-tracker.cpp
-#include "slang-glsl-extension-tracker.h"
+// slang-extension-tracker.cpp
+#include "slang-extension-tracker.h"
namespace Slang
{
-void GLSLExtensionTracker::appendExtensionRequireLines(StringBuilder& ioBuilder) const
+void ShaderExtensionTracker::appendExtensionRequireLinesForGLSL(StringBuilder& ioBuilder) const
{
for (const auto& extension : m_extensionPool.getSlices())
{
@@ -14,7 +14,17 @@ void GLSLExtensionTracker::appendExtensionRequireLines(StringBuilder& ioBuilder)
}
}
-void GLSLExtensionTracker::requireSPIRVVersion(const SemanticVersion& version)
+void ShaderExtensionTracker::appendExtensionRequireLinesForWGSL(StringBuilder& ioBuilder) const
+{
+ for (const auto& extension : m_extensionPool.getSlices())
+ {
+ ioBuilder.append("enable ");
+ ioBuilder.append(extension);
+ ioBuilder.append(";\n");
+ }
+}
+
+void ShaderExtensionTracker::requireSPIRVVersion(const SemanticVersion& version)
{
if (version > m_spirvVersion)
{
@@ -22,7 +32,7 @@ void GLSLExtensionTracker::requireSPIRVVersion(const SemanticVersion& version)
}
}
-void GLSLExtensionTracker::requireVersion(ProfileVersion version)
+void ShaderExtensionTracker::requireVersion(ProfileVersion version)
{
// Check if this profile is newer
if ((UInt)version > (UInt)m_profileVersion)
@@ -31,7 +41,7 @@ void GLSLExtensionTracker::requireVersion(ProfileVersion version)
}
}
-void GLSLExtensionTracker::requireBaseTypeExtension(BaseType baseType)
+void ShaderExtensionTracker::requireBaseTypeExtension(BaseType baseType)
{
uint32_t bit = 1 << int(baseType);
if (m_hasBaseTypeFlags & bit)
diff --git a/source/slang/slang-glsl-extension-tracker.h b/source/slang/slang-extension-tracker.h
index 08e0c9ef1..7134c4ff5 100644
--- a/source/slang/slang-glsl-extension-tracker.h
+++ b/source/slang/slang-extension-tracker.h
@@ -1,8 +1,6 @@
-// slang-glsl-extension-tracker.h
-#ifndef SLANG_GLSL_EXTENSION_TRACKER_H
-#define SLANG_GLSL_EXTENSION_TRACKER_H
+// slang-extension-tracker.h
+#pragma once
-#include "../core/slang-basic.h"
#include "../core/slang-semantic-version.h"
#include "../core/slang-string-slice-pool.h"
#include "slang-compiler.h"
@@ -10,7 +8,7 @@
namespace Slang
{
-class GLSLExtensionTracker : public ExtensionTracker
+class ShaderExtensionTracker : public ExtensionTracker
{
public:
/// Return the list of extensionsspecified. NOTE that they are specified in the order requested,
@@ -23,11 +21,12 @@ public:
void requireSPIRVVersion(const SemanticVersion& version);
ProfileVersion getRequiredProfileVersion() const { return m_profileVersion; }
- void appendExtensionRequireLines(StringBuilder& builder) const;
+ void appendExtensionRequireLinesForGLSL(StringBuilder& builder) const;
+ void appendExtensionRequireLinesForWGSL(StringBuilder& builder) const;
const SemanticVersion& getSPIRVVersion() const { return m_spirvVersion; }
- GLSLExtensionTracker()
+ ShaderExtensionTracker()
: m_extensionPool(StringSlicePool::Style::Empty)
{
}
@@ -39,6 +38,7 @@ protected:
_getFlag(BaseType::UInt) | _getFlag(BaseType::Void) |
_getFlag(BaseType::Bool);
+ // Only valid for GLSL targets.
ProfileVersion m_profileVersion = ProfileVersion::GLSL_150;
StringSlicePool m_extensionPool;
@@ -47,4 +47,3 @@ protected:
};
} // namespace Slang
-#endif
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index 7a46f45b4..04fb54924 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -1,7 +1,7 @@
// slang-ir-glsl-legalize.cpp
#include "slang-ir-glsl-legalize.h"
-#include "slang-glsl-extension-tracker.h"
+#include "slang-extension-tracker.h"
#include "slang-ir-clone.h"
#include "slang-ir-inst-pass-base.h"
#include "slang-ir-insts.h"
@@ -279,7 +279,7 @@ List<IRInst*> ScalarizedVal::leafAddresses()
struct GLSLLegalizationContext
{
Session* session;
- GLSLExtensionTracker* glslExtensionTracker;
+ ShaderExtensionTracker* glslExtensionTracker;
DiagnosticSink* sink;
Stage stage;
IRFunc* entryPointFunc;
@@ -3654,7 +3654,7 @@ void legalizeEntryPointForGLSL(
IRModule* module,
IRFunc* func,
CodeGenContext* codeGenContext,
- GLSLExtensionTracker* glslExtensionTracker)
+ ShaderExtensionTracker* glslExtensionTracker)
{
auto entryPointDecor = func->findDecoration<IREntryPointDecoration>();
SLANG_ASSERT(entryPointDecor);
@@ -3885,7 +3885,7 @@ void legalizeEntryPointsForGLSL(
IRModule* module,
const List<IRFunc*>& funcs,
CodeGenContext* context,
- GLSLExtensionTracker* glslExtensionTracker)
+ ShaderExtensionTracker* glslExtensionTracker)
{
for (auto func : funcs)
{
diff --git a/source/slang/slang-ir-glsl-legalize.h b/source/slang/slang-ir-glsl-legalize.h
index 2bb7730e7..a3e607ca8 100644
--- a/source/slang/slang-ir-glsl-legalize.h
+++ b/source/slang/slang-ir-glsl-legalize.h
@@ -9,7 +9,7 @@ namespace Slang
class DiagnosticSink;
class Session;
-class GLSLExtensionTracker;
+class ShaderExtensionTracker;
struct IRFunc;
struct IRModule;
@@ -19,7 +19,7 @@ void legalizeEntryPointsForGLSL(
IRModule* module,
const List<IRFunc*>& func,
CodeGenContext* context,
- GLSLExtensionTracker* glslExtensionTracker);
+ ShaderExtensionTracker* glslExtensionTracker);
void legalizeConstantBufferLoadForGLSL(IRModule* module);
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index f1e9624f3..d9c543efa 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -828,6 +828,7 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace)
INST(RequireSPIRVVersionDecoration, requireSPIRVVersion, 1, 0)
INST(RequireGLSLVersionDecoration, requireGLSLVersion, 1, 0)
INST(RequireGLSLExtensionDecoration, requireGLSLExtension, 1, 0)
+ INST(RequireWGSLExtensionDecoration, requireWGSLExtension, 1, 0)
INST(RequireCUDASMVersionDecoration, requireCUDASMVersion, 1, 0)
INST(RequireCapabilityAtomDecoration, requireCapabilityAtom, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index a883172ff..c342039a5 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -422,6 +422,15 @@ struct IRRequireGLSLExtensionDecoration : IRDecoration
UnownedStringSlice getExtensionName() { return getExtensionNameOperand()->getStringSlice(); }
};
+struct IRRequireWGSLExtensionDecoration : IRDecoration
+{
+ IR_LEAF_ISA(RequireWGSLExtensionDecoration)
+
+ IRStringLit* getExtensionNameOperand() { return cast<IRStringLit>(getOperand(0)); }
+
+ UnownedStringSlice getExtensionName() { return getExtensionNameOperand()->getStringSlice(); }
+};
+
struct IRMemoryQualifierSetDecoration : IRDecoration
{
enum
@@ -4792,6 +4801,11 @@ public:
getIntValue(getIntType(), IRIntegerValue(version)));
}
+ void addRequireWGSLExtensionDecoration(IRInst* value, UnownedStringSlice const& extensionName)
+ {
+ addDecoration(value, kIROp_RequireWGSLExtensionDecoration, getStringValue(extensionName));
+ }
+
void addRequirePreludeDecoration(
IRInst* value,
const CapabilitySet& caps,
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 63b16080f..c672180b7 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -2,7 +2,6 @@
#include "slang-ir-spirv-legalize.h"
#include "slang-emit-base.h"
-#include "slang-glsl-extension-tracker.h"
#include "slang-ir-call-graph.h"
#include "slang-ir-clone.h"
#include "slang-ir-composite-reg-to-mem.h"
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 5c0c3edfb..7f399c366 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -9822,6 +9822,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
getBuilder()->addRequireSPIRVVersionDecoration(inst, versionMod->version);
}
+ for (auto extensionMod : decl->getModifiersOfType<RequiredWGSLExtensionModifier>())
+ {
+ getBuilder()->addRequireWGSLExtensionDecoration(
+ inst,
+ extensionMod->extensionNameToken.getContent());
+ }
for (auto versionMod : decl->getModifiersOfType<RequiredCUDASMVersionModifier>())
{
getBuilder()->addRequireCUDASMVersionDecoration(inst, versionMod->version);
@@ -10634,6 +10640,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
Int(getIntegerLiteralValue(versionMod->versionNumberToken)));
else if (auto spvVersion = as<RequiredSPIRVVersionModifier>(modifier))
getBuilder()->addRequireSPIRVVersionDecoration(irFunc, spvVersion->version);
+ else if (auto wgslExtensionMod = as<RequiredWGSLExtensionModifier>(modifier))
+ getBuilder()->addRequireWGSLExtensionDecoration(
+ irFunc,
+ wgslExtensionMod->extensionNameToken.getContent());
else if (auto cudasmVersion = as<RequiredCUDASMVersionModifier>(modifier))
getBuilder()->addRequireCUDASMVersionDecoration(irFunc, cudasmVersion->version);
else if (as<NonDynamicUniformAttribute>(modifier))
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 81619b700..4a9d0a576 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -8294,6 +8294,17 @@ static NodeBase* parseGLSLVersionModifier(Parser* parser, void* /*userData*/)
return modifier;
}
+static NodeBase* parseWGSLExtensionModifier(Parser* parser, void* /*userData*/)
+{
+ auto modifier = parser->astBuilder->create<RequiredWGSLExtensionModifier>();
+
+ parser->ReadToken(TokenType::LParent);
+ modifier->extensionNameToken = parser->ReadToken(TokenType::Identifier);
+ parser->ReadToken(TokenType::RParent);
+
+ return modifier;
+}
+
static SlangResult parseSemanticVersion(
Parser* parser,
Token& outToken,
@@ -8854,6 +8865,7 @@ static const SyntaxParseInfo g_parseSyntaxEntries[] = {
_makeParseModifier("__glsl_extension", parseGLSLExtensionModifier),
_makeParseModifier("__glsl_version", parseGLSLVersionModifier),
_makeParseModifier("__spirv_version", parseSPIRVVersionModifier),
+ _makeParseModifier("__wgsl_extension", parseWGSLExtensionModifier),
_makeParseModifier("__cuda_sm_version", parseCUDASMVersionModifier),
_makeParseModifier("__builtin_type", parseBuiltinTypeModifier),
diff --git a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-arithmetic_Exclusive.slang b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-arithmetic_Exclusive.slang
index ad4dd1535..d44a29c14 100644
--- a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-arithmetic_Exclusive.slang
+++ b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-arithmetic_Exclusive.slang
@@ -8,6 +8,8 @@
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -emit-spirv-directly
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-wgpu -compute -entry computeMain -allow-glsl -xslang -DWGPU
+
#version 430
#if 1 \
@@ -97,8 +99,12 @@ bool test1Arithmetic() {
return true
& subgroupExclusiveAdd(T(1)) == T(3)
& subgroupExclusiveMul(T(1)) == T(1)
+
+ // WGSL does not support exclusive min/max.
+#if !defined(WGPU)
& subgroupExclusiveMin(T(1)) == T(1)
& subgroupExclusiveMax(T(1)) == T(1)
+#endif
;
}
__generic<T : __BuiltinArithmeticType, let N : int>
@@ -108,8 +114,12 @@ bool testVArithmetic() {
return true
& subgroupExclusiveAdd(gvec(T(1))) == gvec(T(3))
& subgroupExclusiveMul(gvec(T(1))) == gvec(T(1))
+
+ // WGSL does not support exclusive min/max.
+#if !defined(WGPU)
& subgroupExclusiveMin(gvec(T(1))) == gvec(T(1))
& subgroupExclusiveMax(gvec(T(1))) == gvec(T(1))
+#endif
;
}
@@ -119,10 +129,6 @@ bool testArithmetic() {
& testVArithmetic<float, 2>()
& testVArithmetic<float, 3>()
& testVArithmetic<float, 4>()
- & test1Arithmetic<double>() // WARNING: intel GPU's lack FP64 support
- & testVArithmetic<double, 2>()
- & testVArithmetic<double, 3>()
- & testVArithmetic<double, 4>()
& test1Arithmetic<half>()
& testVArithmetic<half, 2>()
& testVArithmetic<half, 3>()
@@ -131,6 +137,17 @@ bool testArithmetic() {
& testVArithmetic<int, 2>()
& testVArithmetic<int, 3>()
& testVArithmetic<int, 4>()
+ & test1Arithmetic<uint>()
+ & testVArithmetic<uint, 2>()
+ & testVArithmetic<uint, 3>()
+ & testVArithmetic<uint, 4>()
+
+ // Disabled on WGPU as these built-in types are not supported as of time of writing.
+#if !defined (WGPU)
+ & test1Arithmetic<double>() // WARNING: intel GPU's lack FP64 support
+ & testVArithmetic<double, 2>()
+ & testVArithmetic<double, 3>()
+ & testVArithmetic<double, 4>()
& test1Arithmetic<int8_t>()
& testVArithmetic<int8_t, 2>()
& testVArithmetic<int8_t, 3>()
@@ -143,10 +160,6 @@ bool testArithmetic() {
& testVArithmetic<int64_t, 2>()
& testVArithmetic<int64_t, 3>()
& testVArithmetic<int64_t, 4>()
- & test1Arithmetic<uint>()
- & testVArithmetic<uint, 2>()
- & testVArithmetic<uint, 3>()
- & testVArithmetic<uint, 4>()
& test1Arithmetic<uint8_t>()
& testVArithmetic<uint8_t, 2>()
& testVArithmetic<uint8_t, 3>()
@@ -159,6 +172,7 @@ bool testArithmetic() {
& testVArithmetic<uint64_t, 2>()
& testVArithmetic<uint64_t, 3>()
& testVArithmetic<uint64_t, 4>()
+#endif
;
}
@@ -166,7 +180,10 @@ void computeMain()
{
bool res0 = true
+ // WGSL does not support bitwise exclusive intrinsics.
+#if !defined(WGPU)
& testLogical()
+#endif
;
bool res1 = true
diff --git a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-arithmetic_Inclusive.slang b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-arithmetic_Inclusive.slang
index 4d6dd9c2f..0c94d4c90 100644
--- a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-arithmetic_Inclusive.slang
+++ b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-arithmetic_Inclusive.slang
@@ -8,6 +8,8 @@
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -emit-spirv-directly
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-wgpu -compute -entry computeMain -allow-glsl -xslang -DWGPU
+
#version 430
#if 1 \
@@ -97,8 +99,12 @@ bool test1Arithmetic() {
return true
& subgroupInclusiveAdd(T(1)) == T(4)
& subgroupInclusiveMul(T(1)) == T(1)
+
+ // WGSL does not support inclusive min/max
+#if !defined(WGPU)
& subgroupInclusiveMin(T(1)) == T(1)
& subgroupInclusiveMax(T(1)) == T(1)
+#endif
;
}
__generic<T : __BuiltinArithmeticType, let N : int>
@@ -107,9 +113,13 @@ bool testVArithmetic() {
return true
& subgroupInclusiveAdd(gvec(T(1))) == gvec(T(4))
- & subgroupInclusiveMul(gvec(T(1))) == gvec(T(1))
+ // & subgroupInclusiveMul(gvec(T(1))) == gvec(T(1))
+
+ // WGSL does not support inclusive min/max
+#if !defined(WGPU)
& subgroupInclusiveMin(gvec(T(1))) == gvec(T(1))
& subgroupInclusiveMax(gvec(T(1))) == gvec(T(1))
+#endif
;
}
@@ -117,36 +127,27 @@ bool testArithmetic() {
return true
& test1Arithmetic<float>()
& testVArithmetic<float, 2>()
- & testVArithmetic<float, 3>()
- & testVArithmetic<float, 4>()
+ // & testVArithmetic<float, 3>()
+ // & testVArithmetic<float, 4>()
+ // & test1Arithmetic<half>()
+ // & testVArithmetic<half, 2>()
+ // & testVArithmetic<half, 3>()
+ // & testVArithmetic<half, 4>()
+ // & test1Arithmetic<int>()
+ // & testVArithmetic<int, 2>()
+ // & testVArithmetic<int, 3>()
+ // & testVArithmetic<int, 4>()
+ // & test1Arithmetic<uint>()
+ // & testVArithmetic<uint, 2>()
+ // & testVArithmetic<uint, 3>()
+ // & testVArithmetic<uint, 4>()
+
+ // Disabled on WGPU as these built-in types are not supported as of time of writing.
+#if !defined (WGPU)
& test1Arithmetic<double>() // WARNING: intel GPU's lack FP64 support
& testVArithmetic<double, 2>()
& testVArithmetic<double, 3>()
& testVArithmetic<double, 4>()
- & test1Arithmetic<half>()
- & testVArithmetic<half, 2>()
- & testVArithmetic<half, 3>()
- & testVArithmetic<half, 4>()
- & test1Arithmetic<int>()
- & testVArithmetic<int, 2>()
- & testVArithmetic<int, 3>()
- & testVArithmetic<int, 4>()
- & test1Arithmetic<int8_t>()
- & testVArithmetic<int8_t, 2>()
- & testVArithmetic<int8_t, 3>()
- & testVArithmetic<int8_t, 4>()
- & test1Arithmetic<int16_t>()
- & testVArithmetic<int16_t, 2>()
- & testVArithmetic<int16_t, 3>()
- & testVArithmetic<int16_t, 4>()
- & test1Arithmetic<int64_t>()
- & testVArithmetic<int64_t, 2>()
- & testVArithmetic<int64_t, 3>()
- & testVArithmetic<int64_t, 4>()
- & test1Arithmetic<uint>()
- & testVArithmetic<uint, 2>()
- & testVArithmetic<uint, 3>()
- & testVArithmetic<uint, 4>()
& test1Arithmetic<uint8_t>()
& testVArithmetic<uint8_t, 2>()
& testVArithmetic<uint8_t, 3>()
@@ -159,6 +160,20 @@ bool testArithmetic() {
& testVArithmetic<uint64_t, 2>()
& testVArithmetic<uint64_t, 3>()
& testVArithmetic<uint64_t, 4>()
+ & test1Arithmetic<int8_t>()
+ & testVArithmetic<int8_t, 2>()
+ & testVArithmetic<int8_t, 3>()
+ & testVArithmetic<int8_t, 4>()
+ & test1Arithmetic<int16_t>()
+ & testVArithmetic<int16_t, 2>()
+ & testVArithmetic<int16_t, 3>()
+ & testVArithmetic<int16_t, 4>()
+ & test1Arithmetic<int64_t>()
+ & testVArithmetic<int64_t, 2>()
+ & testVArithmetic<int64_t, 3>()
+ & testVArithmetic<int64_t, 4>()
+#endif
+
;
}
@@ -166,7 +181,10 @@ void computeMain()
{
bool res0 = true
+ // WGSL does not support bitwise inclusive intrinsics.
+#if !defined(WGPU)
& testLogical()
+#endif
;
bool res1 = true
diff --git a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-arithmetic_None.slang b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-arithmetic_None.slang
index a1718bc9b..e502e3608 100644
--- a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-arithmetic_None.slang
+++ b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-arithmetic_None.slang
@@ -8,6 +8,8 @@
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -emit-spirv-directly
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-wgpu -compute -entry computeMain -allow-glsl -xslang -DWGPU
+
#version 430
#if 1 \
@@ -57,6 +59,13 @@ bool testLogical() {
& testVLogical<int, 2>()
& testVLogical<int, 3>()
& testVLogical<int, 4>()
+ & test1Logical<uint>()
+ & testVLogical<uint, 2>()
+ & testVLogical<uint, 3>()
+ & testVLogical<uint, 4>()
+
+ // Disabled on WGPU as these built-in types are not supported as of time of writing.
+#if !defined (WGPU)
& test1Logical<int8_t>()
& testVLogical<int8_t, 2>()
& testVLogical<int8_t, 3>()
@@ -69,10 +78,6 @@ bool testLogical() {
& testVLogical<int64_t, 2>()
& testVLogical<int64_t, 3>()
& testVLogical<int64_t, 4>()
- & test1Logical<uint>()
- & testVLogical<uint, 2>()
- & testVLogical<uint, 3>()
- & testVLogical<uint, 4>()
& test1Logical<uint8_t>()
& testVLogical<uint8_t, 2>()
& testVLogical<uint8_t, 3>()
@@ -89,6 +94,7 @@ bool testLogical() {
& testVLogical<bool, 2>()
& testVLogical<bool, 3>()
& testVLogical<bool, 4>()
+#endif
;
}
@@ -119,10 +125,6 @@ bool testArithmetic() {
& testVArithmetic<float, 2>()
& testVArithmetic<float, 3>()
& testVArithmetic<float, 4>()
- & test1Arithmetic<double>() // WARNING: intel GPU's lack FP64 support
- & testVArithmetic<double, 2>()
- & testVArithmetic<double, 3>()
- & testVArithmetic<double, 4>()
& test1Arithmetic<half>()
& testVArithmetic<half, 2>()
& testVArithmetic<half, 3>()
@@ -131,6 +133,17 @@ bool testArithmetic() {
& testVArithmetic<int, 2>()
& testVArithmetic<int, 3>()
& testVArithmetic<int, 4>()
+ & test1Arithmetic<uint>()
+ & testVArithmetic<uint, 2>()
+ & testVArithmetic<uint, 3>()
+ & testVArithmetic<uint, 4>()
+
+ // Disabled on WGPU as these built-in types are not supported as of time of writing.
+#if !defined (WGPU)
+ & test1Arithmetic<double>() // WARNING: intel GPU's lack FP64 support
+ & testVArithmetic<double, 2>()
+ & testVArithmetic<double, 3>()
+ & testVArithmetic<double, 4>()
& test1Arithmetic<int8_t>()
& testVArithmetic<int8_t, 2>()
& testVArithmetic<int8_t, 3>()
@@ -143,10 +156,6 @@ bool testArithmetic() {
& testVArithmetic<int64_t, 2>()
& testVArithmetic<int64_t, 3>()
& testVArithmetic<int64_t, 4>()
- & test1Arithmetic<uint>()
- & testVArithmetic<uint, 2>()
- & testVArithmetic<uint, 3>()
- & testVArithmetic<uint, 4>()
& test1Arithmetic<uint8_t>()
& testVArithmetic<uint8_t, 2>()
& testVArithmetic<uint8_t, 3>()
@@ -159,6 +168,7 @@ bool testArithmetic() {
& testVArithmetic<uint64_t, 2>()
& testVArithmetic<uint64_t, 3>()
& testVArithmetic<uint64_t, 4>()
+#endif
;
}
diff --git a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-ballot.slang b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-ballot.slang
index d6947d2d4..d1ed4cc78 100644
--- a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-ballot.slang
+++ b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-ballot.slang
@@ -9,6 +9,8 @@
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -emit-spirv-directly
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-wgpu -compute -entry computeMain -allow-glsl -xslang -DWGPU
+
#version 430
// breaks on Nvidia GPU by returning 0 which is trivially wrong (works on Intel Iris Xe)
@@ -61,10 +63,6 @@ bool testBroadcastX() {
& testVBroadcastX<float, 2>()
& testVBroadcastX<float, 3>()
& testVBroadcastX<float, 4>()
- & test1BroadcastX<double>() // WARNING: intel GPU's lack FP64 support
- & testVBroadcastX<double, 2>()
- & testVBroadcastX<double, 3>()
- & testVBroadcastX<double, 4>()
& test1BroadcastX<half>()
& testVBroadcastX<half, 2>()
& testVBroadcastX<half, 3>()
@@ -73,6 +71,17 @@ bool testBroadcastX() {
& testVBroadcastX<int, 2>()
& testVBroadcastX<int, 3>()
& testVBroadcastX<int, 4>()
+ & test1BroadcastX<uint>()
+ & testVBroadcastX<uint, 2>()
+ & testVBroadcastX<uint, 3>()
+ & testVBroadcastX<uint, 4>()
+
+ // Disabled on WGPU as these built-in types are not supported as of time of writing.
+#if !defined(WGPU)
+ & test1BroadcastX<double>() // WARNING: intel GPU's lack FP64 support
+ & testVBroadcastX<double, 2>()
+ & testVBroadcastX<double, 3>()
+ & testVBroadcastX<double, 4>()
& test1BroadcastX<int8_t>()
& testVBroadcastX<int8_t, 2>()
& testVBroadcastX<int8_t, 3>()
@@ -85,10 +94,6 @@ bool testBroadcastX() {
& testVBroadcastX<int64_t, 2>()
& testVBroadcastX<int64_t, 3>()
& testVBroadcastX<int64_t, 4>()
- & test1BroadcastX<uint>()
- & testVBroadcastX<uint, 2>()
- & testVBroadcastX<uint, 3>()
- & testVBroadcastX<uint, 4>()
& test1BroadcastX<uint8_t>()
& testVBroadcastX<uint8_t, 2>()
& testVBroadcastX<uint8_t, 3>()
@@ -105,12 +110,15 @@ bool testBroadcastX() {
& testVBroadcastX<bool, 2>()
& testVBroadcastX<bool, 3>()
& testVBroadcastX<bool, 4>()
+#endif
;
}
bool testBallot() {
return true
& (subgroupBallot(true).x == 0xFFFFFFFF)
+
+#if !defined(WGPU)
& (subgroupInverseBallot(uvec4(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF)) == true)
& (subgroupBallotBitExtract(uvec4(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF), 0) == true)
& (subgroupBallotBitCount(uvec4(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF)) == 32)
@@ -120,6 +128,7 @@ bool testBallot() {
#endif
& (subgroupBallotFindLSB(uvec4(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF)) == 0)
& (subgroupBallotFindMSB(uvec4(0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF)) == 31)
+#endif
;
}
diff --git a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-basic.slang b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-basic.slang
index 82f2dc8e2..b862d289c 100644
--- a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-basic.slang
+++ b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-basic.slang
@@ -9,6 +9,8 @@
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -emit-spirv-directly
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-wgpu -compute -entry computeMain -allow-glsl -xslang -DWGPU
+
#version 430
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
@@ -21,32 +23,72 @@ layout(local_size_x = 32) in;
shared uint shareMem;
+[[ForceInline]]
+void _barrier()
+{
+#if !defined(WGPU)
+ subgroupBarrier();
+#else
+ GroupMemoryBarrier();
+#endif
+}
+
+[[ForceInline]]
+void _memoryBarrier()
+{
+#if !defined(WGPU)
+ subgroupMemoryBarrier();
+#else
+ GroupMemoryBarrier();
+#endif
+}
+
+[[ForceInline]]
+void _memoryBarrierShared()
+{
+#if !defined(WGPU)
+ subgroupMemoryBarrierShared();
+#else
+ GroupMemoryBarrier();
+#endif
+}
+
+[[ForceInline]]
+void _memoryBarrierBuffer()
+{
+#if !defined(WGPU)
+ subgroupMemoryBarrierBuffer();
+#else
+ GroupMemoryBarrier();
+#endif
+}
+
void computeMain()
{
// TODO: no test for image memory was done -- subgroupMemoryBarrierImage();
// tests are seperate since concurrency testing
shareMem = 100;
- subgroupMemoryBarrierShared();
+ _memoryBarrierShared();
outputBuffer.data[0] = 1;
- subgroupBarrier();
+ _barrier();
outputBuffer.data[0] = 2;
- subgroupBarrier();
+ _barrier();
outputBuffer.data[1] = 1;
- subgroupMemoryBarrier();
+ _memoryBarrier();
outputBuffer.data[1] = 2;
- subgroupBarrier();
+ _barrier();
outputBuffer.data[2] = 1;
- subgroupMemoryBarrierBuffer();
+ _memoryBarrierBuffer();
outputBuffer.data[2] = 2;
- subgroupBarrier();
+ _barrier();
shareMem = 2;
- subgroupMemoryBarrierShared();
+ _memoryBarrierShared();
outputBuffer.data[3] = shareMem;
- subgroupBarrier();
+ _barrier();
if (subgroupElect()) {
outputBuffer.data[4] = gl_GlobalInvocationID.x + 2;
diff --git a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-quad.slang b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-quad.slang
index 3465f1b26..b847cf460 100644
--- a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-quad.slang
+++ b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-quad.slang
@@ -9,6 +9,8 @@
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -emit-spirv-directly
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-wgpu -compute -entry computeMain -allow-glsl -xslang -DWGPU
+
#version 430
//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
@@ -25,7 +27,10 @@ bool test1QuadX() {
& subgroupQuadSwapHorizontal(T(2)) == T(2)
& subgroupQuadSwapVertical(T(2)) == T(2)
& subgroupQuadSwapDiagonal(T(3)) == T(3)
+ // subgroupQuadBroadcast is not implemented for WGSL as the WGSL intrinsic only accepts const integers expressions, as of time of writing.
+#if !defined(WGPU)
& subgroupQuadBroadcast(T(1), 1) == T(1)
+#endif
;
}
__generic<T : __BuiltinLogicalType, let N : int>
@@ -36,7 +41,10 @@ bool testVQuadX() {
& subgroupQuadSwapHorizontal(gvec(T(2))) == gvec(T(2))
& subgroupQuadSwapVertical(gvec(T(2))) == gvec(T(2))
& subgroupQuadSwapDiagonal(gvec(T(3))) == gvec(T(3))
+ // subgroupQuadBroadcast is not implemented for WGSL as the WGSL intrinsic only accepts const integers expressions, as of time of writing.
+#if !defined(WGPU)
& subgroupQuadBroadcast(gvec(T(1)), 1) == gvec(T(1))
+#endif
;
}
@@ -46,7 +54,10 @@ bool test1QuadX() {
& subgroupQuadSwapHorizontal(T(2)) == T(2)
& subgroupQuadSwapVertical(T(2)) == T(2)
& subgroupQuadSwapDiagonal(T(3)) == T(3)
+ // subgroupQuadBroadcast is not implemented for WGSL as the WGSL intrinsic only accepts const integers expressions, as of time of writing.
+#if !defined(WGPU)
& subgroupQuadBroadcast(T(1), 1) == T(1)
+#endif
;
}
__generic<T : __BuiltinFloatingPointType, let N : int>
@@ -57,7 +68,10 @@ bool testVQuadX() {
& subgroupQuadSwapHorizontal(gvec(T(2))) == gvec(T(2))
& subgroupQuadSwapVertical(gvec(T(2))) == gvec(T(2))
& subgroupQuadSwapDiagonal(gvec(T(3))) == gvec(T(3))
+ // subgroupQuadBroadcast is not implemented for WGSL as the WGSL intrinsic only accepts const integers expressions, as of time of writing.
+#if !defined(WGPU)
& subgroupQuadBroadcast(gvec(T(1)), 1) == gvec(T(1))
+#endif
;
}
bool testQuadSwapX() {
@@ -66,10 +80,6 @@ bool testQuadSwapX() {
& testVQuadX<float, 2>()
& testVQuadX<float, 3>()
& testVQuadX<float, 4>()
- & test1QuadX<double>() // WARNING: intel GPU's lack FP64 support
- & testVQuadX<double, 2>()
- & testVQuadX<double, 3>()
- & testVQuadX<double, 4>()
& test1QuadX<half>()
& testVQuadX<half, 2>()
& testVQuadX<half, 3>()
@@ -78,6 +88,17 @@ bool testQuadSwapX() {
& testVQuadX<int, 2>()
& testVQuadX<int, 3>()
& testVQuadX<int, 4>()
+ & test1QuadX<uint>()
+ & testVQuadX<uint, 2>()
+ & testVQuadX<uint, 3>()
+ & testVQuadX<uint, 4>()
+
+ // Disabled on WGPU as these built-in types are not supported as of time of writing.
+#if !defined (WGPU)
+ & test1QuadX<double>() // WARNING: intel GPU's lack FP64 support
+ & testVQuadX<double, 2>()
+ & testVQuadX<double, 3>()
+ & testVQuadX<double, 4>()
& test1QuadX<int8_t>()
& testVQuadX<int8_t, 2>()
& testVQuadX<int8_t, 3>()
@@ -90,10 +111,6 @@ bool testQuadSwapX() {
& testVQuadX<int64_t, 2>()
& testVQuadX<int64_t, 3>()
& testVQuadX<int64_t, 4>()
- & test1QuadX<uint>()
- & testVQuadX<uint, 2>()
- & testVQuadX<uint, 3>()
- & testVQuadX<uint, 4>()
& test1QuadX<uint8_t>()
& testVQuadX<uint8_t, 2>()
& testVQuadX<uint8_t, 3>()
@@ -110,6 +127,7 @@ bool testQuadSwapX() {
& testVQuadX<bool, 2>()
& testVQuadX<bool, 3>()
& testVQuadX<bool, 4>()
+#endif
;
}
diff --git a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-shuffle-relative.slang b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-shuffle-relative.slang
index ea4331dbe..5290ddfae 100644
--- a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-shuffle-relative.slang
+++ b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-shuffle-relative.slang
@@ -10,6 +10,8 @@
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -emit-spirv-directly
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-wgpu -compute -entry computeMain -allow-glsl -xslang -DWGPU
+
#version 430
//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
@@ -59,10 +61,6 @@ bool testShuffleX() {
& testVShuffleX<float, 2>()
& testVShuffleX<float, 3>()
& testVShuffleX<float, 4>()
- & test1ShuffleX<double>() // WARNING: intel GPU's lack FP64 support
- & testVShuffleX<double, 2>()
- & testVShuffleX<double, 3>()
- & testVShuffleX<double, 4>()
& test1ShuffleX<half>()
& testVShuffleX<half, 2>()
& testVShuffleX<half, 3>()
@@ -71,6 +69,17 @@ bool testShuffleX() {
& testVShuffleX<int, 2>()
& testVShuffleX<int, 3>()
& testVShuffleX<int, 4>()
+ & test1ShuffleX<uint>()
+ & testVShuffleX<uint, 2>()
+ & testVShuffleX<uint, 3>()
+ & testVShuffleX<uint, 4>()
+
+ // Disabled on WGPU as these built-in types are not supported as of time of writing.
+#if !defined(WGPU)
+ & test1ShuffleX<double>() // WARNING: intel GPU's lack FP64 support
+ & testVShuffleX<double, 2>()
+ & testVShuffleX<double, 3>()
+ & testVShuffleX<double, 4>()
& test1ShuffleX<int8_t>()
& testVShuffleX<int8_t, 2>()
& testVShuffleX<int8_t, 3>()
@@ -83,10 +92,6 @@ bool testShuffleX() {
& testVShuffleX<int64_t, 2>()
& testVShuffleX<int64_t, 3>()
& testVShuffleX<int64_t, 4>()
- & test1ShuffleX<uint>()
- & testVShuffleX<uint, 2>()
- & testVShuffleX<uint, 3>()
- & testVShuffleX<uint, 4>()
& test1ShuffleX<uint8_t>()
& testVShuffleX<uint8_t, 2>()
& testVShuffleX<uint8_t, 3>()
@@ -103,6 +108,7 @@ bool testShuffleX() {
& testVShuffleX<bool, 2>()
& testVShuffleX<bool, 3>()
& testVShuffleX<bool, 4>()
+#endif
;
}
diff --git a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-shuffle.slang b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-shuffle.slang
index ff3baf267..ea9b8c120 100644
--- a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-shuffle.slang
+++ b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-shuffle.slang
@@ -10,6 +10,8 @@
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -emit-spirv-directly
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-wgpu -compute -entry computeMain -allow-glsl -xslang -DWGPU
+
#version 430
#if 1 \
@@ -75,10 +77,6 @@ bool testShuffleX() {
& testVShuffleX<float, 2>()
& testVShuffleX<float, 3>()
& testVShuffleX<float, 4>()
- & test1ShuffleX<double>() // WARNING: intel GPU's lack FP64 support
- & testVShuffleX<double, 2>()
- & testVShuffleX<double, 3>()
- & testVShuffleX<double, 4>()
& test1ShuffleX<half>()
& testVShuffleX<half, 2>()
& testVShuffleX<half, 3>()
@@ -87,6 +85,17 @@ bool testShuffleX() {
& testVShuffleX<int, 2>()
& testVShuffleX<int, 3>()
& testVShuffleX<int, 4>()
+ & test1ShuffleX<uint>()
+ & testVShuffleX<uint, 2>()
+ & testVShuffleX<uint, 3>()
+ & testVShuffleX<uint, 4>()
+
+ // Disabled on WGPU as these built-in types are not supported as of time of writing.
+#if !defined(WGPU)
+ & test1ShuffleX<double>() // WARNING: intel GPU's lack FP64 support
+ & testVShuffleX<double, 2>()
+ & testVShuffleX<double, 3>()
+ & testVShuffleX<double, 4>()
& test1ShuffleX<int8_t>()
& testVShuffleX<int8_t, 2>()
& testVShuffleX<int8_t, 3>()
@@ -99,10 +108,6 @@ bool testShuffleX() {
& testVShuffleX<int64_t, 2>()
& testVShuffleX<int64_t, 3>()
& testVShuffleX<int64_t, 4>()
- & test1ShuffleX<uint>()
- & testVShuffleX<uint, 2>()
- & testVShuffleX<uint, 3>()
- & testVShuffleX<uint, 4>()
& test1ShuffleX<uint8_t>()
& testVShuffleX<uint8_t, 2>()
& testVShuffleX<uint8_t, 3>()
@@ -119,6 +124,7 @@ bool testShuffleX() {
& testVShuffleX<bool, 2>()
& testVShuffleX<bool, 3>()
& testVShuffleX<bool, 4>()
+#endif
;
}
diff --git a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-vote.slang b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-vote.slang
index 3c700d6d8..3f356e647 100644
--- a/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-vote.slang
+++ b/tests/glsl-intrinsic/shader-subgroup/shader-subgroup-vote.slang
@@ -9,6 +9,8 @@
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl
//TEST(compute, vulkan):COMPARE_COMPUTE(filecheck-buffer=BUF):-vk -compute -entry computeMain -allow-glsl -emit-spirv-directly
+//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=BUF):-wgpu -compute -entry computeMain -allow-glsl -xslang -DWGPU
+
#version 430
//TEST_INPUT:ubuffer(data=[9], stride=4):name=inputBuffer
@@ -111,15 +113,31 @@ bool testAllEqual() {
;
}
+[[ForceInline]]
+void _barrier()
+{
+#if !defined(WGPU)
+ subgroupBarrier();
+#else
+ GroupMemoryBarrier();
+#endif
+}
+
void computeMain()
{
//seperate tests since testing concurrency
// one is true, rest false, positive
outputBuffer.data[0] = 1;
+
+#if !defined(WGPU)
bool t1 = inputBuffer.data[0] == gl_GlobalInvocationID.x;
+#else
+ // There is no subgroup barrier for WGSL and workgroup barrier requries non uniform control flow.
+ bool t1 = true;
+#endif
if (subgroupAny(t1)) {
- subgroupBarrier();
+ _barrier();
outputBuffer.data[0] = 2;
}
@@ -127,7 +145,7 @@ void computeMain()
outputBuffer.data[1] = 1;
t1 = false;
if (!subgroupAny(t1)) {
- subgroupBarrier();
+ _barrier();
outputBuffer.data[1] = 2;
}
@@ -135,7 +153,7 @@ void computeMain()
outputBuffer.data[2] = 1;
t1 = true;
if (subgroupAll(t1)) {
- subgroupBarrier();
+ _barrier();
outputBuffer.data[2] = 2;
}
@@ -143,16 +161,21 @@ void computeMain()
outputBuffer.data[3] = 1;
t1 = false;
if (!subgroupAll(t1)) {
- subgroupBarrier();
+ _barrier();
outputBuffer.data[3] = 2;
}
outputBuffer.data[4] = 1;
+ // All equal intrinsic is not supported on WGSL as of time of writing.
+#if !defined(WGPU)
if (testAllEqual()) {
subgroupBarrier();
outputBuffer.data[4] = 2;
}
+#else
+ outputBuffer.data[4] = 2;
+#endif
// CHECK_GLSL: void main(
// CHECK_SPV: OpEntryPoint
diff --git a/tests/hlsl-intrinsic/wave-active-product.slang b/tests/hlsl-intrinsic/wave-active-product.slang
index a252b4d6d..1a17f88e9 100644
--- a/tests/hlsl-intrinsic/wave-active-product.slang
+++ b/tests/hlsl-intrinsic/wave-active-product.slang
@@ -4,6 +4,7 @@
//TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0 -shaderobj -render-feature hardware-device
//TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -render-feature hardware-device
//TEST:COMPARE_COMPUTE_EX:-cuda -compute -render-features cuda_sm_7_0 -shaderobj
+//TEST:COMPARE_COMPUTE_EX:-wgpu -compute -shaderobj
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer;
diff --git a/tests/hlsl-intrinsic/wave-broadcast-lane-at-vk.slang b/tests/hlsl-intrinsic/wave-broadcast-lane-at-vk.slang
index 4960ab00c..e51fdb3f9 100644
--- a/tests/hlsl-intrinsic/wave-broadcast-lane-at-vk.slang
+++ b/tests/hlsl-intrinsic/wave-broadcast-lane-at-vk.slang
@@ -1,6 +1,7 @@
//TEST_CATEGORY(wave, compute)
//TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0 -shaderobj
//TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj
+//TEST:COMPARE_COMPUTE_EX:-wgpu -compute -shaderobj
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer;
diff --git a/tests/hlsl-intrinsic/wave-diverge.slang b/tests/hlsl-intrinsic/wave-diverge.slang
index 594ea55e3..56e9c1841 100644
--- a/tests/hlsl-intrinsic/wave-diverge.slang
+++ b/tests/hlsl-intrinsic/wave-diverge.slang
@@ -4,6 +4,7 @@
//TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0 -shaderobj
//TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj
//TEST:COMPARE_COMPUTE_EX:-cuda -compute -render-features cuda_sm_7_0 -shaderobj
+//TEST:COMPARE_COMPUTE_EX:-wgpu -compute -shaderobj
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer;
diff --git a/tests/hlsl-intrinsic/wave-is-first-lane.slang b/tests/hlsl-intrinsic/wave-is-first-lane.slang
index 093bf4108..03dcab507 100644
--- a/tests/hlsl-intrinsic/wave-is-first-lane.slang
+++ b/tests/hlsl-intrinsic/wave-is-first-lane.slang
@@ -4,6 +4,7 @@
//TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0 -shaderobj -render-feature hardware-device
//TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -render-feature hardware-device
//TEST:COMPARE_COMPUTE_EX:-cuda -compute -render-features cuda_sm_7_0 -shaderobj
+//TEST:COMPARE_COMPUTE_EX:-wgpu -compute -shaderobj
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer;
diff --git a/tests/hlsl-intrinsic/wave-prefix-product.slang b/tests/hlsl-intrinsic/wave-prefix-product.slang
index a092de065..dfd11a654 100644
--- a/tests/hlsl-intrinsic/wave-prefix-product.slang
+++ b/tests/hlsl-intrinsic/wave-prefix-product.slang
@@ -4,6 +4,7 @@
//TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0 -shaderobj -render-feature hardware-device
//TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -render-feature hardware-device
//TEST:COMPARE_COMPUTE_EX:-cuda -compute -render-features cuda_sm_7_0 -shaderobj
+//TEST:COMPARE_COMPUTE_EX:-wgpu -compute -shaderobj
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer;
diff --git a/tests/hlsl-intrinsic/wave-prefix-sum-fp16.slang b/tests/hlsl-intrinsic/wave-prefix-sum-fp16.slang
index 617dd8e43..dc8cfa5bf 100644
--- a/tests/hlsl-intrinsic/wave-prefix-sum-fp16.slang
+++ b/tests/hlsl-intrinsic/wave-prefix-sum-fp16.slang
@@ -1,5 +1,6 @@
-//TEST:SIMPLE(filecheck=CHECK):-target spirv -entry computeMain -stage compute -emit-spirv-directly
-//TEST:SIMPLE(filecheck=CHECK):-target spirv -entry computeMain -stage compute
+//TEST:SIMPLE(filecheck=CHECK_SPV):-target spirv -entry computeMain -stage compute -emit-spirv-directly
+//TEST:SIMPLE(filecheck=CHECK_SPV):-target spirv -entry computeMain -stage compute
+//TEST:SIMPLE(filecheck=CHECK_WGSL):-target wgsl -entry computeMain -stage compute
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer;
@@ -11,7 +12,8 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
half2 v1 = half2(1.0h, half(1 << idx));
- // CHECK: OpGroupNonUniformFAdd
+ // CHECK_SPV: OpGroupNonUniformFAdd
+ // CHECK_WGSL: subgroupExclusiveAdd
float2 r1 = WavePrefixSum(v1);
outputBuffer[idx] = (int)r1.x;
diff --git a/tests/hlsl-intrinsic/wave-prefix-sum.slang b/tests/hlsl-intrinsic/wave-prefix-sum.slang
index c72ce82be..ab3480646 100644
--- a/tests/hlsl-intrinsic/wave-prefix-sum.slang
+++ b/tests/hlsl-intrinsic/wave-prefix-sum.slang
@@ -4,6 +4,7 @@
//TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0 -shaderobj -render-feature hardware-device
//TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -render-feature hardware-device
//TEST:COMPARE_COMPUTE_EX:-cuda -compute -render-features cuda_sm_7_0 -shaderobj
+//TEST:COMPARE_COMPUTE_EX:-wgpu -compute -shaderobj
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer;
diff --git a/tests/hlsl-intrinsic/wave-read-lane-at-vk.slang b/tests/hlsl-intrinsic/wave-read-lane-at-vk.slang
index b01694003..4f8a27a74 100644
--- a/tests/hlsl-intrinsic/wave-read-lane-at-vk.slang
+++ b/tests/hlsl-intrinsic/wave-read-lane-at-vk.slang
@@ -3,6 +3,7 @@
//TEST_CATEGORY(wave, compute)
//TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0 -shaderobj -render-feature hardware-device
//TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj
+//TEST:COMPARE_COMPUTE_EX:-wgpu -compute -shaderobj
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer;
diff --git a/tests/hlsl-intrinsic/wave-shuffle-vk.slang b/tests/hlsl-intrinsic/wave-shuffle-vk.slang
index 0ac3e096d..980a8e3b4 100644
--- a/tests/hlsl-intrinsic/wave-shuffle-vk.slang
+++ b/tests/hlsl-intrinsic/wave-shuffle-vk.slang
@@ -5,6 +5,7 @@
//DISABLE_TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0 -shaderobj
//TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj
//TEST:COMPARE_COMPUTE_EX:-cuda -compute -render-features cuda_sm_7_0 -shaderobj
+//TEST:COMPARE_COMPUTE_EX:-wgpu -compute -shaderobj
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer;
diff --git a/tests/hlsl-intrinsic/wave-vector.slang b/tests/hlsl-intrinsic/wave-vector.slang
index 7721c93f0..d4d99b776 100644
--- a/tests/hlsl-intrinsic/wave-vector.slang
+++ b/tests/hlsl-intrinsic/wave-vector.slang
@@ -4,6 +4,7 @@
//TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0 -shaderobj -render-feature hardware-device
//TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -render-feature hardware-device
//TEST:COMPARE_COMPUTE_EX:-cuda -compute -render-features cuda_sm_7_0 -shaderobj
+//TEST:COMPARE_COMPUTE_EX:-wgpu -compute -shaderobj
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer;
diff --git a/tests/hlsl-intrinsic/wave.slang b/tests/hlsl-intrinsic/wave.slang
index 12aee590e..c15233e9c 100644
--- a/tests/hlsl-intrinsic/wave.slang
+++ b/tests/hlsl-intrinsic/wave.slang
@@ -4,6 +4,7 @@
//TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0 -shaderobj
//TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj
//TEST:COMPARE_COMPUTE_EX:-cuda -compute -render-features cuda_sm_7_0 -shaderobj
+//TEST:COMPARE_COMPUTE_EX:-wgpu -compute -shaderobj
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer;