summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/glsl.meta.slang174
-rw-r--r--source/slang/hlsl.meta.slang280
-rw-r--r--source/slang/slang-ast-modifier.h9
-rw-r--r--source/slang/slang-ast-print.cpp2
-rw-r--r--source/slang/slang-capabilities.capdef47
-rw-r--r--source/slang/slang-compiler.cpp13
-rw-r--r--source/slang/slang-emit-glsl.cpp6
-rw-r--r--source/slang/slang-emit-glsl.h4
-rw-r--r--source/slang/slang-emit-wgsl.cpp36
-rw-r--r--source/slang/slang-emit-wgsl.h15
-rw-r--r--source/slang/slang-emit.cpp6
-rw-r--r--source/slang/slang-extension-tracker.cpp (renamed from source/slang/slang-glsl-extension-tracker.cpp)22
-rw-r--r--source/slang/slang-extension-tracker.h (renamed from source/slang/slang-glsl-extension-tracker.h)15
-rw-r--r--source/slang/slang-ir-glsl-legalize.cpp8
-rw-r--r--source/slang/slang-ir-glsl-legalize.h4
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-insts.h14
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp1
-rw-r--r--source/slang/slang-lower-to-ir.cpp10
-rw-r--r--source/slang/slang-parser.cpp12
20 files changed, 465 insertions, 214 deletions
diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang
index dd1c5a907..ef3bfd683 100644
--- a/source/slang/glsl.meta.slang
+++ b/source/slang/glsl.meta.slang
@@ -6305,7 +6305,7 @@ void shader_subgroup_preamble() {
// GL_KHR_shader_subgroup_basic Built-in Variables
-[require(cpp_cuda_glsl_hlsl_spirv, subgroup_basic)]
+[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
void requireGLSLExtForSubgroupBasicBuiltin() {
__target_switch
{
@@ -6317,7 +6317,7 @@ void requireGLSLExtForSubgroupBasicBuiltin() {
}
}
-[require(cpp_cuda_glsl_hlsl_spirv, subgroup_basic)]
+[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
void setupExtForSubgroupBasicBuiltIn() {
__target_switch
{
@@ -6329,7 +6329,7 @@ void setupExtForSubgroupBasicBuiltIn() {
}
__spirv_version(1.3)
-[require(cpp_cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
void requireGLSLExtForSubgroupBallotBuiltin() {
__target_switch
{
@@ -6342,7 +6342,7 @@ void requireGLSLExtForSubgroupBallotBuiltin() {
}
__spirv_version(1.3)
-[require(cpp_cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
void setupExtForSubgroupBallotBuiltIn() {
__target_switch
{
@@ -6392,7 +6392,7 @@ public property uint gl_SubgroupID
public property uint gl_SubgroupSize
{
- [require(cpp_cuda_glsl_hlsl_spirv, subgroup_basic)]
+ [require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
get {
setupExtForSubgroupBasicBuiltIn();
return WaveGetLaneCount();
@@ -6401,7 +6401,7 @@ public property uint gl_SubgroupSize
public property uint gl_SubgroupInvocationID
{
- [require(cpp_cuda_glsl_hlsl_spirv, subgroup_basic)]
+ [require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
get {
setupExtForSubgroupBasicBuiltIn();
return WaveGetLaneIndex();
@@ -6625,7 +6625,7 @@ public void subgroupMemoryBarrierShared()
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_basic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_basic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
public bool subgroupElect()
{
__target_switch
@@ -6635,6 +6635,7 @@ public bool subgroupElect()
case glsl:
case spirv:
case hlsl:
+ case wgsl:
return WaveIsFirstLane();
}
@@ -6645,7 +6646,7 @@ public bool subgroupElect()
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_vote)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_vote)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_vote)]
public bool subgroupAll(bool value)
{
return WaveActiveAllTrue(value);
@@ -6654,7 +6655,7 @@ public bool subgroupAll(bool value)
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_vote)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_vote)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_vote)]
public bool subgroupAny(bool value)
{
return WaveActiveAnyTrue(value);
@@ -6688,7 +6689,7 @@ __generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupAdd(T value)
{
shader_subgroup_preamble<T>();
@@ -6699,7 +6700,7 @@ __generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupMul(T value)
{
shader_subgroup_preamble<T>();
@@ -6710,7 +6711,7 @@ __generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupMin(T value)
{
shader_subgroup_preamble<T>();
@@ -6721,7 +6722,7 @@ __generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupMax(T value)
{
shader_subgroup_preamble<T>();
@@ -6731,14 +6732,17 @@ public T subgroupMax(T value)
__generic<T : __BuiltinLogicalType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupAnd(T value)
{
shader_subgroup_preamble<T>();
__target_switch
{
- case glsl: __intrinsic_asm "subgroupAnd($0)";
+ case glsl:
+ case wgsl:
+ __intrinsic_asm "subgroupAnd($0)";
case spirv:
if (__isBool<T>()) {
return spirv_asm {
@@ -6758,14 +6762,17 @@ public T subgroupAnd(T value)
__generic<T : __BuiltinLogicalType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupOr(T value)
{
shader_subgroup_preamble<T>();
__target_switch
{
- case glsl: __intrinsic_asm "subgroupOr($0)";
+ case glsl:
+ case wgsl:
+ __intrinsic_asm "subgroupOr($0)";
case spirv:
if (__isBool<T>()) {
return spirv_asm {
@@ -6785,14 +6792,17 @@ public T subgroupOr(T value)
__generic<T : __BuiltinLogicalType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupXor(T value)
{
shader_subgroup_preamble<T>();
__target_switch
{
- case glsl: __intrinsic_asm "subgroupXor($0)";
+ case glsl:
+ case wgsl:
+ __intrinsic_asm "subgroupXor($0)";
case spirv:
if (__isBool<T>()) {
return spirv_asm {
@@ -6812,14 +6822,16 @@ public T subgroupXor(T value)
__generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupInclusiveAdd(T value)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupInclusiveAdd($0)";
case spirv:
if (__isFloat<T>())
@@ -6833,14 +6845,16 @@ public T subgroupInclusiveAdd(T value)
__generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupInclusiveMul(T value)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupInclusiveMul($0)";
case spirv:
if (__isFloat<T>())
@@ -6974,7 +6988,7 @@ __generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupExclusiveAdd(T value)
{
shader_subgroup_preamble<T>();
@@ -6986,7 +7000,7 @@ __generic<T : __BuiltinArithmeticType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public T subgroupExclusiveMul(T value)
{
shader_subgroup_preamble<T>();
@@ -7097,7 +7111,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupAdd(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7108,7 +7122,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupMul(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7119,7 +7133,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupMin(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7130,7 +7144,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupMax(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7140,14 +7154,18 @@ public vector<T,N> subgroupMax(vector<T,N> value)
__generic<T : __BuiltinLogicalType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupAnd(vector<T,N> value)
{
shader_subgroup_preamble<T>();
__target_switch
{
- case glsl: __intrinsic_asm "subgroupAnd($0)";
+ case glsl:
+ case wgsl:
+ // TODO: Bool inputs are invalid for WGSL, cast them to int or don't allow them to compile.
+ __intrinsic_asm "subgroupAnd($0)";
case spirv:
if (__isBool<T>()) {
return spirv_asm {
@@ -7168,14 +7186,17 @@ public vector<T,N> subgroupAnd(vector<T,N> value)
__generic<T : __BuiltinLogicalType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupOr(vector<T,N> value)
{
shader_subgroup_preamble<T>();
__target_switch
{
- case glsl: __intrinsic_asm "subgroupOr($0)";
+ case glsl:
+ case wgsl:
+ __intrinsic_asm "subgroupOr($0)";
case spirv:
if (__isBool<T>()) {
return spirv_asm {
@@ -7196,14 +7217,17 @@ public vector<T,N> subgroupOr(vector<T,N> value)
__generic<T : __BuiltinLogicalType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupXor(vector<T,N> value)
{
shader_subgroup_preamble<T>();
__target_switch
{
- case glsl: __intrinsic_asm "subgroupXor($0)";
+ case glsl:
+ case wgsl:
+ __intrinsic_asm "subgroupXor($0)";
case spirv:
if (__isBool<T>()) {
return spirv_asm {
@@ -7223,14 +7247,16 @@ public vector<T,N> subgroupXor(vector<T,N> value)
__generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupInclusiveAdd(vector<T,N> value)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupInclusiveAdd($0)";
case spirv:
if (__isFloat<T>())
@@ -7244,14 +7270,16 @@ public vector<T,N> subgroupInclusiveAdd(vector<T,N> value)
__generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_arithmetic)]
+[require(glsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupInclusiveMul(vector<T,N> value)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupInclusiveMul($0)";
case spirv:
if (__isFloat<T>())
@@ -7366,7 +7394,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupExclusiveAdd(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7378,7 +7406,7 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
public vector<T,N> subgroupExclusiveMul(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -7488,51 +7516,65 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
public T subgroupBroadcast(T value, uint id)
{
shader_subgroup_preamble<T>();
- return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, id);
+ __target_switch
+ {
+ case wgsl:
+ // WGSL's intrinsic does not accept non-const ids, do shuffle instead.
+ __intrinsic_asm "subgroupShuffle";
+ default:
+ return WaveBroadcastLaneAt(value, id);
+ }
}
__generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
public vector<T,N> subgroupBroadcast(vector<T,N> value, uint id)
{
shader_subgroup_preamble<T>();
- return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, id);
+ __target_switch
+ {
+ case wgsl:
+ // WGSL's intrinsic does not accept non-const ids, do shuffle instead.
+ __intrinsic_asm "subgroupShuffle";
+ default:
+ return WaveBroadcastLaneAt(value, id);
+ }
}
__generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
public T subgroupBroadcastFirst(T value)
{
shader_subgroup_preamble<T>();
- return WaveMaskReadLaneFirst(WaveGetActiveMask(), value);
+ return WaveReadLaneFirst(value);
}
__generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
public vector<T,N> subgroupBroadcastFirst(vector<T,N> value)
{
shader_subgroup_preamble<T>();
- return WaveMaskReadLaneFirst(WaveGetActiveMask(), value);
+ return WaveReadLaneFirst(value);
}
// WaveMaskBallot is not the same; it force trunc's
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
public uvec4 subgroupBallot(bool value)
{
return WaveActiveBallot(value);
@@ -7713,7 +7755,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
public T subgroupShuffle(T value, uint index)
{
shader_subgroup_preamble<T>();
@@ -7723,13 +7765,15 @@ public T subgroupShuffle(T value, uint index)
__generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
-[require(glsl_spirv, subgroup_shuffle)]
+__wgsl_extension(subgroups)
+[require(glsl_spirv_wgsl, subgroup_shuffle)]
[ForceInline] public T subgroupShuffleXor(T value, uint mask)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupShuffleXor($0,$1)";
case spirv:
return spirv_asm {
@@ -7743,7 +7787,7 @@ __generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
[ForceInline]
-[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
public vector<T,N> subgroupShuffle(vector<T,N> value, uint index)
{
shader_subgroup_preamble<T>();
@@ -7753,14 +7797,16 @@ public vector<T,N> subgroupShuffle(vector<T,N> value, uint index)
__generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_shuffle)]
+[require(glsl_spirv_wgsl, subgroup_shuffle)]
public vector<T,N> subgroupShuffleXor(vector<T,N> value, uint mask)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupShuffleXor($0,$1)";
case spirv:
return spirv_asm {
@@ -7776,14 +7822,16 @@ public vector<T,N> subgroupShuffleXor(vector<T,N> value, uint mask)
__generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle_relative)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_shufflerelative)]
+[require(glsl_spirv_wgsl, subgroup_shufflerelative)]
public T subgroupShuffleUp(T value, uint delta)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupShuffleUp($0, $1)";
case spirv:
return spirv_asm {
@@ -7796,14 +7844,16 @@ public T subgroupShuffleUp(T value, uint delta)
__generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle_relative)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_shufflerelative)]
+[require(glsl_spirv_wgsl, subgroup_shufflerelative)]
public T subgroupShuffleDown(T value, uint delta)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupShuffleDown($0, $1)";
case spirv:
return spirv_asm {
@@ -7817,14 +7867,16 @@ public T subgroupShuffleDown(T value, uint delta)
__generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle_relative)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_shufflerelative)]
+[require(glsl_spirv_wgsl, subgroup_shufflerelative)]
public vector<T,N> subgroupShuffleUp(vector<T,N> value, uint delta)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupShuffleUp($0, $1)";
case spirv:
return spirv_asm {
@@ -7837,14 +7889,16 @@ public vector<T,N> subgroupShuffleUp(vector<T,N> value, uint delta)
__generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle_relative)
+__wgsl_extension(subgroups)
[ForceInline]
-[require(glsl_spirv, subgroup_shufflerelative)]
+[require(glsl_spirv_wgsl, subgroup_shufflerelative)]
public vector<T,N> subgroupShuffleDown(vector<T,N> value, uint delta)
{
shader_subgroup_preamble<T>();
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupShuffleDown($0, $1)";
case spirv:
return spirv_asm {
@@ -8161,7 +8215,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv, subgroup_quad)]
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
public T subgroupQuadSwapHorizontal(T value)
{
shader_subgroup_preamble<T>();
@@ -8172,7 +8226,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv, subgroup_quad)]
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
public T subgroupQuadSwapVertical(T value)
{
shader_subgroup_preamble<T>();
@@ -8183,7 +8237,7 @@ __generic<T : __BuiltinType>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv, subgroup_quad)]
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
public T subgroupQuadSwapDiagonal(T value)
{
shader_subgroup_preamble<T>();
@@ -8206,7 +8260,7 @@ __generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv, subgroup_quad)]
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
public vector<T,N> subgroupQuadSwapHorizontal(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -8217,7 +8271,7 @@ __generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv, subgroup_quad)]
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
public vector<T,N> subgroupQuadSwapVertical(vector<T,N> value)
{
shader_subgroup_preamble<T>();
@@ -8228,7 +8282,7 @@ __generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_quad)
[ForceInline]
-[require(glsl_hlsl_spirv, subgroup_quad)]
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
public vector<T,N> subgroupQuadSwapDiagonal(vector<T,N> value)
{
shader_subgroup_preamble<T>();
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index ecab7ff93..3baad5c10 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -14460,42 +14460,44 @@ __generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadLan
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
-[require(glsl_hlsl_spirv, subgroup_quad)]
+__wgsl_extension(subgroups)
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
T QuadReadAcrossX(T localValue)
{
__target_switch
{
- case hlsl:
- __intrinsic_asm "QuadReadAcrossX";
- case glsl:
- __intrinsic_asm "subgroupQuadSwapHorizontal($0)";
+ case hlsl: __intrinsic_asm "QuadReadAcrossX";
+ case glsl: __intrinsic_asm "subgroupQuadSwapHorizontal($0)";
case spirv:
uint direction = 0u;
- return spirv_asm {
+ return spirv_asm
+ {
OpCapability GroupNonUniformQuad;
result:$$T = OpGroupNonUniformQuadSwap Subgroup $localValue $direction;
};
+ case wgsl: __intrinsic_asm "quadSwapX";
}
}
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
-[require(glsl_hlsl_spirv, subgroup_quad)]
+__wgsl_extension(subgroups)
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
vector<T,N> QuadReadAcrossX(vector<T,N> localValue)
{
__target_switch
{
- case hlsl:
- __intrinsic_asm "QuadReadAcrossX";
- case glsl:
- __intrinsic_asm "subgroupQuadSwapHorizontal($0)";
+ case hlsl: __intrinsic_asm "QuadReadAcrossX";
+ case glsl: __intrinsic_asm "subgroupQuadSwapHorizontal($0)";
case spirv:
uint direction = 0u;
- return spirv_asm {
+ return spirv_asm
+ {
OpCapability GroupNonUniformQuad;
result:$$vector<T,N> = OpGroupNonUniformQuadSwap Subgroup $localValue $direction;
};
+ case wgsl: __intrinsic_asm "quadSwapX";
}
}
__generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadAcrossX(matrix<T,N,M> localValue);
@@ -14504,85 +14506,88 @@ __generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadAcr
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
-[require(glsl_hlsl_spirv, subgroup_quad)]
+__wgsl_extension(subgroups)
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
T QuadReadAcrossY(T localValue)
{
__target_switch
{
- case hlsl:
- __intrinsic_asm "QuadReadAcrossY";
- case glsl:
- __intrinsic_asm "subgroupQuadSwapVertical($0)";
+ case hlsl: __intrinsic_asm "QuadReadAcrossY";
+ case glsl: __intrinsic_asm "subgroupQuadSwapVertical($0)";
case spirv:
uint direction = 1u;
- return spirv_asm {
+ return spirv_asm
+ {
OpCapability GroupNonUniformQuad;
result:$$T = OpGroupNonUniformQuadSwap Subgroup $localValue $direction;
};
+ case wgsl: __intrinsic_asm "quadSwapY";
}
}
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
-[require(glsl_hlsl_spirv, subgroup_quad)]
+__wgsl_extension(subgroups)
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
vector<T,N> QuadReadAcrossY(vector<T,N> localValue)
{
__target_switch
{
- case hlsl:
- __intrinsic_asm "QuadReadAcrossY";
- case glsl:
- __intrinsic_asm "subgroupQuadSwapVertical($0)";
+ case hlsl: __intrinsic_asm "QuadReadAcrossY";
+ case glsl: __intrinsic_asm "subgroupQuadSwapVertical($0)";
case spirv:
uint direction = 1u;
- return spirv_asm {
+ return spirv_asm
+ {
OpCapability GroupNonUniformQuad;
result:$$vector<T,N> = OpGroupNonUniformQuadSwap Subgroup $localValue $direction;
};
+ case wgsl: __intrinsic_asm "quadSwapY";
}
}
-
__generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadAcrossY(matrix<T,N,M> localValue);
/// @category wave
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
-[require(glsl_hlsl_spirv, subgroup_quad)]
+__wgsl_extension(subgroups)
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
T QuadReadAcrossDiagonal(T localValue)
{
__target_switch
{
- case hlsl:
- __intrinsic_asm "QuadReadAcrossDiagonal";
- case glsl:
- __intrinsic_asm "subgroupQuadSwapDiagonal($0)";
+ case hlsl: __intrinsic_asm "QuadReadAcrossDiagonal";
+ case glsl: __intrinsic_asm "subgroupQuadSwapDiagonal($0)";
case spirv:
uint direction = 2u;
- return spirv_asm {
+ return spirv_asm
+ {
OpCapability GroupNonUniformQuad;
result:$$T = OpGroupNonUniformQuadSwap Subgroup $localValue $direction;
};
+ case wgsl: __intrinsic_asm "quadSwapDiagonal";
}
}
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_quad)
__spirv_version(1.3)
-[require(glsl_hlsl_spirv, subgroup_quad)]
+__wgsl_extension(subgroups)
+[require(glsl_hlsl_spirv_wgsl, subgroup_quad)]
vector<T,N> QuadReadAcrossDiagonal(vector<T,N> localValue)
{
__target_switch
{
- case hlsl:
- __intrinsic_asm "QuadReadAcrossDiagonal";
- case glsl:
- __intrinsic_asm "subgroupQuadSwapDiagonal($0)";
+ case hlsl: __intrinsic_asm "QuadReadAcrossDiagonal";
+ case glsl: __intrinsic_asm "subgroupQuadSwapDiagonal($0)";
case spirv:
uint direction = 2u;
- return spirv_asm {
+ return spirv_asm
+ {
OpCapability GroupNonUniformQuad;
result:$$vector<T,N> = OpGroupNonUniformQuadSwap Subgroup $localValue $direction;
};
+ case wgsl: __intrinsic_asm "quadSwapDiagonal";
}
}
__generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadAcrossDiagonal(matrix<T,N,M> localValue);
@@ -14597,16 +14602,19 @@ for (auto opName : kWaveActiveBitOpEntries) {
__generic<T : __BuiltinIntegerType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
T WaveActive$(opName.hlslName)(T expr)
{
__target_switch
{
- case glsl: __intrinsic_asm "subgroup$(opName.glslName)($0)";
+ case glsl:
+ case wgsl:
+ __intrinsic_asm "subgroup$(opName.glslName)";
case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
case spirv:
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniform$(opName.spirvName) $$T result Subgroup Reduce $expr};
- default:
+ case cuda:
return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
}
}
@@ -14614,22 +14622,25 @@ T WaveActive$(opName.hlslName)(T expr)
__generic<T : __BuiltinIntegerType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
vector<T, N> WaveActive$(opName.hlslName)(vector<T, N> expr)
{
__target_switch
{
- case glsl: __intrinsic_asm "subgroup$(opName.glslName)($0)";
+ case glsl:
+ case wgsl:
+ __intrinsic_asm "subgroup$(opName.glslName)";
case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
case spirv:
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniform$(opName.spirvName) $$vector<T, N> result Subgroup Reduce $expr};
- default:
+ case cuda:
return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
}
}
__generic<T : __BuiltinIntegerType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
matrix<T, N, M> WaveActive$(opName.hlslName)(matrix<T, N, M> expr)
{
__target_switch
@@ -14637,12 +14648,13 @@ matrix<T, N, M> WaveActive$(opName.hlslName)(matrix<T, N, M> expr)
case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
case glsl:
case spirv:
+ case wgsl:
matrix<T,N,M> result;
[ForceUnroll]
for (int i = 0; i < N; ++i)
result[i] = WaveActive$(opName.hlslName)(expr[i]);
return result;
- default:
+ case cuda:
return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
}
}
@@ -14659,12 +14671,15 @@ for (const char* opName : kWaveActiveMinMaxNames) {
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
T WaveActive$(opName)(T expr)
{
__target_switch
{
- case glsl: __intrinsic_asm "subgroup$(opName)($0)";
+ case glsl:
+ case wgsl:
+ __intrinsic_asm "subgroup$(opName)";
case hlsl: __intrinsic_asm "WaveActive$(opName)";
case spirv:
if (__isFloat<T>())
@@ -14673,7 +14688,7 @@ T WaveActive$(opName)(T expr)
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformU$(opName) $$T result Subgroup Reduce $expr};
else
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformS$(opName) $$T result Subgroup Reduce $expr};
- default:
+ case cuda:
return WaveMask$(opName)(WaveGetActiveMask(), expr);
}
}
@@ -14681,12 +14696,15 @@ T WaveActive$(opName)(T expr)
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
vector<T, N> WaveActive$(opName)(vector<T, N> expr)
{
__target_switch
{
- case glsl: __intrinsic_asm "subgroup$(opName)($0)";
+ case glsl:
+ case wgsl:
+ __intrinsic_asm "subgroup$(opName)";
case hlsl: __intrinsic_asm "WaveActive$(opName)";
case spirv:
if (__isFloat<T>())
@@ -14695,13 +14713,13 @@ vector<T, N> WaveActive$(opName)(vector<T, N> expr)
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformU$(opName) $$vector<T, N> result Subgroup Reduce $expr};
else
return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformS$(opName) $$vector<T, N> result Subgroup Reduce $expr};
- default:
+ case cuda:
return WaveMask$(opName)(WaveGetActiveMask(), expr);
}
}
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
matrix<T, N, M> WaveActive$(opName)(matrix<T, N, M> expr)
{
__target_switch
@@ -14709,12 +14727,13 @@ matrix<T, N, M> WaveActive$(opName)(matrix<T, N, M> expr)
case hlsl: __intrinsic_asm "WaveActive$(opName)";
case glsl:
case spirv:
+ case wgsl:
matrix<T, N, M> result;
[ForceUnroll]
for (int i = 0; i < N; ++i)
result[i] = WaveActive$(opName)(expr[i]);
return result;
- default:
+ case cuda:
return WaveMask$(opName)(WaveGetActiveMask(), expr);
}
}
@@ -14733,7 +14752,8 @@ for (auto opName : kWaveActivProductSumNames) {
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
T WaveActive$(opName.hlslName)(T expr)
{
__target_switch
@@ -14757,7 +14777,8 @@ T WaveActive$(opName.hlslName)(T expr)
};
}
else return expr;
- default:
+ case wgsl: __intrinsic_asm "subgroup$(opName.glslName)";
+ case cuda:
return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
}
}
@@ -14765,7 +14786,8 @@ T WaveActive$(opName.hlslName)(T expr)
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
vector<T,N> WaveActive$(opName.hlslName)(vector<T,N> expr)
{
__target_switch
@@ -14789,13 +14811,14 @@ vector<T,N> WaveActive$(opName.hlslName)(vector<T,N> expr)
};
}
else return expr;
- default:
+ case wgsl: __intrinsic_asm "subgroup$(opName.glslName)";
+ case cuda:
return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
}
}
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
matrix<T, N, M> WaveActive$(opName.hlslName)(matrix<T, N, M> expr)
{
__target_switch
@@ -14803,12 +14826,13 @@ matrix<T, N, M> WaveActive$(opName.hlslName)(matrix<T, N, M> expr)
case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
case glsl:
case spirv:
+ case wgsl:
matrix<T, N, M> result;
[ForceUnroll]
for (int i = 0; i < N; ++i)
result[i] = WaveActive$(opName.hlslName)(expr[i]);
return result;
- default:
+ case cuda:
return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
}
}
@@ -14877,22 +14901,23 @@ bool WaveActiveAllEqual(matrix<T, N, M> value)
/// @category wave
__glsl_extension(GL_KHR_shader_subgroup_vote)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_vote)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_vote)]
bool WaveActiveAllTrue(bool condition)
{
__target_switch
{
case glsl:
- __intrinsic_asm "subgroupAll($0)";
- case hlsl:
- __intrinsic_asm "WaveActiveAllTrue($0)";
+ case wgsl:
+ __intrinsic_asm "subgroupAll";
+ case hlsl: __intrinsic_asm "WaveActiveAllTrue($0)";
case spirv:
return spirv_asm
{
OpCapability GroupNonUniformVote;
OpGroupNonUniformAll $$bool result Subgroup $condition
};
- default:
+ case cuda:
return WaveMaskAllTrue(WaveGetActiveMask(), condition);
}
}
@@ -14900,13 +14925,15 @@ bool WaveActiveAllTrue(bool condition)
/// @category wave
__glsl_extension(GL_KHR_shader_subgroup_vote)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_vote)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_vote)]
bool WaveActiveAnyTrue(bool condition)
{
__target_switch
{
case glsl:
- __intrinsic_asm "subgroupAny($0)";
+ case wgsl:
+ __intrinsic_asm "subgroupAny";
case hlsl:
__intrinsic_asm "WaveActiveAnyTrue($0)";
case spirv:
@@ -14923,14 +14950,16 @@ bool WaveActiveAnyTrue(bool condition)
/// @category wave
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
+__wgsl_extension(subgroups)
[NonUniformReturn]
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
uint4 WaveActiveBallot(bool condition)
{
__target_switch
{
case glsl:
- __intrinsic_asm "subgroupBallot($0)";
+ case wgsl:
+ __intrinsic_asm "subgroupBallot";
case hlsl:
__intrinsic_asm "WaveActiveBallot";
case spirv:
@@ -15004,23 +15033,23 @@ uint WaveGetLaneIndex()
/// @category wave
__glsl_extension(GL_KHR_shader_subgroup_basic)
__spirv_version(1.3)
+__wgsl_extension(subgroups)
[NonUniformReturn]
-[require(cuda_glsl_hlsl_spirv, subgroup_basic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)]
bool WaveIsFirstLane()
{
__target_switch
{
- case glsl:
- __intrinsic_asm "subgroupElect()";
- case hlsl:
- __intrinsic_asm "WaveIsFirstLane()";
+ case glsl: __intrinsic_asm "subgroupElect()";
+ case hlsl: __intrinsic_asm "WaveIsFirstLane()";
case spirv:
return spirv_asm
{
OpCapability GroupNonUniformBallot;
OpGroupNonUniformElect $$bool result Subgroup
};
- default:
+ case wgsl: __intrinsic_asm "subgroupElect";
+ case cuda:
return WaveMaskIsFirstLane(WaveGetActiveMask());
}
}
@@ -15059,7 +15088,8 @@ uint _WaveCountBits(uint4 value)
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
T WavePrefixProduct(T expr)
{
__target_switch
@@ -15083,7 +15113,8 @@ T WavePrefixProduct(T expr)
};
}
else return expr;
- default:
+ case wgsl: __intrinsic_asm "subgroupExclusiveMul";
+ case cuda:
return WaveMaskPrefixProduct(WaveGetActiveMask(), expr);
}
}
@@ -15092,7 +15123,8 @@ T WavePrefixProduct(T expr)
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
vector<T,N> WavePrefixProduct(vector<T,N> expr)
{
__target_switch
@@ -15113,13 +15145,14 @@ vector<T,N> WavePrefixProduct(vector<T,N> expr)
};
}
else return expr;
- default:
+ case wgsl: __intrinsic_asm "subgroupExclusiveMul";
+ case cuda:
return WaveMaskPrefixProduct(WaveGetActiveMask(), expr);
}
}
/// @category wave
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
matrix<T, N, M> WavePrefixProduct(matrix<T, N, M> expr)
{
__target_switch
@@ -15127,11 +15160,12 @@ matrix<T, N, M> WavePrefixProduct(matrix<T, N, M> expr)
case hlsl: __intrinsic_asm "WavePrefixProduct";
case glsl:
case spirv:
+ case wgsl:
matrix<T, N, M> result;
for (int i = 0; i < N; ++i)
result[i] = WavePrefixProduct(expr[i]);
return result;
- default:
+ case cuda:
return WaveMaskPrefixProduct(WaveGetActiveMask(), expr);
}
}
@@ -15140,7 +15174,8 @@ matrix<T, N, M> WavePrefixProduct(matrix<T, N, M> expr)
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
T WavePrefixSum(T expr)
{
__target_switch
@@ -15161,7 +15196,8 @@ T WavePrefixSum(T expr)
};
}
else return expr;
- default:
+ case wgsl: __intrinsic_asm "subgroupExclusiveAdd";
+ case cuda:
return WaveMaskPrefixSum(WaveGetActiveMask(), expr);
}
}
@@ -15169,7 +15205,8 @@ T WavePrefixSum(T expr)
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
vector<T,N> WavePrefixSum(vector<T,N> expr)
{
__target_switch
@@ -15190,13 +15227,14 @@ vector<T,N> WavePrefixSum(vector<T,N> expr)
};
}
else return expr;
- default:
+ case wgsl: __intrinsic_asm "subgroupExclusiveAdd";
+ case cuda:
return WaveMaskPrefixSum(WaveGetActiveMask(), expr);
}
}
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)]
matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr)
{
__target_switch
@@ -15204,11 +15242,12 @@ matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr)
case hlsl: __intrinsic_asm "WavePrefixSum";
case glsl:
case spirv:
+ case wgsl:
matrix<T, N, M> result;
for (int i = 0; i < N; ++i)
result[i] = WavePrefixSum(expr[i]);
return result;
- default:
+ case cuda:
return WaveMaskPrefixSum(WaveGetActiveMask(), expr);
}
}
@@ -15217,7 +15256,8 @@ matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr)
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
T WaveReadLaneFirst(T expr)
{
__target_switch
@@ -15228,7 +15268,8 @@ T WaveReadLaneFirst(T expr)
case hlsl: __intrinsic_asm "WaveReadLaneFirst";
case spirv:
return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$T result Subgroup $expr};
- default:
+ case wgsl: __intrinsic_asm "subgroupBroadcastFirst";
+ case cuda:
return WaveMaskReadLaneFirst(WaveGetActiveMask(), expr);
}
}
@@ -15236,7 +15277,8 @@ T WaveReadLaneFirst(T expr)
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
vector<T,N> WaveReadLaneFirst(vector<T,N> expr)
{
__target_switch
@@ -15247,13 +15289,14 @@ vector<T,N> WaveReadLaneFirst(vector<T,N> expr)
case hlsl: __intrinsic_asm "WaveReadLaneFirst";
case spirv:
return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$vector<T,N> result Subgroup $expr};
- default:
+ case wgsl: __intrinsic_asm "subgroupBroadcastFirst";
+ case cuda:
return WaveMaskReadLaneFirst(WaveGetActiveMask(), expr);
}
}
__generic<T : __BuiltinType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
matrix<T,N,M> WaveReadLaneFirst(matrix<T,N,M> expr)
{
__target_switch
@@ -15261,11 +15304,12 @@ matrix<T,N,M> WaveReadLaneFirst(matrix<T,N,M> expr)
case hlsl: __intrinsic_asm "WaveReadLaneFirst";
case glsl:
case spirv:
+ case wgsl:
matrix<T, N, M> result;
for (int i = 0; i < N; ++i)
result[i] = WaveReadLaneFirst(expr[i]);
return result;
- default:
+ case cuda:
return WaveMaskReadLaneFirst(WaveGetActiveMask(), expr);
}
}
@@ -15280,7 +15324,8 @@ matrix<T,N,M> WaveReadLaneFirst(matrix<T,N,M> expr)
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
T WaveBroadcastLaneAt(T value, constexpr int lane)
{
__target_switch
@@ -15292,7 +15337,8 @@ T WaveBroadcastLaneAt(T value, constexpr int lane)
case spirv:
let ulane = uint(lane);
return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcast $$T result Subgroup $value $ulane};
- default:
+ case wgsl: __intrinsic_asm "subgroupBroadcast";
+ case cuda:
return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, lane);
}
}
@@ -15301,7 +15347,8 @@ T WaveBroadcastLaneAt(T value, constexpr int lane)
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
vector<T,N> WaveBroadcastLaneAt(vector<T,N> value, constexpr int lane)
{
__target_switch
@@ -15313,13 +15360,14 @@ vector<T,N> WaveBroadcastLaneAt(vector<T,N> value, constexpr int lane)
case spirv:
let ulane = uint(lane);
return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcast $$vector<T,N> result Subgroup $value $ulane};
- default:
+ case wgsl: __intrinsic_asm "subgroupBroadcast";
+ case cuda:
return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, lane);
}
}
__generic<T : __BuiltinType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
matrix<T, N, M> WaveBroadcastLaneAt(matrix<T, N, M> value, constexpr int lane)
{
__target_switch
@@ -15328,11 +15376,12 @@ matrix<T, N, M> WaveBroadcastLaneAt(matrix<T, N, M> value, constexpr int lane)
case hlsl: __intrinsic_asm "WaveReadLaneAt";
case glsl:
case spirv:
+ case wgsl:
matrix<T, N, M> result;
for (int i = 0; i < N; ++i)
result[i] = WaveBroadcastLaneAt(value[i], lane);
return result;
- default:
+ case cuda:
return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, lane);
}
}
@@ -15343,7 +15392,8 @@ matrix<T, N, M> WaveBroadcastLaneAt(matrix<T, N, M> value, constexpr int lane)
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
T WaveReadLaneAt(T value, int lane)
{
__target_switch
@@ -15355,15 +15405,17 @@ T WaveReadLaneAt(T value, int lane)
case spirv:
let ulane = uint(lane);
return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$T result Subgroup $value $ulane};
- default:
+ case wgsl: __intrinsic_asm "subgroupShuffle";
+ case cuda:
return WaveMaskReadLaneAt(WaveGetActiveMask(), value, lane);
}
}
__generic<T : __BuiltinType, let N : int>
-__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
-[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)]
+__spirv_version(1.3)
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
vector<T,N> WaveReadLaneAt(vector<T,N> value, int lane)
{
__target_switch
@@ -15375,13 +15427,14 @@ vector<T,N> WaveReadLaneAt(vector<T,N> value, int lane)
case spirv:
let ulane = uint(lane);
return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$vector<T,N> result Subgroup $value $ulane};
- default:
+ case wgsl: __intrinsic_asm "subgroupShuffle";
+ case cuda:
return WaveMaskReadLaneAt(WaveGetActiveMask(), value, lane);
}
}
__generic<T : __BuiltinType, let N : int, let M : int>
-[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)]
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
matrix<T, N, M> WaveReadLaneAt(matrix<T, N, M> value, int lane)
{
__target_switch
@@ -15390,11 +15443,12 @@ matrix<T, N, M> WaveReadLaneAt(matrix<T, N, M> value, int lane)
case hlsl: __intrinsic_asm "WaveReadLaneAt";
case glsl:
case spirv:
+ case wgsl:
matrix<T,N,M> result;
for (int i = 0; i < N; ++i)
result[i] = WaveReadLaneAt(value[i], lane);
return result;
- default:
+ case cuda:
return WaveMaskReadLaneAt(WaveGetActiveMask(), value, lane);
}
}
@@ -15406,7 +15460,8 @@ matrix<T, N, M> WaveReadLaneAt(matrix<T, N, M> value, int lane)
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
T WaveShuffle(T value, int lane)
{
__target_switch
@@ -15418,7 +15473,8 @@ T WaveShuffle(T value, int lane)
case spirv:
let ulane = uint(lane);
return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$T result Subgroup $value $ulane};
- default:
+ case wgsl: __intrinsic_asm "subgroupShuffle";
+ case cuda:
return WaveMaskShuffle(WaveGetActiveMask(), value, lane);
}
}
@@ -15427,7 +15483,8 @@ T WaveShuffle(T value, int lane)
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)]
vector<T,N> WaveShuffle(vector<T,N> value, int lane)
{
__target_switch
@@ -15439,7 +15496,8 @@ vector<T,N> WaveShuffle(vector<T,N> value, int lane)
case spirv:
let ulane = uint(lane);
return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$vector<T,N> result Subgroup $value $ulane};
- default:
+ case wgsl: __intrinsic_asm "subgroupShuffle";
+ case cuda:
return WaveMaskShuffle(WaveGetActiveMask(), value, lane);
}
}
@@ -15482,12 +15540,14 @@ uint WavePrefixCountBits(bool value)
/// @category wave
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-[require(cuda_glsl_hlsl_spirv, subgroup_ballot)]
+__wgsl_extension(subgroups)
+[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)]
uint4 WaveGetConvergedMulti()
{
__target_switch
{
case glsl:
+ case wgsl:
__intrinsic_asm "subgroupBallot(true)";
case hlsl: __intrinsic_asm "WaveActiveBallot(true)";
case cuda: __intrinsic_asm "make_uint4(__activemask(), 0, 0, 0)";
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 36dddd15f..cc4901236 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -218,6 +218,15 @@ class RequiredSPIRVVersionModifier : public Modifier
};
// A modifier to tag something as an intrinsic that requires
+// a certain WGSL extension to be enabled when used
+class RequiredWGSLExtensionModifier : public Modifier
+{
+ SLANG_AST_CLASS(RequiredWGSLExtensionModifier)
+
+ Token extensionNameToken;
+};
+
+// A modifier to tag something as an intrinsic that requires
// a certain CUDA SM version to be enabled when used. Specified as "major.minor"
class RequiredCUDASMVersionModifier : public Modifier
{
diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp
index 22a5ec8d1..597a35f4b 100644
--- a/source/slang/slang-ast-print.cpp
+++ b/source/slang/slang-ast-print.cpp
@@ -434,6 +434,8 @@ void ASTPrinter::addDeclKindPrefix(Decl* decl)
continue;
if (as<RequiredGLSLExtensionModifier>(modifier))
continue;
+ if (as<RequiredWGSLExtensionModifier>(modifier))
+ continue;
if (as<GLSLLayoutModifierGroupMarker>(modifier))
continue;
if (as<HLSLLayoutSemantic>(modifier))
diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef
index 3be09b8d3..f98be0e32 100644
--- a/source/slang/slang-capabilities.capdef
+++ b/source/slang/slang-capabilities.capdef
@@ -331,6 +331,10 @@ alias cuda_hlsl_metal_spirv = cuda | hlsl | metal | spirv;
/// [Compound]
alias cuda_glsl_hlsl_spirv = cuda | glsl | hlsl | spirv;
+/// CUDA, GLSL, HLSL, SPIRV, and WGSL code-gen targets
+/// [Compound]
+alias cuda_glsl_hlsl_spirv_wgsl = cuda | glsl | hlsl | spirv | wgsl;
+
/// CUDA, GLSL, HLSL, Metal, and SPIRV code-gen targets
/// [Compound]
alias cuda_glsl_hlsl_metal_spirv = cuda | glsl | hlsl | metal | spirv;
@@ -387,6 +391,10 @@ alias glsl_metal_spirv_wgsl = glsl | metal | spirv | wgsl;
/// [Compound]
alias glsl_spirv = glsl | spirv;
+/// GLSL, SPIRV, and WGSL code-gen targets
+/// [Compound]
+alias glsl_spirv_wgsl = glsl | spirv | wgsl;
+
/// HLSL, and SPIRV code-gen targets
/// [Compound]
alias hlsl_spirv = hlsl | spirv;
@@ -1931,13 +1939,18 @@ alias shader5_sm_5_0 = GL_ARB_gpu_shader5 | sm_5_0_version;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_basic'
/// [Compound]
-alias subgroup_basic = GL_KHR_shader_subgroup_basic | _sm_6_0 | _cuda_sm_7_0;
+alias subgroup_basic = GL_KHR_shader_subgroup_basic
+ | _sm_6_0
+ | _cuda_sm_7_0
+ | wgsl
+ ;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_ballot'
/// [Compound]
alias subgroup_ballot = spirv_1_0 + GL_KHR_shader_subgroup_ballot
| glsl + GL_KHR_shader_subgroup_ballot + shader5_sm_5_0
| _sm_6_0 + shader5_sm_5_0
| _cuda_sm_7_0 + shader5_sm_5_0
+ | wgsl
;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_ballot_activemask'
/// [Compound]
@@ -1952,28 +1965,50 @@ alias subgroup_basic_ballot = glsl + GL_KHR_shader_subgroup_basic + subgroup_bal
| spirv + GL_KHR_shader_subgroup_basic + subgroup_ballot
| hlsl + subgroup_ballot
| cuda + subgroup_ballot
+ | wgsl
;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_vote'
/// [Compound]
-alias subgroup_vote = GL_KHR_shader_subgroup_vote | _sm_6_0 | _cuda_sm_7_0;
+alias subgroup_vote = GL_KHR_shader_subgroup_vote
+ | _sm_6_0
+ | _cuda_sm_7_0
+ | wgsl
+ ;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_vote'
/// [Compound]
alias shaderinvocationgroup = subgroup_vote;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_arithmetic'
/// [Compound]
-alias subgroup_arithmetic = GL_KHR_shader_subgroup_arithmetic | _sm_6_0 | _cuda_sm_7_0;
+alias subgroup_arithmetic = GL_KHR_shader_subgroup_arithmetic
+ | _sm_6_0
+ | _cuda_sm_7_0
+ | wgsl
+ ;
+
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_shuffle'
/// [Compound]
-alias subgroup_shuffle = GL_KHR_shader_subgroup_shuffle | _sm_6_0 | _cuda_sm_7_0;
+alias subgroup_shuffle = GL_KHR_shader_subgroup_shuffle
+ | _sm_6_0
+ | _cuda_sm_7_0
+ | wgsl
+ ;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_shuffle_relative'
/// [Compound]
-alias subgroup_shufflerelative = GL_KHR_shader_subgroup_shuffle_relative | _sm_6_0 | _cuda_sm_7_0;
+alias subgroup_shufflerelative = GL_KHR_shader_subgroup_shuffle_relative
+ | _sm_6_0
+ | _cuda_sm_7_0
+ | wgsl
+ ;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_clustered'
/// [Compound]
alias subgroup_clustered = GL_KHR_shader_subgroup_clustered | _sm_6_0 | _cuda_sm_7_0;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_quad'
/// [Compound]
-alias subgroup_quad = GL_KHR_shader_subgroup_quad | _sm_6_0 | _cuda_sm_7_0;
+alias subgroup_quad = GL_KHR_shader_subgroup_quad
+ | _sm_6_0
+ | _cuda_sm_7_0
+ | wgsl
+ ;
/// Capabilities required to use GLSL-style subgroup operations 'subgroup_partitioned'
/// [Compound]
alias subgroup_partitioned = GL_NV_shader_subgroup_partitioned + subgroup_ballot_activemask | _sm_6_5 | _cuda_sm_7_0;
diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp
index 448534ce8..04ebb753c 100644
--- a/source/slang/slang-compiler.cpp
+++ b/source/slang/slang-compiler.cpp
@@ -28,7 +28,7 @@
// Artifact output
#include "slang-artifact-output-util.h"
#include "slang-emit-cuda.h"
-#include "slang-glsl-extension-tracker.h"
+#include "slang-extension-tracker.h"
#include "slang-lower-to-ir.h"
#include "slang-mangle.h"
#include "slang-parameter-binding.h"
@@ -658,7 +658,7 @@ static void _appendCodeWithPath(
outCodeBuilder << fileContent << "\n";
}
-void trackGLSLTargetCaps(GLSLExtensionTracker* extensionTracker, CapabilitySet const& caps)
+void trackGLSLTargetCaps(ShaderExtensionTracker* extensionTracker, CapabilitySet const& caps)
{
for (auto& conjunctions : caps.getAtomSets())
{
@@ -1037,8 +1037,11 @@ static RefPtr<ExtensionTracker> _newExtensionTracker(CodeGenTarget target)
}
case CodeGenTarget::SPIRV:
case CodeGenTarget::GLSL:
+ case CodeGenTarget::WGSL:
+ case CodeGenTarget::WGSLSPIRV:
+ case CodeGenTarget::WGSLSPIRVAssembly:
{
- return new GLSLExtensionTracker;
+ return new ShaderExtensionTracker;
}
default:
return nullptr;
@@ -1261,7 +1264,7 @@ SlangResult CodeGenContext::emitWithDownstreamForEntryPoints(ComPtr<IArtifact>&
if (auto endToEndReq = isPassThroughEnabled())
{
// If we are pass through, we may need to set extension tracker state.
- if (GLSLExtensionTracker* glslTracker = as<GLSLExtensionTracker>(extensionTracker))
+ if (ShaderExtensionTracker* glslTracker = as<ShaderExtensionTracker>(extensionTracker))
{
trackGLSLTargetCaps(glslTracker, getTargetCaps());
}
@@ -1400,7 +1403,7 @@ SlangResult CodeGenContext::emitWithDownstreamForEntryPoints(ComPtr<IArtifact>&
options.flags |= CompileOptions::Flag::EnableFloat16;
}
}
- else if (GLSLExtensionTracker* glslTracker = as<GLSLExtensionTracker>(extensionTracker))
+ else if (ShaderExtensionTracker* glslTracker = as<ShaderExtensionTracker>(extensionTracker))
{
DownstreamCompileOptions::CapabilityVersion version;
version.kind = DownstreamCompileOptions::CapabilityVersion::Kind::SPIRV;
diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp
index 1429139f9..25dab3fb3 100644
--- a/source/slang/slang-emit-glsl.cpp
+++ b/source/slang/slang-emit-glsl.cpp
@@ -15,13 +15,13 @@
namespace Slang
{
-void trackGLSLTargetCaps(GLSLExtensionTracker* extensionTracker, CapabilitySet const& caps);
+void trackGLSLTargetCaps(ShaderExtensionTracker* extensionTracker, CapabilitySet const& caps);
GLSLSourceEmitter::GLSLSourceEmitter(const Desc& desc)
: Super(desc)
{
m_glslExtensionTracker =
- dynamicCast<GLSLExtensionTracker>(desc.codeGenContext->getExtensionTracker());
+ dynamicCast<ShaderExtensionTracker>(desc.codeGenContext->getExtensionTracker());
SLANG_ASSERT(m_glslExtensionTracker);
}
@@ -2997,7 +2997,7 @@ void GLSLSourceEmitter::emitFrontMatterImpl(TargetRequest* targetReq)
trackGLSLTargetCaps(m_glslExtensionTracker, targetReq->getTargetCaps());
StringBuilder builder;
- m_glslExtensionTracker->appendExtensionRequireLines(builder);
+ m_glslExtensionTracker->appendExtensionRequireLinesForGLSL(builder);
m_writer->emit(builder.getUnownedSlice());
}
diff --git a/source/slang/slang-emit-glsl.h b/source/slang/slang-emit-glsl.h
index b07b410ca..8308a9954 100644
--- a/source/slang/slang-emit-glsl.h
+++ b/source/slang/slang-emit-glsl.h
@@ -3,7 +3,7 @@
#define SLANG_EMIT_GLSL_H
#include "slang-emit-c-like.h"
-#include "slang-glsl-extension-tracker.h"
+#include "slang-extension-tracker.h"
namespace Slang
{
@@ -180,7 +180,7 @@ protected:
Dictionary<IRInst*, HashSet<IRFunc*>> m_referencingEntryPoints;
- RefPtr<GLSLExtensionTracker> m_glslExtensionTracker;
+ RefPtr<ShaderExtensionTracker> m_glslExtensionTracker;
};
} // namespace Slang
diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp
index ce60cc2a0..aea766f9f 100644
--- a/source/slang/slang-emit-wgsl.cpp
+++ b/source/slang/slang-emit-wgsl.cpp
@@ -49,6 +49,14 @@ fn _slang_getNan() -> f32
}
)";
+WGSLSourceEmitter::WGSLSourceEmitter(const Desc& desc)
+ : CLikeSourceEmitter(desc)
+{
+ m_extensionTracker =
+ dynamicCast<ShaderExtensionTracker>(desc.codeGenContext->getExtensionTracker());
+ SLANG_ASSERT(m_extensionTracker);
+}
+
void WGSLSourceEmitter::emitSwitchCaseSelectorsImpl(
const SwitchRegion::Case* const currentCase,
const bool isDefault)
@@ -1556,6 +1564,10 @@ void WGSLSourceEmitter::emitFrontMatterImpl(TargetRequest* /* targetReq */)
m_writer->emit("enable f16;\n");
m_writer->emit("\n");
}
+
+ StringBuilder builder;
+ m_extensionTracker->appendExtensionRequireLinesForWGSL(builder);
+ m_writer->emit(builder.getUnownedSlice());
}
void WGSLSourceEmitter::emitIntrinsicCallExprImpl(
@@ -1626,4 +1638,28 @@ void WGSLSourceEmitter::emitInterpolationModifiersImpl(
// https://www.w3.org/TR/WGSL/#interpolation
}
+void WGSLSourceEmitter::_requireExtension(const UnownedStringSlice& name)
+{
+ m_extensionTracker->requireExtension(name);
+}
+
+void WGSLSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst)
+{
+ for (auto decoration : inst->getDecorations())
+ {
+ if (const auto extensionDecoration = as<IRRequireWGSLExtensionDecoration>(decoration))
+ {
+ _requireExtension(extensionDecoration->getExtensionName());
+
+ // TODO: Make this cleaner and only enable this extension if f16 is actually used on the
+ // subgroup intrinsic. Check float type in meta file.
+ if (m_f16ExtensionEnabled && extensionDecoration->getExtensionName() == "subgroups")
+ {
+ String extName = "subgroups_f16";
+ _requireExtension(extName.getUnownedSlice());
+ }
+ }
+ }
+}
+
} // namespace Slang
diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h
index 714a722d7..441933b57 100644
--- a/source/slang/slang-emit-wgsl.h
+++ b/source/slang/slang-emit-wgsl.h
@@ -1,6 +1,7 @@
#pragma once
#include "slang-emit-c-like.h"
+#include "slang-extension-tracker.h"
namespace Slang
{
@@ -8,10 +9,8 @@ namespace Slang
class WGSLSourceEmitter : public CLikeSourceEmitter
{
public:
- WGSLSourceEmitter(const Desc& desc)
- : CLikeSourceEmitter(desc)
- {
- }
+ explicit WGSLSourceEmitter(const Desc& desc);
+
virtual bool isResourceTypeBindless(IRType* type) SLANG_OVERRIDE
{
SLANG_UNUSED(type);
@@ -58,10 +57,14 @@ public:
EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE;
virtual void emitGlobalParamDefaultVal(IRGlobalParam* varDecl) SLANG_OVERRIDE;
+ virtual void handleRequiredCapabilitiesImpl(IRInst* inst) SLANG_OVERRIDE;
+
void emit(const AddressSpace addressSpace);
virtual bool shouldFoldInstIntoUseSites(IRInst* inst) SLANG_OVERRIDE;
+ virtual RefObject* getExtensionTracker() SLANG_OVERRIDE { return m_extensionTracker; }
+
private:
bool maybeEmitSystemSemantic(IRInst* inst);
@@ -73,7 +76,11 @@ private:
const char* getWgslImageFormat(IRTextureTypeBase* type);
+ void _requireExtension(const UnownedStringSlice& name);
+
bool m_f16ExtensionEnabled = false;
+
+ RefPtr<ShaderExtensionTracker> m_extensionTracker;
};
} // namespace Slang
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 70521c7ee..58376bbc1 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -1400,10 +1400,10 @@ Result linkAndOptimizeIR(
case CodeGenTarget::SPIRV:
case CodeGenTarget::SPIRVAssembly:
{
- GLSLExtensionTracker glslExtensionTracker;
- GLSLExtensionTracker* glslExtensionTrackerPtr =
+ ShaderExtensionTracker glslExtensionTracker;
+ ShaderExtensionTracker* glslExtensionTrackerPtr =
options.sourceEmitter
- ? as<GLSLExtensionTracker>(options.sourceEmitter->getExtensionTracker())
+ ? as<ShaderExtensionTracker>(options.sourceEmitter->getExtensionTracker())
: &glslExtensionTracker;
#if 0
diff --git a/source/slang/slang-glsl-extension-tracker.cpp b/source/slang/slang-extension-tracker.cpp
index 268f123f7..f3818cbba 100644
--- a/source/slang/slang-glsl-extension-tracker.cpp
+++ b/source/slang/slang-extension-tracker.cpp
@@ -1,10 +1,10 @@
-// slang-glsl-extension-tracker.cpp
-#include "slang-glsl-extension-tracker.h"
+// slang-extension-tracker.cpp
+#include "slang-extension-tracker.h"
namespace Slang
{
-void GLSLExtensionTracker::appendExtensionRequireLines(StringBuilder& ioBuilder) const
+void ShaderExtensionTracker::appendExtensionRequireLinesForGLSL(StringBuilder& ioBuilder) const
{
for (const auto& extension : m_extensionPool.getSlices())
{
@@ -14,7 +14,17 @@ void GLSLExtensionTracker::appendExtensionRequireLines(StringBuilder& ioBuilder)
}
}
-void GLSLExtensionTracker::requireSPIRVVersion(const SemanticVersion& version)
+void ShaderExtensionTracker::appendExtensionRequireLinesForWGSL(StringBuilder& ioBuilder) const
+{
+ for (const auto& extension : m_extensionPool.getSlices())
+ {
+ ioBuilder.append("enable ");
+ ioBuilder.append(extension);
+ ioBuilder.append(";\n");
+ }
+}
+
+void ShaderExtensionTracker::requireSPIRVVersion(const SemanticVersion& version)
{
if (version > m_spirvVersion)
{
@@ -22,7 +32,7 @@ void GLSLExtensionTracker::requireSPIRVVersion(const SemanticVersion& version)
}
}
-void GLSLExtensionTracker::requireVersion(ProfileVersion version)
+void ShaderExtensionTracker::requireVersion(ProfileVersion version)
{
// Check if this profile is newer
if ((UInt)version > (UInt)m_profileVersion)
@@ -31,7 +41,7 @@ void GLSLExtensionTracker::requireVersion(ProfileVersion version)
}
}
-void GLSLExtensionTracker::requireBaseTypeExtension(BaseType baseType)
+void ShaderExtensionTracker::requireBaseTypeExtension(BaseType baseType)
{
uint32_t bit = 1 << int(baseType);
if (m_hasBaseTypeFlags & bit)
diff --git a/source/slang/slang-glsl-extension-tracker.h b/source/slang/slang-extension-tracker.h
index 08e0c9ef1..7134c4ff5 100644
--- a/source/slang/slang-glsl-extension-tracker.h
+++ b/source/slang/slang-extension-tracker.h
@@ -1,8 +1,6 @@
-// slang-glsl-extension-tracker.h
-#ifndef SLANG_GLSL_EXTENSION_TRACKER_H
-#define SLANG_GLSL_EXTENSION_TRACKER_H
+// slang-extension-tracker.h
+#pragma once
-#include "../core/slang-basic.h"
#include "../core/slang-semantic-version.h"
#include "../core/slang-string-slice-pool.h"
#include "slang-compiler.h"
@@ -10,7 +8,7 @@
namespace Slang
{
-class GLSLExtensionTracker : public ExtensionTracker
+class ShaderExtensionTracker : public ExtensionTracker
{
public:
/// Return the list of extensionsspecified. NOTE that they are specified in the order requested,
@@ -23,11 +21,12 @@ public:
void requireSPIRVVersion(const SemanticVersion& version);
ProfileVersion getRequiredProfileVersion() const { return m_profileVersion; }
- void appendExtensionRequireLines(StringBuilder& builder) const;
+ void appendExtensionRequireLinesForGLSL(StringBuilder& builder) const;
+ void appendExtensionRequireLinesForWGSL(StringBuilder& builder) const;
const SemanticVersion& getSPIRVVersion() const { return m_spirvVersion; }
- GLSLExtensionTracker()
+ ShaderExtensionTracker()
: m_extensionPool(StringSlicePool::Style::Empty)
{
}
@@ -39,6 +38,7 @@ protected:
_getFlag(BaseType::UInt) | _getFlag(BaseType::Void) |
_getFlag(BaseType::Bool);
+ // Only valid for GLSL targets.
ProfileVersion m_profileVersion = ProfileVersion::GLSL_150;
StringSlicePool m_extensionPool;
@@ -47,4 +47,3 @@ protected:
};
} // namespace Slang
-#endif
diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp
index 7a46f45b4..04fb54924 100644
--- a/source/slang/slang-ir-glsl-legalize.cpp
+++ b/source/slang/slang-ir-glsl-legalize.cpp
@@ -1,7 +1,7 @@
// slang-ir-glsl-legalize.cpp
#include "slang-ir-glsl-legalize.h"
-#include "slang-glsl-extension-tracker.h"
+#include "slang-extension-tracker.h"
#include "slang-ir-clone.h"
#include "slang-ir-inst-pass-base.h"
#include "slang-ir-insts.h"
@@ -279,7 +279,7 @@ List<IRInst*> ScalarizedVal::leafAddresses()
struct GLSLLegalizationContext
{
Session* session;
- GLSLExtensionTracker* glslExtensionTracker;
+ ShaderExtensionTracker* glslExtensionTracker;
DiagnosticSink* sink;
Stage stage;
IRFunc* entryPointFunc;
@@ -3654,7 +3654,7 @@ void legalizeEntryPointForGLSL(
IRModule* module,
IRFunc* func,
CodeGenContext* codeGenContext,
- GLSLExtensionTracker* glslExtensionTracker)
+ ShaderExtensionTracker* glslExtensionTracker)
{
auto entryPointDecor = func->findDecoration<IREntryPointDecoration>();
SLANG_ASSERT(entryPointDecor);
@@ -3885,7 +3885,7 @@ void legalizeEntryPointsForGLSL(
IRModule* module,
const List<IRFunc*>& funcs,
CodeGenContext* context,
- GLSLExtensionTracker* glslExtensionTracker)
+ ShaderExtensionTracker* glslExtensionTracker)
{
for (auto func : funcs)
{
diff --git a/source/slang/slang-ir-glsl-legalize.h b/source/slang/slang-ir-glsl-legalize.h
index 2bb7730e7..a3e607ca8 100644
--- a/source/slang/slang-ir-glsl-legalize.h
+++ b/source/slang/slang-ir-glsl-legalize.h
@@ -9,7 +9,7 @@ namespace Slang
class DiagnosticSink;
class Session;
-class GLSLExtensionTracker;
+class ShaderExtensionTracker;
struct IRFunc;
struct IRModule;
@@ -19,7 +19,7 @@ void legalizeEntryPointsForGLSL(
IRModule* module,
const List<IRFunc*>& func,
CodeGenContext* context,
- GLSLExtensionTracker* glslExtensionTracker);
+ ShaderExtensionTracker* glslExtensionTracker);
void legalizeConstantBufferLoadForGLSL(IRModule* module);
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index f1e9624f3..d9c543efa 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -828,6 +828,7 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace)
INST(RequireSPIRVVersionDecoration, requireSPIRVVersion, 1, 0)
INST(RequireGLSLVersionDecoration, requireGLSLVersion, 1, 0)
INST(RequireGLSLExtensionDecoration, requireGLSLExtension, 1, 0)
+ INST(RequireWGSLExtensionDecoration, requireWGSLExtension, 1, 0)
INST(RequireCUDASMVersionDecoration, requireCUDASMVersion, 1, 0)
INST(RequireCapabilityAtomDecoration, requireCapabilityAtom, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index a883172ff..c342039a5 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -422,6 +422,15 @@ struct IRRequireGLSLExtensionDecoration : IRDecoration
UnownedStringSlice getExtensionName() { return getExtensionNameOperand()->getStringSlice(); }
};
+struct IRRequireWGSLExtensionDecoration : IRDecoration
+{
+ IR_LEAF_ISA(RequireWGSLExtensionDecoration)
+
+ IRStringLit* getExtensionNameOperand() { return cast<IRStringLit>(getOperand(0)); }
+
+ UnownedStringSlice getExtensionName() { return getExtensionNameOperand()->getStringSlice(); }
+};
+
struct IRMemoryQualifierSetDecoration : IRDecoration
{
enum
@@ -4792,6 +4801,11 @@ public:
getIntValue(getIntType(), IRIntegerValue(version)));
}
+ void addRequireWGSLExtensionDecoration(IRInst* value, UnownedStringSlice const& extensionName)
+ {
+ addDecoration(value, kIROp_RequireWGSLExtensionDecoration, getStringValue(extensionName));
+ }
+
void addRequirePreludeDecoration(
IRInst* value,
const CapabilitySet& caps,
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index 63b16080f..c672180b7 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -2,7 +2,6 @@
#include "slang-ir-spirv-legalize.h"
#include "slang-emit-base.h"
-#include "slang-glsl-extension-tracker.h"
#include "slang-ir-call-graph.h"
#include "slang-ir-clone.h"
#include "slang-ir-composite-reg-to-mem.h"
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 5c0c3edfb..7f399c366 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -9822,6 +9822,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
getBuilder()->addRequireSPIRVVersionDecoration(inst, versionMod->version);
}
+ for (auto extensionMod : decl->getModifiersOfType<RequiredWGSLExtensionModifier>())
+ {
+ getBuilder()->addRequireWGSLExtensionDecoration(
+ inst,
+ extensionMod->extensionNameToken.getContent());
+ }
for (auto versionMod : decl->getModifiersOfType<RequiredCUDASMVersionModifier>())
{
getBuilder()->addRequireCUDASMVersionDecoration(inst, versionMod->version);
@@ -10634,6 +10640,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
Int(getIntegerLiteralValue(versionMod->versionNumberToken)));
else if (auto spvVersion = as<RequiredSPIRVVersionModifier>(modifier))
getBuilder()->addRequireSPIRVVersionDecoration(irFunc, spvVersion->version);
+ else if (auto wgslExtensionMod = as<RequiredWGSLExtensionModifier>(modifier))
+ getBuilder()->addRequireWGSLExtensionDecoration(
+ irFunc,
+ wgslExtensionMod->extensionNameToken.getContent());
else if (auto cudasmVersion = as<RequiredCUDASMVersionModifier>(modifier))
getBuilder()->addRequireCUDASMVersionDecoration(irFunc, cudasmVersion->version);
else if (as<NonDynamicUniformAttribute>(modifier))
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 81619b700..4a9d0a576 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -8294,6 +8294,17 @@ static NodeBase* parseGLSLVersionModifier(Parser* parser, void* /*userData*/)
return modifier;
}
+static NodeBase* parseWGSLExtensionModifier(Parser* parser, void* /*userData*/)
+{
+ auto modifier = parser->astBuilder->create<RequiredWGSLExtensionModifier>();
+
+ parser->ReadToken(TokenType::LParent);
+ modifier->extensionNameToken = parser->ReadToken(TokenType::Identifier);
+ parser->ReadToken(TokenType::RParent);
+
+ return modifier;
+}
+
static SlangResult parseSemanticVersion(
Parser* parser,
Token& outToken,
@@ -8854,6 +8865,7 @@ static const SyntaxParseInfo g_parseSyntaxEntries[] = {
_makeParseModifier("__glsl_extension", parseGLSLExtensionModifier),
_makeParseModifier("__glsl_version", parseGLSLVersionModifier),
_makeParseModifier("__spirv_version", parseSPIRVVersionModifier),
+ _makeParseModifier("__wgsl_extension", parseWGSLExtensionModifier),
_makeParseModifier("__cuda_sm_version", parseCUDASMVersionModifier),
_makeParseModifier("__builtin_type", parseBuiltinTypeModifier),