diff options
| author | Darren Wihandi <65404740+fairywreath@users.noreply.github.com> | 2025-02-02 15:27:11 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-02-02 12:27:11 -0800 |
| commit | 0a6828572aa4cc1f0f99993e77c321799eb88cca (patch) | |
| tree | d18f1950074958ff3276e303425eed15067ea2bc /source | |
| parent | 2949b786a7f04ad31c113b622039fb5b72bc8622 (diff) | |
Add support for WGSL subgroup operations (#6213)
* initial work
* more work
* more work on glsl intrinsics
* add subgroup broadcast for glsl
* wip add wgsl extension tracking
* enable tests, enable extensions and added some todos
* format and warning fixes
* fix wgsl extension tracker
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
20 files changed, 465 insertions, 214 deletions
diff --git a/source/slang/glsl.meta.slang b/source/slang/glsl.meta.slang index dd1c5a907..ef3bfd683 100644 --- a/source/slang/glsl.meta.slang +++ b/source/slang/glsl.meta.slang @@ -6305,7 +6305,7 @@ void shader_subgroup_preamble() { // GL_KHR_shader_subgroup_basic Built-in Variables -[require(cpp_cuda_glsl_hlsl_spirv, subgroup_basic)] +[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)] void requireGLSLExtForSubgroupBasicBuiltin() { __target_switch { @@ -6317,7 +6317,7 @@ void requireGLSLExtForSubgroupBasicBuiltin() { } } -[require(cpp_cuda_glsl_hlsl_spirv, subgroup_basic)] +[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)] void setupExtForSubgroupBasicBuiltIn() { __target_switch { @@ -6329,7 +6329,7 @@ void setupExtForSubgroupBasicBuiltIn() { } __spirv_version(1.3) -[require(cpp_cuda_glsl_hlsl_spirv, subgroup_ballot)] +[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)] void requireGLSLExtForSubgroupBallotBuiltin() { __target_switch { @@ -6342,7 +6342,7 @@ void requireGLSLExtForSubgroupBallotBuiltin() { } __spirv_version(1.3) -[require(cpp_cuda_glsl_hlsl_spirv, subgroup_ballot)] +[require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)] void setupExtForSubgroupBallotBuiltIn() { __target_switch { @@ -6392,7 +6392,7 @@ public property uint gl_SubgroupID public property uint gl_SubgroupSize { - [require(cpp_cuda_glsl_hlsl_spirv, subgroup_basic)] + [require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)] get { setupExtForSubgroupBasicBuiltIn(); return WaveGetLaneCount(); @@ -6401,7 +6401,7 @@ public property uint gl_SubgroupSize public property uint gl_SubgroupInvocationID { - [require(cpp_cuda_glsl_hlsl_spirv, subgroup_basic)] + [require(cpp_cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)] get { setupExtForSubgroupBasicBuiltIn(); return WaveGetLaneIndex(); @@ -6625,7 +6625,7 @@ public void subgroupMemoryBarrierShared() __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_basic) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_basic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)] public bool subgroupElect() { __target_switch @@ -6635,6 +6635,7 @@ public bool subgroupElect() case glsl: case spirv: case hlsl: + case wgsl: return WaveIsFirstLane(); } @@ -6645,7 +6646,7 @@ public bool subgroupElect() __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_vote) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_vote)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_vote)] public bool subgroupAll(bool value) { return WaveActiveAllTrue(value); @@ -6654,7 +6655,7 @@ public bool subgroupAll(bool value) __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_vote) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_vote)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_vote)] public bool subgroupAny(bool value) { return WaveActiveAnyTrue(value); @@ -6688,7 +6689,7 @@ __generic<T : __BuiltinArithmeticType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] public T subgroupAdd(T value) { shader_subgroup_preamble<T>(); @@ -6699,7 +6700,7 @@ __generic<T : __BuiltinArithmeticType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] public T subgroupMul(T value) { shader_subgroup_preamble<T>(); @@ -6710,7 +6711,7 @@ __generic<T : __BuiltinArithmeticType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] public T subgroupMin(T value) { shader_subgroup_preamble<T>(); @@ -6721,7 +6722,7 @@ __generic<T : __BuiltinArithmeticType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] public T subgroupMax(T value) { shader_subgroup_preamble<T>(); @@ -6731,14 +6732,17 @@ public T subgroupMax(T value) __generic<T : __BuiltinLogicalType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__wgsl_extension(subgroups) [ForceInline] -[require(glsl_spirv, subgroup_arithmetic)] +[require(glsl_spirv_wgsl, subgroup_arithmetic)] public T subgroupAnd(T value) { shader_subgroup_preamble<T>(); __target_switch { - case glsl: __intrinsic_asm "subgroupAnd($0)"; + case glsl: + case wgsl: + __intrinsic_asm "subgroupAnd($0)"; case spirv: if (__isBool<T>()) { return spirv_asm { @@ -6758,14 +6762,17 @@ public T subgroupAnd(T value) __generic<T : __BuiltinLogicalType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__wgsl_extension(subgroups) [ForceInline] -[require(glsl_spirv, subgroup_arithmetic)] +[require(glsl_spirv_wgsl, subgroup_arithmetic)] public T subgroupOr(T value) { shader_subgroup_preamble<T>(); __target_switch { - case glsl: __intrinsic_asm "subgroupOr($0)"; + case glsl: + case wgsl: + __intrinsic_asm "subgroupOr($0)"; case spirv: if (__isBool<T>()) { return spirv_asm { @@ -6785,14 +6792,17 @@ public T subgroupOr(T value) __generic<T : __BuiltinLogicalType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__wgsl_extension(subgroups) [ForceInline] -[require(glsl_spirv, subgroup_arithmetic)] +[require(glsl_spirv_wgsl, subgroup_arithmetic)] public T subgroupXor(T value) { shader_subgroup_preamble<T>(); __target_switch { - case glsl: __intrinsic_asm "subgroupXor($0)"; + case glsl: + case wgsl: + __intrinsic_asm "subgroupXor($0)"; case spirv: if (__isBool<T>()) { return spirv_asm { @@ -6812,14 +6822,16 @@ public T subgroupXor(T value) __generic<T : __BuiltinArithmeticType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__wgsl_extension(subgroups) [ForceInline] -[require(glsl_spirv, subgroup_arithmetic)] +[require(glsl_spirv_wgsl, subgroup_arithmetic)] public T subgroupInclusiveAdd(T value) { shader_subgroup_preamble<T>(); __target_switch { case glsl: + case wgsl: __intrinsic_asm "subgroupInclusiveAdd($0)"; case spirv: if (__isFloat<T>()) @@ -6833,14 +6845,16 @@ public T subgroupInclusiveAdd(T value) __generic<T : __BuiltinArithmeticType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__wgsl_extension(subgroups) [ForceInline] -[require(glsl_spirv, subgroup_arithmetic)] +[require(glsl_spirv_wgsl, subgroup_arithmetic)] public T subgroupInclusiveMul(T value) { shader_subgroup_preamble<T>(); __target_switch { case glsl: + case wgsl: __intrinsic_asm "subgroupInclusiveMul($0)"; case spirv: if (__isFloat<T>()) @@ -6974,7 +6988,7 @@ __generic<T : __BuiltinArithmeticType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] public T subgroupExclusiveAdd(T value) { shader_subgroup_preamble<T>(); @@ -6986,7 +7000,7 @@ __generic<T : __BuiltinArithmeticType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] public T subgroupExclusiveMul(T value) { shader_subgroup_preamble<T>(); @@ -7097,7 +7111,7 @@ __generic<T : __BuiltinArithmeticType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] public vector<T,N> subgroupAdd(vector<T,N> value) { shader_subgroup_preamble<T>(); @@ -7108,7 +7122,7 @@ __generic<T : __BuiltinArithmeticType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] public vector<T,N> subgroupMul(vector<T,N> value) { shader_subgroup_preamble<T>(); @@ -7119,7 +7133,7 @@ __generic<T : __BuiltinArithmeticType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] public vector<T,N> subgroupMin(vector<T,N> value) { shader_subgroup_preamble<T>(); @@ -7130,7 +7144,7 @@ __generic<T : __BuiltinArithmeticType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] public vector<T,N> subgroupMax(vector<T,N> value) { shader_subgroup_preamble<T>(); @@ -7140,14 +7154,18 @@ public vector<T,N> subgroupMax(vector<T,N> value) __generic<T : __BuiltinLogicalType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__wgsl_extension(subgroups) [ForceInline] -[require(glsl_spirv, subgroup_arithmetic)] +[require(glsl_spirv_wgsl, subgroup_arithmetic)] public vector<T,N> subgroupAnd(vector<T,N> value) { shader_subgroup_preamble<T>(); __target_switch { - case glsl: __intrinsic_asm "subgroupAnd($0)"; + case glsl: + case wgsl: + // TODO: Bool inputs are invalid for WGSL, cast them to int or don't allow them to compile. + __intrinsic_asm "subgroupAnd($0)"; case spirv: if (__isBool<T>()) { return spirv_asm { @@ -7168,14 +7186,17 @@ public vector<T,N> subgroupAnd(vector<T,N> value) __generic<T : __BuiltinLogicalType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__wgsl_extension(subgroups) [ForceInline] -[require(glsl_spirv, subgroup_arithmetic)] +[require(glsl_spirv_wgsl, subgroup_arithmetic)] public vector<T,N> subgroupOr(vector<T,N> value) { shader_subgroup_preamble<T>(); __target_switch { - case glsl: __intrinsic_asm "subgroupOr($0)"; + case glsl: + case wgsl: + __intrinsic_asm "subgroupOr($0)"; case spirv: if (__isBool<T>()) { return spirv_asm { @@ -7196,14 +7217,17 @@ public vector<T,N> subgroupOr(vector<T,N> value) __generic<T : __BuiltinLogicalType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__wgsl_extension(subgroups) [ForceInline] -[require(glsl_spirv, subgroup_arithmetic)] +[require(glsl_spirv_wgsl, subgroup_arithmetic)] public vector<T,N> subgroupXor(vector<T,N> value) { shader_subgroup_preamble<T>(); __target_switch { - case glsl: __intrinsic_asm "subgroupXor($0)"; + case glsl: + case wgsl: + __intrinsic_asm "subgroupXor($0)"; case spirv: if (__isBool<T>()) { return spirv_asm { @@ -7223,14 +7247,16 @@ public vector<T,N> subgroupXor(vector<T,N> value) __generic<T : __BuiltinArithmeticType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__wgsl_extension(subgroups) [ForceInline] -[require(glsl_spirv, subgroup_arithmetic)] +[require(glsl_spirv_wgsl, subgroup_arithmetic)] public vector<T,N> subgroupInclusiveAdd(vector<T,N> value) { shader_subgroup_preamble<T>(); __target_switch { case glsl: + case wgsl: __intrinsic_asm "subgroupInclusiveAdd($0)"; case spirv: if (__isFloat<T>()) @@ -7244,14 +7270,16 @@ public vector<T,N> subgroupInclusiveAdd(vector<T,N> value) __generic<T : __BuiltinArithmeticType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) +__wgsl_extension(subgroups) [ForceInline] -[require(glsl_spirv, subgroup_arithmetic)] +[require(glsl_spirv_wgsl, subgroup_arithmetic)] public vector<T,N> subgroupInclusiveMul(vector<T,N> value) { shader_subgroup_preamble<T>(); __target_switch { case glsl: + case wgsl: __intrinsic_asm "subgroupInclusiveMul($0)"; case spirv: if (__isFloat<T>()) @@ -7366,7 +7394,7 @@ __generic<T : __BuiltinArithmeticType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] public vector<T,N> subgroupExclusiveAdd(vector<T,N> value) { shader_subgroup_preamble<T>(); @@ -7378,7 +7406,7 @@ __generic<T : __BuiltinArithmeticType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] public vector<T,N> subgroupExclusiveMul(vector<T,N> value) { shader_subgroup_preamble<T>(); @@ -7488,51 +7516,65 @@ __generic<T : __BuiltinType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_ballot) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_ballot)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)] public T subgroupBroadcast(T value, uint id) { shader_subgroup_preamble<T>(); - return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, id); + __target_switch + { + case wgsl: + // WGSL's intrinsic does not accept non-const ids, do shuffle instead. + __intrinsic_asm "subgroupShuffle"; + default: + return WaveBroadcastLaneAt(value, id); + } } __generic<T : __BuiltinType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_ballot) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_ballot)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)] public vector<T,N> subgroupBroadcast(vector<T,N> value, uint id) { shader_subgroup_preamble<T>(); - return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, id); + __target_switch + { + case wgsl: + // WGSL's intrinsic does not accept non-const ids, do shuffle instead. + __intrinsic_asm "subgroupShuffle"; + default: + return WaveBroadcastLaneAt(value, id); + } } __generic<T : __BuiltinType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_ballot) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_ballot)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)] public T subgroupBroadcastFirst(T value) { shader_subgroup_preamble<T>(); - return WaveMaskReadLaneFirst(WaveGetActiveMask(), value); + return WaveReadLaneFirst(value); } __generic<T : __BuiltinType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_ballot) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_ballot)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)] public vector<T,N> subgroupBroadcastFirst(vector<T,N> value) { shader_subgroup_preamble<T>(); - return WaveMaskReadLaneFirst(WaveGetActiveMask(), value); + return WaveReadLaneFirst(value); } // WaveMaskBallot is not the same; it force trunc's __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_ballot) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_ballot)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)] public uvec4 subgroupBallot(bool value) { return WaveActiveBallot(value); @@ -7713,7 +7755,7 @@ __generic<T : __BuiltinType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_shuffle) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)] public T subgroupShuffle(T value, uint index) { shader_subgroup_preamble<T>(); @@ -7723,13 +7765,15 @@ public T subgroupShuffle(T value, uint index) __generic<T : __BuiltinType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_shuffle) -[require(glsl_spirv, subgroup_shuffle)] +__wgsl_extension(subgroups) +[require(glsl_spirv_wgsl, subgroup_shuffle)] [ForceInline] public T subgroupShuffleXor(T value, uint mask) { shader_subgroup_preamble<T>(); __target_switch { case glsl: + case wgsl: __intrinsic_asm "subgroupShuffleXor($0,$1)"; case spirv: return spirv_asm { @@ -7743,7 +7787,7 @@ __generic<T : __BuiltinType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_shuffle) [ForceInline] -[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)] public vector<T,N> subgroupShuffle(vector<T,N> value, uint index) { shader_subgroup_preamble<T>(); @@ -7753,14 +7797,16 @@ public vector<T,N> subgroupShuffle(vector<T,N> value, uint index) __generic<T : __BuiltinType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_shuffle) +__wgsl_extension(subgroups) [ForceInline] -[require(glsl_spirv, subgroup_shuffle)] +[require(glsl_spirv_wgsl, subgroup_shuffle)] public vector<T,N> subgroupShuffleXor(vector<T,N> value, uint mask) { shader_subgroup_preamble<T>(); __target_switch { case glsl: + case wgsl: __intrinsic_asm "subgroupShuffleXor($0,$1)"; case spirv: return spirv_asm { @@ -7776,14 +7822,16 @@ public vector<T,N> subgroupShuffleXor(vector<T,N> value, uint mask) __generic<T : __BuiltinType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_shuffle_relative) +__wgsl_extension(subgroups) [ForceInline] -[require(glsl_spirv, subgroup_shufflerelative)] +[require(glsl_spirv_wgsl, subgroup_shufflerelative)] public T subgroupShuffleUp(T value, uint delta) { shader_subgroup_preamble<T>(); __target_switch { case glsl: + case wgsl: __intrinsic_asm "subgroupShuffleUp($0, $1)"; case spirv: return spirv_asm { @@ -7796,14 +7844,16 @@ public T subgroupShuffleUp(T value, uint delta) __generic<T : __BuiltinType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_shuffle_relative) +__wgsl_extension(subgroups) [ForceInline] -[require(glsl_spirv, subgroup_shufflerelative)] +[require(glsl_spirv_wgsl, subgroup_shufflerelative)] public T subgroupShuffleDown(T value, uint delta) { shader_subgroup_preamble<T>(); __target_switch { case glsl: + case wgsl: __intrinsic_asm "subgroupShuffleDown($0, $1)"; case spirv: return spirv_asm { @@ -7817,14 +7867,16 @@ public T subgroupShuffleDown(T value, uint delta) __generic<T : __BuiltinType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_shuffle_relative) +__wgsl_extension(subgroups) [ForceInline] -[require(glsl_spirv, subgroup_shufflerelative)] +[require(glsl_spirv_wgsl, subgroup_shufflerelative)] public vector<T,N> subgroupShuffleUp(vector<T,N> value, uint delta) { shader_subgroup_preamble<T>(); __target_switch { case glsl: + case wgsl: __intrinsic_asm "subgroupShuffleUp($0, $1)"; case spirv: return spirv_asm { @@ -7837,14 +7889,16 @@ public vector<T,N> subgroupShuffleUp(vector<T,N> value, uint delta) __generic<T : __BuiltinType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_shuffle_relative) +__wgsl_extension(subgroups) [ForceInline] -[require(glsl_spirv, subgroup_shufflerelative)] +[require(glsl_spirv_wgsl, subgroup_shufflerelative)] public vector<T,N> subgroupShuffleDown(vector<T,N> value, uint delta) { shader_subgroup_preamble<T>(); __target_switch { case glsl: + case wgsl: __intrinsic_asm "subgroupShuffleDown($0, $1)"; case spirv: return spirv_asm { @@ -8161,7 +8215,7 @@ __generic<T : __BuiltinType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv, subgroup_quad)] +[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] public T subgroupQuadSwapHorizontal(T value) { shader_subgroup_preamble<T>(); @@ -8172,7 +8226,7 @@ __generic<T : __BuiltinType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv, subgroup_quad)] +[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] public T subgroupQuadSwapVertical(T value) { shader_subgroup_preamble<T>(); @@ -8183,7 +8237,7 @@ __generic<T : __BuiltinType> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv, subgroup_quad)] +[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] public T subgroupQuadSwapDiagonal(T value) { shader_subgroup_preamble<T>(); @@ -8206,7 +8260,7 @@ __generic<T : __BuiltinType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv, subgroup_quad)] +[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] public vector<T,N> subgroupQuadSwapHorizontal(vector<T,N> value) { shader_subgroup_preamble<T>(); @@ -8217,7 +8271,7 @@ __generic<T : __BuiltinType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv, subgroup_quad)] +[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] public vector<T,N> subgroupQuadSwapVertical(vector<T,N> value) { shader_subgroup_preamble<T>(); @@ -8228,7 +8282,7 @@ __generic<T : __BuiltinType, let N : int> __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_quad) [ForceInline] -[require(glsl_hlsl_spirv, subgroup_quad)] +[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] public vector<T,N> subgroupQuadSwapDiagonal(vector<T,N> value) { shader_subgroup_preamble<T>(); diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index ecab7ff93..3baad5c10 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -14460,42 +14460,44 @@ __generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadLan __generic<T : __BuiltinType> __glsl_extension(GL_KHR_shader_subgroup_quad) __spirv_version(1.3) -[require(glsl_hlsl_spirv, subgroup_quad)] +__wgsl_extension(subgroups) +[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] T QuadReadAcrossX(T localValue) { __target_switch { - case hlsl: - __intrinsic_asm "QuadReadAcrossX"; - case glsl: - __intrinsic_asm "subgroupQuadSwapHorizontal($0)"; + case hlsl: __intrinsic_asm "QuadReadAcrossX"; + case glsl: __intrinsic_asm "subgroupQuadSwapHorizontal($0)"; case spirv: uint direction = 0u; - return spirv_asm { + return spirv_asm + { OpCapability GroupNonUniformQuad; result:$$T = OpGroupNonUniformQuadSwap Subgroup $localValue $direction; }; + case wgsl: __intrinsic_asm "quadSwapX"; } } __generic<T : __BuiltinType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_quad) __spirv_version(1.3) -[require(glsl_hlsl_spirv, subgroup_quad)] +__wgsl_extension(subgroups) +[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] vector<T,N> QuadReadAcrossX(vector<T,N> localValue) { __target_switch { - case hlsl: - __intrinsic_asm "QuadReadAcrossX"; - case glsl: - __intrinsic_asm "subgroupQuadSwapHorizontal($0)"; + case hlsl: __intrinsic_asm "QuadReadAcrossX"; + case glsl: __intrinsic_asm "subgroupQuadSwapHorizontal($0)"; case spirv: uint direction = 0u; - return spirv_asm { + return spirv_asm + { OpCapability GroupNonUniformQuad; result:$$vector<T,N> = OpGroupNonUniformQuadSwap Subgroup $localValue $direction; }; + case wgsl: __intrinsic_asm "quadSwapX"; } } __generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadAcrossX(matrix<T,N,M> localValue); @@ -14504,85 +14506,88 @@ __generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadAcr __generic<T : __BuiltinType> __glsl_extension(GL_KHR_shader_subgroup_quad) __spirv_version(1.3) -[require(glsl_hlsl_spirv, subgroup_quad)] +__wgsl_extension(subgroups) +[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] T QuadReadAcrossY(T localValue) { __target_switch { - case hlsl: - __intrinsic_asm "QuadReadAcrossY"; - case glsl: - __intrinsic_asm "subgroupQuadSwapVertical($0)"; + case hlsl: __intrinsic_asm "QuadReadAcrossY"; + case glsl: __intrinsic_asm "subgroupQuadSwapVertical($0)"; case spirv: uint direction = 1u; - return spirv_asm { + return spirv_asm + { OpCapability GroupNonUniformQuad; result:$$T = OpGroupNonUniformQuadSwap Subgroup $localValue $direction; }; + case wgsl: __intrinsic_asm "quadSwapY"; } } __generic<T : __BuiltinType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_quad) __spirv_version(1.3) -[require(glsl_hlsl_spirv, subgroup_quad)] +__wgsl_extension(subgroups) +[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] vector<T,N> QuadReadAcrossY(vector<T,N> localValue) { __target_switch { - case hlsl: - __intrinsic_asm "QuadReadAcrossY"; - case glsl: - __intrinsic_asm "subgroupQuadSwapVertical($0)"; + case hlsl: __intrinsic_asm "QuadReadAcrossY"; + case glsl: __intrinsic_asm "subgroupQuadSwapVertical($0)"; case spirv: uint direction = 1u; - return spirv_asm { + return spirv_asm + { OpCapability GroupNonUniformQuad; result:$$vector<T,N> = OpGroupNonUniformQuadSwap Subgroup $localValue $direction; }; + case wgsl: __intrinsic_asm "quadSwapY"; } } - __generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadAcrossY(matrix<T,N,M> localValue); /// @category wave __generic<T : __BuiltinType> __glsl_extension(GL_KHR_shader_subgroup_quad) __spirv_version(1.3) -[require(glsl_hlsl_spirv, subgroup_quad)] +__wgsl_extension(subgroups) +[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] T QuadReadAcrossDiagonal(T localValue) { __target_switch { - case hlsl: - __intrinsic_asm "QuadReadAcrossDiagonal"; - case glsl: - __intrinsic_asm "subgroupQuadSwapDiagonal($0)"; + case hlsl: __intrinsic_asm "QuadReadAcrossDiagonal"; + case glsl: __intrinsic_asm "subgroupQuadSwapDiagonal($0)"; case spirv: uint direction = 2u; - return spirv_asm { + return spirv_asm + { OpCapability GroupNonUniformQuad; result:$$T = OpGroupNonUniformQuadSwap Subgroup $localValue $direction; }; + case wgsl: __intrinsic_asm "quadSwapDiagonal"; } } __generic<T : __BuiltinType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_quad) __spirv_version(1.3) -[require(glsl_hlsl_spirv, subgroup_quad)] +__wgsl_extension(subgroups) +[require(glsl_hlsl_spirv_wgsl, subgroup_quad)] vector<T,N> QuadReadAcrossDiagonal(vector<T,N> localValue) { __target_switch { - case hlsl: - __intrinsic_asm "QuadReadAcrossDiagonal"; - case glsl: - __intrinsic_asm "subgroupQuadSwapDiagonal($0)"; + case hlsl: __intrinsic_asm "QuadReadAcrossDiagonal"; + case glsl: __intrinsic_asm "subgroupQuadSwapDiagonal($0)"; case spirv: uint direction = 2u; - return spirv_asm { + return spirv_asm + { OpCapability GroupNonUniformQuad; result:$$vector<T,N> = OpGroupNonUniformQuadSwap Subgroup $localValue $direction; }; + case wgsl: __intrinsic_asm "quadSwapDiagonal"; } } __generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadAcrossDiagonal(matrix<T,N,M> localValue); @@ -14597,16 +14602,19 @@ for (auto opName : kWaveActiveBitOpEntries) { __generic<T : __BuiltinIntegerType> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] T WaveActive$(opName.hlslName)(T expr) { __target_switch { - case glsl: __intrinsic_asm "subgroup$(opName.glslName)($0)"; + case glsl: + case wgsl: + __intrinsic_asm "subgroup$(opName.glslName)"; case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)"; case spirv: return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniform$(opName.spirvName) $$T result Subgroup Reduce $expr}; - default: + case cuda: return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr); } } @@ -14614,22 +14622,25 @@ T WaveActive$(opName.hlslName)(T expr) __generic<T : __BuiltinIntegerType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] vector<T, N> WaveActive$(opName.hlslName)(vector<T, N> expr) { __target_switch { - case glsl: __intrinsic_asm "subgroup$(opName.glslName)($0)"; + case glsl: + case wgsl: + __intrinsic_asm "subgroup$(opName.glslName)"; case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)"; case spirv: return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniform$(opName.spirvName) $$vector<T, N> result Subgroup Reduce $expr}; - default: + case cuda: return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr); } } __generic<T : __BuiltinIntegerType, let N : int, let M : int> -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] matrix<T, N, M> WaveActive$(opName.hlslName)(matrix<T, N, M> expr) { __target_switch @@ -14637,12 +14648,13 @@ matrix<T, N, M> WaveActive$(opName.hlslName)(matrix<T, N, M> expr) case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)"; case glsl: case spirv: + case wgsl: matrix<T,N,M> result; [ForceUnroll] for (int i = 0; i < N; ++i) result[i] = WaveActive$(opName.hlslName)(expr[i]); return result; - default: + case cuda: return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr); } } @@ -14659,12 +14671,15 @@ for (const char* opName : kWaveActiveMinMaxNames) { __generic<T : __BuiltinArithmeticType> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] T WaveActive$(opName)(T expr) { __target_switch { - case glsl: __intrinsic_asm "subgroup$(opName)($0)"; + case glsl: + case wgsl: + __intrinsic_asm "subgroup$(opName)"; case hlsl: __intrinsic_asm "WaveActive$(opName)"; case spirv: if (__isFloat<T>()) @@ -14673,7 +14688,7 @@ T WaveActive$(opName)(T expr) return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformU$(opName) $$T result Subgroup Reduce $expr}; else return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformS$(opName) $$T result Subgroup Reduce $expr}; - default: + case cuda: return WaveMask$(opName)(WaveGetActiveMask(), expr); } } @@ -14681,12 +14696,15 @@ T WaveActive$(opName)(T expr) __generic<T : __BuiltinArithmeticType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] vector<T, N> WaveActive$(opName)(vector<T, N> expr) { __target_switch { - case glsl: __intrinsic_asm "subgroup$(opName)($0)"; + case glsl: + case wgsl: + __intrinsic_asm "subgroup$(opName)"; case hlsl: __intrinsic_asm "WaveActive$(opName)"; case spirv: if (__isFloat<T>()) @@ -14695,13 +14713,13 @@ vector<T, N> WaveActive$(opName)(vector<T, N> expr) return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformU$(opName) $$vector<T, N> result Subgroup Reduce $expr}; else return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformS$(opName) $$vector<T, N> result Subgroup Reduce $expr}; - default: + case cuda: return WaveMask$(opName)(WaveGetActiveMask(), expr); } } __generic<T : __BuiltinArithmeticType, let N : int, let M : int> -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] matrix<T, N, M> WaveActive$(opName)(matrix<T, N, M> expr) { __target_switch @@ -14709,12 +14727,13 @@ matrix<T, N, M> WaveActive$(opName)(matrix<T, N, M> expr) case hlsl: __intrinsic_asm "WaveActive$(opName)"; case glsl: case spirv: + case wgsl: matrix<T, N, M> result; [ForceUnroll] for (int i = 0; i < N; ++i) result[i] = WaveActive$(opName)(expr[i]); return result; - default: + case cuda: return WaveMask$(opName)(WaveGetActiveMask(), expr); } } @@ -14733,7 +14752,8 @@ for (auto opName : kWaveActivProductSumNames) { __generic<T : __BuiltinArithmeticType> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] T WaveActive$(opName.hlslName)(T expr) { __target_switch @@ -14757,7 +14777,8 @@ T WaveActive$(opName.hlslName)(T expr) }; } else return expr; - default: + case wgsl: __intrinsic_asm "subgroup$(opName.glslName)"; + case cuda: return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr); } } @@ -14765,7 +14786,8 @@ T WaveActive$(opName.hlslName)(T expr) __generic<T : __BuiltinArithmeticType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] vector<T,N> WaveActive$(opName.hlslName)(vector<T,N> expr) { __target_switch @@ -14789,13 +14811,14 @@ vector<T,N> WaveActive$(opName.hlslName)(vector<T,N> expr) }; } else return expr; - default: + case wgsl: __intrinsic_asm "subgroup$(opName.glslName)"; + case cuda: return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr); } } __generic<T : __BuiltinArithmeticType, let N : int, let M : int> -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] matrix<T, N, M> WaveActive$(opName.hlslName)(matrix<T, N, M> expr) { __target_switch @@ -14803,12 +14826,13 @@ matrix<T, N, M> WaveActive$(opName.hlslName)(matrix<T, N, M> expr) case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)"; case glsl: case spirv: + case wgsl: matrix<T, N, M> result; [ForceUnroll] for (int i = 0; i < N; ++i) result[i] = WaveActive$(opName.hlslName)(expr[i]); return result; - default: + case cuda: return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr); } } @@ -14877,22 +14901,23 @@ bool WaveActiveAllEqual(matrix<T, N, M> value) /// @category wave __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_vote)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_vote)] bool WaveActiveAllTrue(bool condition) { __target_switch { case glsl: - __intrinsic_asm "subgroupAll($0)"; - case hlsl: - __intrinsic_asm "WaveActiveAllTrue($0)"; + case wgsl: + __intrinsic_asm "subgroupAll"; + case hlsl: __intrinsic_asm "WaveActiveAllTrue($0)"; case spirv: return spirv_asm { OpCapability GroupNonUniformVote; OpGroupNonUniformAll $$bool result Subgroup $condition }; - default: + case cuda: return WaveMaskAllTrue(WaveGetActiveMask(), condition); } } @@ -14900,13 +14925,15 @@ bool WaveActiveAllTrue(bool condition) /// @category wave __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_vote)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_vote)] bool WaveActiveAnyTrue(bool condition) { __target_switch { case glsl: - __intrinsic_asm "subgroupAny($0)"; + case wgsl: + __intrinsic_asm "subgroupAny"; case hlsl: __intrinsic_asm "WaveActiveAnyTrue($0)"; case spirv: @@ -14923,14 +14950,16 @@ bool WaveActiveAnyTrue(bool condition) /// @category wave __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) +__wgsl_extension(subgroups) [NonUniformReturn] -[require(cuda_glsl_hlsl_spirv, subgroup_ballot)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)] uint4 WaveActiveBallot(bool condition) { __target_switch { case glsl: - __intrinsic_asm "subgroupBallot($0)"; + case wgsl: + __intrinsic_asm "subgroupBallot"; case hlsl: __intrinsic_asm "WaveActiveBallot"; case spirv: @@ -15004,23 +15033,23 @@ uint WaveGetLaneIndex() /// @category wave __glsl_extension(GL_KHR_shader_subgroup_basic) __spirv_version(1.3) +__wgsl_extension(subgroups) [NonUniformReturn] -[require(cuda_glsl_hlsl_spirv, subgroup_basic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_basic)] bool WaveIsFirstLane() { __target_switch { - case glsl: - __intrinsic_asm "subgroupElect()"; - case hlsl: - __intrinsic_asm "WaveIsFirstLane()"; + case glsl: __intrinsic_asm "subgroupElect()"; + case hlsl: __intrinsic_asm "WaveIsFirstLane()"; case spirv: return spirv_asm { OpCapability GroupNonUniformBallot; OpGroupNonUniformElect $$bool result Subgroup }; - default: + case wgsl: __intrinsic_asm "subgroupElect"; + case cuda: return WaveMaskIsFirstLane(WaveGetActiveMask()); } } @@ -15059,7 +15088,8 @@ uint _WaveCountBits(uint4 value) __generic<T : __BuiltinArithmeticType> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] T WavePrefixProduct(T expr) { __target_switch @@ -15083,7 +15113,8 @@ T WavePrefixProduct(T expr) }; } else return expr; - default: + case wgsl: __intrinsic_asm "subgroupExclusiveMul"; + case cuda: return WaveMaskPrefixProduct(WaveGetActiveMask(), expr); } } @@ -15092,7 +15123,8 @@ T WavePrefixProduct(T expr) __generic<T : __BuiltinArithmeticType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] vector<T,N> WavePrefixProduct(vector<T,N> expr) { __target_switch @@ -15113,13 +15145,14 @@ vector<T,N> WavePrefixProduct(vector<T,N> expr) }; } else return expr; - default: + case wgsl: __intrinsic_asm "subgroupExclusiveMul"; + case cuda: return WaveMaskPrefixProduct(WaveGetActiveMask(), expr); } } /// @category wave __generic<T : __BuiltinArithmeticType, let N : int, let M : int> -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] matrix<T, N, M> WavePrefixProduct(matrix<T, N, M> expr) { __target_switch @@ -15127,11 +15160,12 @@ matrix<T, N, M> WavePrefixProduct(matrix<T, N, M> expr) case hlsl: __intrinsic_asm "WavePrefixProduct"; case glsl: case spirv: + case wgsl: matrix<T, N, M> result; for (int i = 0; i < N; ++i) result[i] = WavePrefixProduct(expr[i]); return result; - default: + case cuda: return WaveMaskPrefixProduct(WaveGetActiveMask(), expr); } } @@ -15140,7 +15174,8 @@ matrix<T, N, M> WavePrefixProduct(matrix<T, N, M> expr) __generic<T : __BuiltinArithmeticType> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] T WavePrefixSum(T expr) { __target_switch @@ -15161,7 +15196,8 @@ T WavePrefixSum(T expr) }; } else return expr; - default: + case wgsl: __intrinsic_asm "subgroupExclusiveAdd"; + case cuda: return WaveMaskPrefixSum(WaveGetActiveMask(), expr); } } @@ -15169,7 +15205,8 @@ T WavePrefixSum(T expr) __generic<T : __BuiltinArithmeticType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] vector<T,N> WavePrefixSum(vector<T,N> expr) { __target_switch @@ -15190,13 +15227,14 @@ vector<T,N> WavePrefixSum(vector<T,N> expr) }; } else return expr; - default: + case wgsl: __intrinsic_asm "subgroupExclusiveAdd"; + case cuda: return WaveMaskPrefixSum(WaveGetActiveMask(), expr); } } __generic<T : __BuiltinArithmeticType, let N : int, let M : int> -[require(cuda_glsl_hlsl_spirv, subgroup_arithmetic)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_arithmetic)] matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr) { __target_switch @@ -15204,11 +15242,12 @@ matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr) case hlsl: __intrinsic_asm "WavePrefixSum"; case glsl: case spirv: + case wgsl: matrix<T, N, M> result; for (int i = 0; i < N; ++i) result[i] = WavePrefixSum(expr[i]); return result; - default: + case cuda: return WaveMaskPrefixSum(WaveGetActiveMask(), expr); } } @@ -15217,7 +15256,8 @@ matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr) __generic<T : __BuiltinType> __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_ballot)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)] T WaveReadLaneFirst(T expr) { __target_switch @@ -15228,7 +15268,8 @@ T WaveReadLaneFirst(T expr) case hlsl: __intrinsic_asm "WaveReadLaneFirst"; case spirv: return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$T result Subgroup $expr}; - default: + case wgsl: __intrinsic_asm "subgroupBroadcastFirst"; + case cuda: return WaveMaskReadLaneFirst(WaveGetActiveMask(), expr); } } @@ -15236,7 +15277,8 @@ T WaveReadLaneFirst(T expr) __generic<T : __BuiltinType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_ballot)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)] vector<T,N> WaveReadLaneFirst(vector<T,N> expr) { __target_switch @@ -15247,13 +15289,14 @@ vector<T,N> WaveReadLaneFirst(vector<T,N> expr) case hlsl: __intrinsic_asm "WaveReadLaneFirst"; case spirv: return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$vector<T,N> result Subgroup $expr}; - default: + case wgsl: __intrinsic_asm "subgroupBroadcastFirst"; + case cuda: return WaveMaskReadLaneFirst(WaveGetActiveMask(), expr); } } __generic<T : __BuiltinType, let N : int, let M : int> -[require(cuda_glsl_hlsl_spirv, subgroup_ballot)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)] matrix<T,N,M> WaveReadLaneFirst(matrix<T,N,M> expr) { __target_switch @@ -15261,11 +15304,12 @@ matrix<T,N,M> WaveReadLaneFirst(matrix<T,N,M> expr) case hlsl: __intrinsic_asm "WaveReadLaneFirst"; case glsl: case spirv: + case wgsl: matrix<T, N, M> result; for (int i = 0; i < N; ++i) result[i] = WaveReadLaneFirst(expr[i]); return result; - default: + case cuda: return WaveMaskReadLaneFirst(WaveGetActiveMask(), expr); } } @@ -15280,7 +15324,8 @@ matrix<T,N,M> WaveReadLaneFirst(matrix<T,N,M> expr) __generic<T : __BuiltinType> __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_ballot)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)] T WaveBroadcastLaneAt(T value, constexpr int lane) { __target_switch @@ -15292,7 +15337,8 @@ T WaveBroadcastLaneAt(T value, constexpr int lane) case spirv: let ulane = uint(lane); return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcast $$T result Subgroup $value $ulane}; - default: + case wgsl: __intrinsic_asm "subgroupBroadcast"; + case cuda: return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, lane); } } @@ -15301,7 +15347,8 @@ T WaveBroadcastLaneAt(T value, constexpr int lane) __generic<T : __BuiltinType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_ballot)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)] vector<T,N> WaveBroadcastLaneAt(vector<T,N> value, constexpr int lane) { __target_switch @@ -15313,13 +15360,14 @@ vector<T,N> WaveBroadcastLaneAt(vector<T,N> value, constexpr int lane) case spirv: let ulane = uint(lane); return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcast $$vector<T,N> result Subgroup $value $ulane}; - default: + case wgsl: __intrinsic_asm "subgroupBroadcast"; + case cuda: return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, lane); } } __generic<T : __BuiltinType, let N : int, let M : int> -[require(cuda_glsl_hlsl_spirv, subgroup_ballot)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)] matrix<T, N, M> WaveBroadcastLaneAt(matrix<T, N, M> value, constexpr int lane) { __target_switch @@ -15328,11 +15376,12 @@ matrix<T, N, M> WaveBroadcastLaneAt(matrix<T, N, M> value, constexpr int lane) case hlsl: __intrinsic_asm "WaveReadLaneAt"; case glsl: case spirv: + case wgsl: matrix<T, N, M> result; for (int i = 0; i < N; ++i) result[i] = WaveBroadcastLaneAt(value[i], lane); return result; - default: + case cuda: return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, lane); } } @@ -15343,7 +15392,8 @@ matrix<T, N, M> WaveBroadcastLaneAt(matrix<T, N, M> value, constexpr int lane) __generic<T : __BuiltinType> __glsl_extension(GL_KHR_shader_subgroup_shuffle) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)] T WaveReadLaneAt(T value, int lane) { __target_switch @@ -15355,15 +15405,17 @@ T WaveReadLaneAt(T value, int lane) case spirv: let ulane = uint(lane); return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$T result Subgroup $value $ulane}; - default: + case wgsl: __intrinsic_asm "subgroupShuffle"; + case cuda: return WaveMaskReadLaneAt(WaveGetActiveMask(), value, lane); } } __generic<T : __BuiltinType, let N : int> -__spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_shuffle) -[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)] +__spirv_version(1.3) +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)] vector<T,N> WaveReadLaneAt(vector<T,N> value, int lane) { __target_switch @@ -15375,13 +15427,14 @@ vector<T,N> WaveReadLaneAt(vector<T,N> value, int lane) case spirv: let ulane = uint(lane); return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$vector<T,N> result Subgroup $value $ulane}; - default: + case wgsl: __intrinsic_asm "subgroupShuffle"; + case cuda: return WaveMaskReadLaneAt(WaveGetActiveMask(), value, lane); } } __generic<T : __BuiltinType, let N : int, let M : int> -[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)] +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)] matrix<T, N, M> WaveReadLaneAt(matrix<T, N, M> value, int lane) { __target_switch @@ -15390,11 +15443,12 @@ matrix<T, N, M> WaveReadLaneAt(matrix<T, N, M> value, int lane) case hlsl: __intrinsic_asm "WaveReadLaneAt"; case glsl: case spirv: + case wgsl: matrix<T,N,M> result; for (int i = 0; i < N; ++i) result[i] = WaveReadLaneAt(value[i], lane); return result; - default: + case cuda: return WaveMaskReadLaneAt(WaveGetActiveMask(), value, lane); } } @@ -15406,7 +15460,8 @@ matrix<T, N, M> WaveReadLaneAt(matrix<T, N, M> value, int lane) __generic<T : __BuiltinType> __glsl_extension(GL_KHR_shader_subgroup_shuffle) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)] T WaveShuffle(T value, int lane) { __target_switch @@ -15418,7 +15473,8 @@ T WaveShuffle(T value, int lane) case spirv: let ulane = uint(lane); return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$T result Subgroup $value $ulane}; - default: + case wgsl: __intrinsic_asm "subgroupShuffle"; + case cuda: return WaveMaskShuffle(WaveGetActiveMask(), value, lane); } } @@ -15427,7 +15483,8 @@ T WaveShuffle(T value, int lane) __generic<T : __BuiltinType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_shuffle) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_shuffle)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_shuffle)] vector<T,N> WaveShuffle(vector<T,N> value, int lane) { __target_switch @@ -15439,7 +15496,8 @@ vector<T,N> WaveShuffle(vector<T,N> value, int lane) case spirv: let ulane = uint(lane); return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$vector<T,N> result Subgroup $value $ulane}; - default: + case wgsl: __intrinsic_asm "subgroupShuffle"; + case cuda: return WaveMaskShuffle(WaveGetActiveMask(), value, lane); } } @@ -15482,12 +15540,14 @@ uint WavePrefixCountBits(bool value) /// @category wave __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -[require(cuda_glsl_hlsl_spirv, subgroup_ballot)] +__wgsl_extension(subgroups) +[require(cuda_glsl_hlsl_spirv_wgsl, subgroup_ballot)] uint4 WaveGetConvergedMulti() { __target_switch { case glsl: + case wgsl: __intrinsic_asm "subgroupBallot(true)"; case hlsl: __intrinsic_asm "WaveActiveBallot(true)"; case cuda: __intrinsic_asm "make_uint4(__activemask(), 0, 0, 0)"; diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 36dddd15f..cc4901236 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -218,6 +218,15 @@ class RequiredSPIRVVersionModifier : public Modifier }; // A modifier to tag something as an intrinsic that requires +// a certain WGSL extension to be enabled when used +class RequiredWGSLExtensionModifier : public Modifier +{ + SLANG_AST_CLASS(RequiredWGSLExtensionModifier) + + Token extensionNameToken; +}; + +// A modifier to tag something as an intrinsic that requires // a certain CUDA SM version to be enabled when used. Specified as "major.minor" class RequiredCUDASMVersionModifier : public Modifier { diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp index 22a5ec8d1..597a35f4b 100644 --- a/source/slang/slang-ast-print.cpp +++ b/source/slang/slang-ast-print.cpp @@ -434,6 +434,8 @@ void ASTPrinter::addDeclKindPrefix(Decl* decl) continue; if (as<RequiredGLSLExtensionModifier>(modifier)) continue; + if (as<RequiredWGSLExtensionModifier>(modifier)) + continue; if (as<GLSLLayoutModifierGroupMarker>(modifier)) continue; if (as<HLSLLayoutSemantic>(modifier)) diff --git a/source/slang/slang-capabilities.capdef b/source/slang/slang-capabilities.capdef index 3be09b8d3..f98be0e32 100644 --- a/source/slang/slang-capabilities.capdef +++ b/source/slang/slang-capabilities.capdef @@ -331,6 +331,10 @@ alias cuda_hlsl_metal_spirv = cuda | hlsl | metal | spirv; /// [Compound] alias cuda_glsl_hlsl_spirv = cuda | glsl | hlsl | spirv; +/// CUDA, GLSL, HLSL, SPIRV, and WGSL code-gen targets +/// [Compound] +alias cuda_glsl_hlsl_spirv_wgsl = cuda | glsl | hlsl | spirv | wgsl; + /// CUDA, GLSL, HLSL, Metal, and SPIRV code-gen targets /// [Compound] alias cuda_glsl_hlsl_metal_spirv = cuda | glsl | hlsl | metal | spirv; @@ -387,6 +391,10 @@ alias glsl_metal_spirv_wgsl = glsl | metal | spirv | wgsl; /// [Compound] alias glsl_spirv = glsl | spirv; +/// GLSL, SPIRV, and WGSL code-gen targets +/// [Compound] +alias glsl_spirv_wgsl = glsl | spirv | wgsl; + /// HLSL, and SPIRV code-gen targets /// [Compound] alias hlsl_spirv = hlsl | spirv; @@ -1931,13 +1939,18 @@ alias shader5_sm_5_0 = GL_ARB_gpu_shader5 | sm_5_0_version; /// Capabilities required to use GLSL-style subgroup operations 'subgroup_basic' /// [Compound] -alias subgroup_basic = GL_KHR_shader_subgroup_basic | _sm_6_0 | _cuda_sm_7_0; +alias subgroup_basic = GL_KHR_shader_subgroup_basic + | _sm_6_0 + | _cuda_sm_7_0 + | wgsl + ; /// Capabilities required to use GLSL-style subgroup operations 'subgroup_ballot' /// [Compound] alias subgroup_ballot = spirv_1_0 + GL_KHR_shader_subgroup_ballot | glsl + GL_KHR_shader_subgroup_ballot + shader5_sm_5_0 | _sm_6_0 + shader5_sm_5_0 | _cuda_sm_7_0 + shader5_sm_5_0 + | wgsl ; /// Capabilities required to use GLSL-style subgroup operations 'subgroup_ballot_activemask' /// [Compound] @@ -1952,28 +1965,50 @@ alias subgroup_basic_ballot = glsl + GL_KHR_shader_subgroup_basic + subgroup_bal | spirv + GL_KHR_shader_subgroup_basic + subgroup_ballot | hlsl + subgroup_ballot | cuda + subgroup_ballot + | wgsl ; /// Capabilities required to use GLSL-style subgroup operations 'subgroup_vote' /// [Compound] -alias subgroup_vote = GL_KHR_shader_subgroup_vote | _sm_6_0 | _cuda_sm_7_0; +alias subgroup_vote = GL_KHR_shader_subgroup_vote + | _sm_6_0 + | _cuda_sm_7_0 + | wgsl + ; /// Capabilities required to use GLSL-style subgroup operations 'subgroup_vote' /// [Compound] alias shaderinvocationgroup = subgroup_vote; /// Capabilities required to use GLSL-style subgroup operations 'subgroup_arithmetic' /// [Compound] -alias subgroup_arithmetic = GL_KHR_shader_subgroup_arithmetic | _sm_6_0 | _cuda_sm_7_0; +alias subgroup_arithmetic = GL_KHR_shader_subgroup_arithmetic + | _sm_6_0 + | _cuda_sm_7_0 + | wgsl + ; + /// Capabilities required to use GLSL-style subgroup operations 'subgroup_shuffle' /// [Compound] -alias subgroup_shuffle = GL_KHR_shader_subgroup_shuffle | _sm_6_0 | _cuda_sm_7_0; +alias subgroup_shuffle = GL_KHR_shader_subgroup_shuffle + | _sm_6_0 + | _cuda_sm_7_0 + | wgsl + ; /// Capabilities required to use GLSL-style subgroup operations 'subgroup_shuffle_relative' /// [Compound] -alias subgroup_shufflerelative = GL_KHR_shader_subgroup_shuffle_relative | _sm_6_0 | _cuda_sm_7_0; +alias subgroup_shufflerelative = GL_KHR_shader_subgroup_shuffle_relative + | _sm_6_0 + | _cuda_sm_7_0 + | wgsl + ; /// Capabilities required to use GLSL-style subgroup operations 'subgroup_clustered' /// [Compound] alias subgroup_clustered = GL_KHR_shader_subgroup_clustered | _sm_6_0 | _cuda_sm_7_0; /// Capabilities required to use GLSL-style subgroup operations 'subgroup_quad' /// [Compound] -alias subgroup_quad = GL_KHR_shader_subgroup_quad | _sm_6_0 | _cuda_sm_7_0; +alias subgroup_quad = GL_KHR_shader_subgroup_quad + | _sm_6_0 + | _cuda_sm_7_0 + | wgsl + ; /// Capabilities required to use GLSL-style subgroup operations 'subgroup_partitioned' /// [Compound] alias subgroup_partitioned = GL_NV_shader_subgroup_partitioned + subgroup_ballot_activemask | _sm_6_5 | _cuda_sm_7_0; diff --git a/source/slang/slang-compiler.cpp b/source/slang/slang-compiler.cpp index 448534ce8..04ebb753c 100644 --- a/source/slang/slang-compiler.cpp +++ b/source/slang/slang-compiler.cpp @@ -28,7 +28,7 @@ // Artifact output #include "slang-artifact-output-util.h" #include "slang-emit-cuda.h" -#include "slang-glsl-extension-tracker.h" +#include "slang-extension-tracker.h" #include "slang-lower-to-ir.h" #include "slang-mangle.h" #include "slang-parameter-binding.h" @@ -658,7 +658,7 @@ static void _appendCodeWithPath( outCodeBuilder << fileContent << "\n"; } -void trackGLSLTargetCaps(GLSLExtensionTracker* extensionTracker, CapabilitySet const& caps) +void trackGLSLTargetCaps(ShaderExtensionTracker* extensionTracker, CapabilitySet const& caps) { for (auto& conjunctions : caps.getAtomSets()) { @@ -1037,8 +1037,11 @@ static RefPtr<ExtensionTracker> _newExtensionTracker(CodeGenTarget target) } case CodeGenTarget::SPIRV: case CodeGenTarget::GLSL: + case CodeGenTarget::WGSL: + case CodeGenTarget::WGSLSPIRV: + case CodeGenTarget::WGSLSPIRVAssembly: { - return new GLSLExtensionTracker; + return new ShaderExtensionTracker; } default: return nullptr; @@ -1261,7 +1264,7 @@ SlangResult CodeGenContext::emitWithDownstreamForEntryPoints(ComPtr<IArtifact>& if (auto endToEndReq = isPassThroughEnabled()) { // If we are pass through, we may need to set extension tracker state. - if (GLSLExtensionTracker* glslTracker = as<GLSLExtensionTracker>(extensionTracker)) + if (ShaderExtensionTracker* glslTracker = as<ShaderExtensionTracker>(extensionTracker)) { trackGLSLTargetCaps(glslTracker, getTargetCaps()); } @@ -1400,7 +1403,7 @@ SlangResult CodeGenContext::emitWithDownstreamForEntryPoints(ComPtr<IArtifact>& options.flags |= CompileOptions::Flag::EnableFloat16; } } - else if (GLSLExtensionTracker* glslTracker = as<GLSLExtensionTracker>(extensionTracker)) + else if (ShaderExtensionTracker* glslTracker = as<ShaderExtensionTracker>(extensionTracker)) { DownstreamCompileOptions::CapabilityVersion version; version.kind = DownstreamCompileOptions::CapabilityVersion::Kind::SPIRV; diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index 1429139f9..25dab3fb3 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -15,13 +15,13 @@ namespace Slang { -void trackGLSLTargetCaps(GLSLExtensionTracker* extensionTracker, CapabilitySet const& caps); +void trackGLSLTargetCaps(ShaderExtensionTracker* extensionTracker, CapabilitySet const& caps); GLSLSourceEmitter::GLSLSourceEmitter(const Desc& desc) : Super(desc) { m_glslExtensionTracker = - dynamicCast<GLSLExtensionTracker>(desc.codeGenContext->getExtensionTracker()); + dynamicCast<ShaderExtensionTracker>(desc.codeGenContext->getExtensionTracker()); SLANG_ASSERT(m_glslExtensionTracker); } @@ -2997,7 +2997,7 @@ void GLSLSourceEmitter::emitFrontMatterImpl(TargetRequest* targetReq) trackGLSLTargetCaps(m_glslExtensionTracker, targetReq->getTargetCaps()); StringBuilder builder; - m_glslExtensionTracker->appendExtensionRequireLines(builder); + m_glslExtensionTracker->appendExtensionRequireLinesForGLSL(builder); m_writer->emit(builder.getUnownedSlice()); } diff --git a/source/slang/slang-emit-glsl.h b/source/slang/slang-emit-glsl.h index b07b410ca..8308a9954 100644 --- a/source/slang/slang-emit-glsl.h +++ b/source/slang/slang-emit-glsl.h @@ -3,7 +3,7 @@ #define SLANG_EMIT_GLSL_H #include "slang-emit-c-like.h" -#include "slang-glsl-extension-tracker.h" +#include "slang-extension-tracker.h" namespace Slang { @@ -180,7 +180,7 @@ protected: Dictionary<IRInst*, HashSet<IRFunc*>> m_referencingEntryPoints; - RefPtr<GLSLExtensionTracker> m_glslExtensionTracker; + RefPtr<ShaderExtensionTracker> m_glslExtensionTracker; }; } // namespace Slang diff --git a/source/slang/slang-emit-wgsl.cpp b/source/slang/slang-emit-wgsl.cpp index ce60cc2a0..aea766f9f 100644 --- a/source/slang/slang-emit-wgsl.cpp +++ b/source/slang/slang-emit-wgsl.cpp @@ -49,6 +49,14 @@ fn _slang_getNan() -> f32 } )"; +WGSLSourceEmitter::WGSLSourceEmitter(const Desc& desc) + : CLikeSourceEmitter(desc) +{ + m_extensionTracker = + dynamicCast<ShaderExtensionTracker>(desc.codeGenContext->getExtensionTracker()); + SLANG_ASSERT(m_extensionTracker); +} + void WGSLSourceEmitter::emitSwitchCaseSelectorsImpl( const SwitchRegion::Case* const currentCase, const bool isDefault) @@ -1556,6 +1564,10 @@ void WGSLSourceEmitter::emitFrontMatterImpl(TargetRequest* /* targetReq */) m_writer->emit("enable f16;\n"); m_writer->emit("\n"); } + + StringBuilder builder; + m_extensionTracker->appendExtensionRequireLinesForWGSL(builder); + m_writer->emit(builder.getUnownedSlice()); } void WGSLSourceEmitter::emitIntrinsicCallExprImpl( @@ -1626,4 +1638,28 @@ void WGSLSourceEmitter::emitInterpolationModifiersImpl( // https://www.w3.org/TR/WGSL/#interpolation } +void WGSLSourceEmitter::_requireExtension(const UnownedStringSlice& name) +{ + m_extensionTracker->requireExtension(name); +} + +void WGSLSourceEmitter::handleRequiredCapabilitiesImpl(IRInst* inst) +{ + for (auto decoration : inst->getDecorations()) + { + if (const auto extensionDecoration = as<IRRequireWGSLExtensionDecoration>(decoration)) + { + _requireExtension(extensionDecoration->getExtensionName()); + + // TODO: Make this cleaner and only enable this extension if f16 is actually used on the + // subgroup intrinsic. Check float type in meta file. + if (m_f16ExtensionEnabled && extensionDecoration->getExtensionName() == "subgroups") + { + String extName = "subgroups_f16"; + _requireExtension(extName.getUnownedSlice()); + } + } + } +} + } // namespace Slang diff --git a/source/slang/slang-emit-wgsl.h b/source/slang/slang-emit-wgsl.h index 714a722d7..441933b57 100644 --- a/source/slang/slang-emit-wgsl.h +++ b/source/slang/slang-emit-wgsl.h @@ -1,6 +1,7 @@ #pragma once #include "slang-emit-c-like.h" +#include "slang-extension-tracker.h" namespace Slang { @@ -8,10 +9,8 @@ namespace Slang class WGSLSourceEmitter : public CLikeSourceEmitter { public: - WGSLSourceEmitter(const Desc& desc) - : CLikeSourceEmitter(desc) - { - } + explicit WGSLSourceEmitter(const Desc& desc); + virtual bool isResourceTypeBindless(IRType* type) SLANG_OVERRIDE { SLANG_UNUSED(type); @@ -58,10 +57,14 @@ public: EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE; virtual void emitGlobalParamDefaultVal(IRGlobalParam* varDecl) SLANG_OVERRIDE; + virtual void handleRequiredCapabilitiesImpl(IRInst* inst) SLANG_OVERRIDE; + void emit(const AddressSpace addressSpace); virtual bool shouldFoldInstIntoUseSites(IRInst* inst) SLANG_OVERRIDE; + virtual RefObject* getExtensionTracker() SLANG_OVERRIDE { return m_extensionTracker; } + private: bool maybeEmitSystemSemantic(IRInst* inst); @@ -73,7 +76,11 @@ private: const char* getWgslImageFormat(IRTextureTypeBase* type); + void _requireExtension(const UnownedStringSlice& name); + bool m_f16ExtensionEnabled = false; + + RefPtr<ShaderExtensionTracker> m_extensionTracker; }; } // namespace Slang diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 70521c7ee..58376bbc1 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -1400,10 +1400,10 @@ Result linkAndOptimizeIR( case CodeGenTarget::SPIRV: case CodeGenTarget::SPIRVAssembly: { - GLSLExtensionTracker glslExtensionTracker; - GLSLExtensionTracker* glslExtensionTrackerPtr = + ShaderExtensionTracker glslExtensionTracker; + ShaderExtensionTracker* glslExtensionTrackerPtr = options.sourceEmitter - ? as<GLSLExtensionTracker>(options.sourceEmitter->getExtensionTracker()) + ? as<ShaderExtensionTracker>(options.sourceEmitter->getExtensionTracker()) : &glslExtensionTracker; #if 0 diff --git a/source/slang/slang-glsl-extension-tracker.cpp b/source/slang/slang-extension-tracker.cpp index 268f123f7..f3818cbba 100644 --- a/source/slang/slang-glsl-extension-tracker.cpp +++ b/source/slang/slang-extension-tracker.cpp @@ -1,10 +1,10 @@ -// slang-glsl-extension-tracker.cpp -#include "slang-glsl-extension-tracker.h" +// slang-extension-tracker.cpp +#include "slang-extension-tracker.h" namespace Slang { -void GLSLExtensionTracker::appendExtensionRequireLines(StringBuilder& ioBuilder) const +void ShaderExtensionTracker::appendExtensionRequireLinesForGLSL(StringBuilder& ioBuilder) const { for (const auto& extension : m_extensionPool.getSlices()) { @@ -14,7 +14,17 @@ void GLSLExtensionTracker::appendExtensionRequireLines(StringBuilder& ioBuilder) } } -void GLSLExtensionTracker::requireSPIRVVersion(const SemanticVersion& version) +void ShaderExtensionTracker::appendExtensionRequireLinesForWGSL(StringBuilder& ioBuilder) const +{ + for (const auto& extension : m_extensionPool.getSlices()) + { + ioBuilder.append("enable "); + ioBuilder.append(extension); + ioBuilder.append(";\n"); + } +} + +void ShaderExtensionTracker::requireSPIRVVersion(const SemanticVersion& version) { if (version > m_spirvVersion) { @@ -22,7 +32,7 @@ void GLSLExtensionTracker::requireSPIRVVersion(const SemanticVersion& version) } } -void GLSLExtensionTracker::requireVersion(ProfileVersion version) +void ShaderExtensionTracker::requireVersion(ProfileVersion version) { // Check if this profile is newer if ((UInt)version > (UInt)m_profileVersion) @@ -31,7 +41,7 @@ void GLSLExtensionTracker::requireVersion(ProfileVersion version) } } -void GLSLExtensionTracker::requireBaseTypeExtension(BaseType baseType) +void ShaderExtensionTracker::requireBaseTypeExtension(BaseType baseType) { uint32_t bit = 1 << int(baseType); if (m_hasBaseTypeFlags & bit) diff --git a/source/slang/slang-glsl-extension-tracker.h b/source/slang/slang-extension-tracker.h index 08e0c9ef1..7134c4ff5 100644 --- a/source/slang/slang-glsl-extension-tracker.h +++ b/source/slang/slang-extension-tracker.h @@ -1,8 +1,6 @@ -// slang-glsl-extension-tracker.h -#ifndef SLANG_GLSL_EXTENSION_TRACKER_H -#define SLANG_GLSL_EXTENSION_TRACKER_H +// slang-extension-tracker.h +#pragma once -#include "../core/slang-basic.h" #include "../core/slang-semantic-version.h" #include "../core/slang-string-slice-pool.h" #include "slang-compiler.h" @@ -10,7 +8,7 @@ namespace Slang { -class GLSLExtensionTracker : public ExtensionTracker +class ShaderExtensionTracker : public ExtensionTracker { public: /// Return the list of extensionsspecified. NOTE that they are specified in the order requested, @@ -23,11 +21,12 @@ public: void requireSPIRVVersion(const SemanticVersion& version); ProfileVersion getRequiredProfileVersion() const { return m_profileVersion; } - void appendExtensionRequireLines(StringBuilder& builder) const; + void appendExtensionRequireLinesForGLSL(StringBuilder& builder) const; + void appendExtensionRequireLinesForWGSL(StringBuilder& builder) const; const SemanticVersion& getSPIRVVersion() const { return m_spirvVersion; } - GLSLExtensionTracker() + ShaderExtensionTracker() : m_extensionPool(StringSlicePool::Style::Empty) { } @@ -39,6 +38,7 @@ protected: _getFlag(BaseType::UInt) | _getFlag(BaseType::Void) | _getFlag(BaseType::Bool); + // Only valid for GLSL targets. ProfileVersion m_profileVersion = ProfileVersion::GLSL_150; StringSlicePool m_extensionPool; @@ -47,4 +47,3 @@ protected: }; } // namespace Slang -#endif diff --git a/source/slang/slang-ir-glsl-legalize.cpp b/source/slang/slang-ir-glsl-legalize.cpp index 7a46f45b4..04fb54924 100644 --- a/source/slang/slang-ir-glsl-legalize.cpp +++ b/source/slang/slang-ir-glsl-legalize.cpp @@ -1,7 +1,7 @@ // slang-ir-glsl-legalize.cpp #include "slang-ir-glsl-legalize.h" -#include "slang-glsl-extension-tracker.h" +#include "slang-extension-tracker.h" #include "slang-ir-clone.h" #include "slang-ir-inst-pass-base.h" #include "slang-ir-insts.h" @@ -279,7 +279,7 @@ List<IRInst*> ScalarizedVal::leafAddresses() struct GLSLLegalizationContext { Session* session; - GLSLExtensionTracker* glslExtensionTracker; + ShaderExtensionTracker* glslExtensionTracker; DiagnosticSink* sink; Stage stage; IRFunc* entryPointFunc; @@ -3654,7 +3654,7 @@ void legalizeEntryPointForGLSL( IRModule* module, IRFunc* func, CodeGenContext* codeGenContext, - GLSLExtensionTracker* glslExtensionTracker) + ShaderExtensionTracker* glslExtensionTracker) { auto entryPointDecor = func->findDecoration<IREntryPointDecoration>(); SLANG_ASSERT(entryPointDecor); @@ -3885,7 +3885,7 @@ void legalizeEntryPointsForGLSL( IRModule* module, const List<IRFunc*>& funcs, CodeGenContext* context, - GLSLExtensionTracker* glslExtensionTracker) + ShaderExtensionTracker* glslExtensionTracker) { for (auto func : funcs) { diff --git a/source/slang/slang-ir-glsl-legalize.h b/source/slang/slang-ir-glsl-legalize.h index 2bb7730e7..a3e607ca8 100644 --- a/source/slang/slang-ir-glsl-legalize.h +++ b/source/slang/slang-ir-glsl-legalize.h @@ -9,7 +9,7 @@ namespace Slang class DiagnosticSink; class Session; -class GLSLExtensionTracker; +class ShaderExtensionTracker; struct IRFunc; struct IRModule; @@ -19,7 +19,7 @@ void legalizeEntryPointsForGLSL( IRModule* module, const List<IRFunc*>& func, CodeGenContext* context, - GLSLExtensionTracker* glslExtensionTracker); + ShaderExtensionTracker* glslExtensionTracker); void legalizeConstantBufferLoadForGLSL(IRModule* module); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index f1e9624f3..d9c543efa 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -828,6 +828,7 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace) INST(RequireSPIRVVersionDecoration, requireSPIRVVersion, 1, 0) INST(RequireGLSLVersionDecoration, requireGLSLVersion, 1, 0) INST(RequireGLSLExtensionDecoration, requireGLSLExtension, 1, 0) + INST(RequireWGSLExtensionDecoration, requireWGSLExtension, 1, 0) INST(RequireCUDASMVersionDecoration, requireCUDASMVersion, 1, 0) INST(RequireCapabilityAtomDecoration, requireCapabilityAtom, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index a883172ff..c342039a5 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -422,6 +422,15 @@ struct IRRequireGLSLExtensionDecoration : IRDecoration UnownedStringSlice getExtensionName() { return getExtensionNameOperand()->getStringSlice(); } }; +struct IRRequireWGSLExtensionDecoration : IRDecoration +{ + IR_LEAF_ISA(RequireWGSLExtensionDecoration) + + IRStringLit* getExtensionNameOperand() { return cast<IRStringLit>(getOperand(0)); } + + UnownedStringSlice getExtensionName() { return getExtensionNameOperand()->getStringSlice(); } +}; + struct IRMemoryQualifierSetDecoration : IRDecoration { enum @@ -4792,6 +4801,11 @@ public: getIntValue(getIntType(), IRIntegerValue(version))); } + void addRequireWGSLExtensionDecoration(IRInst* value, UnownedStringSlice const& extensionName) + { + addDecoration(value, kIROp_RequireWGSLExtensionDecoration, getStringValue(extensionName)); + } + void addRequirePreludeDecoration( IRInst* value, const CapabilitySet& caps, diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index 63b16080f..c672180b7 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -2,7 +2,6 @@ #include "slang-ir-spirv-legalize.h" #include "slang-emit-base.h" -#include "slang-glsl-extension-tracker.h" #include "slang-ir-call-graph.h" #include "slang-ir-clone.h" #include "slang-ir-composite-reg-to-mem.h" diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 5c0c3edfb..7f399c366 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -9822,6 +9822,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { getBuilder()->addRequireSPIRVVersionDecoration(inst, versionMod->version); } + for (auto extensionMod : decl->getModifiersOfType<RequiredWGSLExtensionModifier>()) + { + getBuilder()->addRequireWGSLExtensionDecoration( + inst, + extensionMod->extensionNameToken.getContent()); + } for (auto versionMod : decl->getModifiersOfType<RequiredCUDASMVersionModifier>()) { getBuilder()->addRequireCUDASMVersionDecoration(inst, versionMod->version); @@ -10634,6 +10640,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> Int(getIntegerLiteralValue(versionMod->versionNumberToken))); else if (auto spvVersion = as<RequiredSPIRVVersionModifier>(modifier)) getBuilder()->addRequireSPIRVVersionDecoration(irFunc, spvVersion->version); + else if (auto wgslExtensionMod = as<RequiredWGSLExtensionModifier>(modifier)) + getBuilder()->addRequireWGSLExtensionDecoration( + irFunc, + wgslExtensionMod->extensionNameToken.getContent()); else if (auto cudasmVersion = as<RequiredCUDASMVersionModifier>(modifier)) getBuilder()->addRequireCUDASMVersionDecoration(irFunc, cudasmVersion->version); else if (as<NonDynamicUniformAttribute>(modifier)) diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 81619b700..4a9d0a576 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -8294,6 +8294,17 @@ static NodeBase* parseGLSLVersionModifier(Parser* parser, void* /*userData*/) return modifier; } +static NodeBase* parseWGSLExtensionModifier(Parser* parser, void* /*userData*/) +{ + auto modifier = parser->astBuilder->create<RequiredWGSLExtensionModifier>(); + + parser->ReadToken(TokenType::LParent); + modifier->extensionNameToken = parser->ReadToken(TokenType::Identifier); + parser->ReadToken(TokenType::RParent); + + return modifier; +} + static SlangResult parseSemanticVersion( Parser* parser, Token& outToken, @@ -8854,6 +8865,7 @@ static const SyntaxParseInfo g_parseSyntaxEntries[] = { _makeParseModifier("__glsl_extension", parseGLSLExtensionModifier), _makeParseModifier("__glsl_version", parseGLSLVersionModifier), _makeParseModifier("__spirv_version", parseSPIRVVersionModifier), + _makeParseModifier("__wgsl_extension", parseWGSLExtensionModifier), _makeParseModifier("__cuda_sm_version", parseCUDASMVersionModifier), _makeParseModifier("__builtin_type", parseBuiltinTypeModifier), |
