From c787c4b82ba76f87069911f203eb192060b5264f Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 28 Aug 2023 21:24:49 -0700 Subject: Add `target_switch` and `intrinsic_asm` statement. (#3154) * Add `target_switch` and `__intrinsic_asm` statement. * Cleanup. * WaveGetActiveMask, WaveGetActiveMask, WaveCountBits. * WaveIsFirstLane. * More wave intrinsics. * wave intrinsics. * merge fix. * Fix. * Fix. * Update test. * update test. * Fix. --------- Co-authored-by: Yong He --- source/slang/core.meta.slang | 83 ++ source/slang/hlsl.meta.slang | 1067 ++++++++++++++++---- source/slang/slang-ast-iterator.h | 15 + source/slang/slang-ast-modifier.h | 1 + source/slang/slang-ast-stmt.h | 21 + source/slang/slang-check-impl.h | 6 + source/slang/slang-check-stmt.cpp | 20 + source/slang/slang-diagnostic-defs.h | 11 +- source/slang/slang-emit-c-like.cpp | 39 +- source/slang/slang-emit-c-like.h | 14 +- source/slang/slang-emit-cpp.cpp | 6 +- source/slang/slang-emit-cpp.h | 2 +- source/slang/slang-emit-cuda.cpp | 6 +- source/slang/slang-emit-cuda.h | 2 +- source/slang/slang-emit-glsl.cpp | 5 +- source/slang/slang-emit-spirv.cpp | 6 + source/slang/slang-ir-inline.cpp | 26 +- source/slang/slang-ir-inst-defs.h | 14 + source/slang/slang-ir-insts.h | 38 +- source/slang/slang-ir-link.cpp | 3 + source/slang/slang-ir-peephole.cpp | 61 +- source/slang/slang-ir-restructure.cpp | 1 + source/slang/slang-ir-sccp.cpp | 9 +- source/slang/slang-ir-specialize-target-switch.cpp | 67 ++ source/slang/slang-ir-specialize-target-switch.h | 15 + source/slang/slang-ir.cpp | 33 +- source/slang/slang-language-server-ast-lookup.cpp | 15 + source/slang/slang-lower-to-ir.cpp | 45 +- source/slang/slang-parser.cpp | 126 ++- source/slang/slang-spirv-core-grammar-embed.cpp | 27 +- 30 files changed, 1505 insertions(+), 279 deletions(-) create mode 100644 source/slang/slang-ir-specialize-target-switch.cpp create mode 100644 source/slang/slang-ir-specialize-target-switch.h (limited to 'source') diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 157b83653..96e6d284a 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2984,6 +2984,89 @@ __prefix T operator ~(T v0) return v0.bitNot(); } +// IR level type traits. + +__generic +__intrinsic_op($(kIROp_undefined)) +T __declVal(); + +__generic +__intrinsic_op($(kIROp_DefaultConstruct)) +T __default(); + +__generic +__intrinsic_op($(kIROp_TypeEquals)) +bool __type_equals_impl(T t, U u); + +__generic +[__unsafeForceInlineEarly] +bool __type_equals(T t, U u) +{ + return __type_equals_impl(__declVal(), __declVal()); +} + +__generic +[__unsafeForceInlineEarly] +bool __type_equals() +{ + return __type_equals_impl(__declVal(), __declVal()); +} + +__generic +__intrinsic_op($(kIROp_IsBool)) +bool __isBool_impl(T t); + +__generic +[__unsafeForceInlineEarly] +bool __isBool() +{ + return __isBool_impl(__declVal()); +} + +__generic +__intrinsic_op($(kIROp_IsInt)) +bool __isInt_impl(T t); + +__generic +[__unsafeForceInlineEarly] +bool __isInt() +{ + return __isInt_impl(__declVal()); +} + +__generic +__intrinsic_op($(kIROp_IsFloat)) +bool __isFloat_impl(T t); + +__generic +[__unsafeForceInlineEarly] +bool __isFloat() +{ + return __isFloat_impl(__declVal()); +} + +__generic +__intrinsic_op($(kIROp_IsUnsignedInt)) +bool __isUnsignedInt_impl(T t); + +__generic +[__unsafeForceInlineEarly] +bool __isUnsignedInt() +{ + return __isUnsignedInt_impl(__declVal()); +} + +__generic +__intrinsic_op($(kIROp_IsSignedInt)) +bool __isSignedInt_impl(T t); + +__generic +[__unsafeForceInlineEarly] +bool __isSignedInt() +{ + return __isSignedInt_impl(__declVal()); +} + // Provide implementations to public generic arithmetic interfaces for builtin types. ${{{{ diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 81f05d3d4..2502ae75f 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -996,29 +996,72 @@ matrix acos(matrix x) // Test if all components are non-zero (HLSL SM 1.0) __generic -__target_intrinsic(cpp, "bool($0)") -__target_intrinsic(cuda, "bool($0)") -__target_intrinsic(glsl, "bool($0)") -__target_intrinsic(spirv, boolean(T), "OpCopyObject resultType resultId _0") -__target_intrinsic(spirv, integral(T), "OpINotEqual resultType resultId _0 const(,0)") -__target_intrinsic(spirv, floating(T), "OpFUnordNotEqual resultType resultId _0 const(,0)") [__readNone] -bool all(T x); +bool all(T x) +{ + __target_switch + { + default: + __intrinsic_asm "bool($0)"; + case hlsl: + __intrinsic_asm "all"; + case spirv: + let zero = __default(); + if (__isInt()) + return spirv_asm + { + OpINotEqual $$bool result $x $zero + }; + else if (__isFloat()) + return spirv_asm + { + OpFUnordNotEqual $$bool result $x $zero + }; + else if (__isBool()) + return __slang_noop_cast(x); + } +} __generic -__target_intrinsic(hlsl) -__target_intrinsic(glsl, "all(bvec$N0($0))") -__target_intrinsic(spirv, boolean(T), "OpAll resultType resultId _0") -// TODO: correct resultType in here -// __target_intrinsic(spirv, integral(T), "%c = OpINotEqual {vector} resultId _0 const(,0); OpAll resultType resultId %c") -// __target_intrinsic(spirv, floating(T), "%c = OpFUnordNotEqual {vector} resultId _0 const(,0); OpAll resultType resultId %c") [__readNone] bool all(vector x) { - bool result = true; - for(int i = 0; i < N; ++i) - result = result && all(x[i]); - return result; + __target_switch + { + case hlsl: + __intrinsic_asm "all"; + case glsl: + __intrinsic_asm "all(bvec$N0($0))"; + case spirv: + if (__isBool()) + return spirv_asm + { + OpAll $$bool result $x + }; + else if (__isInt()) + { + let zero = __default>(); + return spirv_asm + { + OpINotEqual $$vector %castResult $x $zero; + OpAll $$bool result %castResult + }; + } + else + { + let zero = __default(); + return spirv_asm + { + OpFUnordNotEqual $$vector %castResult $x $zero; + OpAll $$bool result %castResult + }; + } + default: + bool result = true; + for(int i = 0; i < N; ++i) + result = result && all(x[i]); + return result; + } } __generic @@ -1045,25 +1088,72 @@ void AllMemoryBarrierWithGroupSync(); // Test if any components is non-zero (HLSL SM 1.0) __generic -__target_intrinsic(cpp, "bool($0)") -__target_intrinsic(cuda, "bool($0)") -__target_intrinsic(glsl, "bool($0)") -__target_intrinsic(spirv, boolean(T), "OpCopyObject resultType resultId _0") -__target_intrinsic(spirv, integral(T), "OpINotEqual resultType resultId _0 const(,0)") -__target_intrinsic(spirv, floating(T), "OpFUnordNotEqual resultType resultId _0 const(,0)") [__readNone] -bool any(T x); +bool any(T x) +{ + __target_switch + { + default: + __intrinsic_asm "bool($0)"; + case hlsl: + __intrinsic_asm "any"; + case spirv: + let zero = __default(); + if (__isInt()) + return spirv_asm + { + OpINotEqual $$bool result $x $zero + }; + else if (__isFloat()) + return spirv_asm + { + OpFUnordNotEqual $$bool result $x $zero + }; + else if (__isBool()) + return __slang_noop_cast(x); + } +} __generic -__target_intrinsic(hlsl) -__target_intrinsic(glsl, "any(bvec$N0($0))") [__readNone] bool any(vector x) { - bool result = false; - for(int i = 0; i < N; ++i) - result = result || any(x[i]); - return result; + __target_switch + { + case hlsl: + __intrinsic_asm "any"; + case glsl: + __intrinsic_asm "any(bvec$N0($0))"; + case spirv: + if (__isBool()) + return spirv_asm + { + OpAny $$bool result $x + }; + else if (__isInt()) + { + let zero = __default>(); + return spirv_asm + { + OpINotEqual $$vector %castResult $x $zero; + OpAny $$bool result %castResult + }; + } + else + { + let zero = __default(); + return spirv_asm + { + OpFUnordNotEqual $$vector %castResult $x $zero; + OpAny $$bool result %castResult + }; + } + default: + bool result = false; + for(int i = 0; i < N; ++i) + result = result || any(x[i]); + return result; + } } __generic @@ -1648,12 +1738,22 @@ matrix cosh(matrix x) } // Population count -__target_intrinsic(hlsl) -__target_intrinsic(glsl, "bitCount") -__target_intrinsic(cuda, "$P_countbits($0)") -__target_intrinsic(cpp, "$P_countbits($0)") [__readNone] -uint countbits(uint value); +uint countbits(uint value) +{ + __target_switch + { + case hlsl: + __intrinsic_asm "countbits"; + case glsl: + __intrinsic_asm "bitCount"; + case cuda: + case cpp: + __intrinsic_asm "$P_countbits($0)"; + case spirv: + return spirv_asm {OpBitCount $$uint result $value}; + } +} // Cross product // TODO: SPIRV does not support integer vectors. @@ -2500,6 +2600,14 @@ __target_intrinsic(spirv, "OpMemoryBarrier const(int,ScopeWorkgroup)" "| MemorySemanticsWorkgroupMemoryMask)") void GroupMemoryBarrier(); +__target_intrinsic(glsl, "subgroupBarrier") +__target_intrinsic(spirv, "OpControlBarrier const(int,ScopeSubgroup) const(int,ScopeSubgroup)" + "const(int, MemorySemanticsAcquireReleaseMask" + "| MemorySemanticsUniformMemoryMask" + "| MemorySemanticsImageMemoryMask" + "| MemorySemanticsAtomicCounterMemoryMask" + "| MemorySemanticsWorkgroupMemoryMask)") +void __subgroupBarrier(); __target_intrinsic(glsl, "groupMemoryBarrier(), barrier()") __target_intrinsic(cuda, "__syncthreads()") @@ -3638,19 +3746,37 @@ vector refract(vector i, vector n, T eta) } // Reverse order of bits -__target_intrinsic(hlsl) -__target_intrinsic(glsl, "bitfieldReverse") -__target_intrinsic(cuda, "$P_reversebits($0)") -__target_intrinsic(cpp, "$P_reversebits($0)") [__readNone] -uint reversebits(uint value); +uint reversebits(uint value) +{ + __target_switch + { + case hlsl: + __intrinsic_asm "reversebits"; + case glsl: + __intrinsic_asm "bitfieldReverse"; + case cuda: + case cpp: + __intrinsic_asm "$P_reversebits($0)"; + case spirv: + return spirv_asm {OpBitReverse $$uint result $value}; + } +} __target_intrinsic(glsl, "bitfieldReverse") __generic [__readNone] vector reversebits(vector value) { - VECTOR_MAP_UNARY(uint, N, reversebits, value); + __target_switch + { + default: + VECTOR_MAP_UNARY(uint, N, reversebits, value); + case glsl: + __intrinsic_asm "bitfieldReverse"; + case spirv: + return spirv_asm {OpBitReverse $$vector result $value}; + } } // Round-to-nearest @@ -4072,58 +4198,161 @@ matrix trunc(matrix x) typedef uint WaveMask; __glsl_extension(GL_KHR_shader_subgroup_ballot) +__spirv_capability(GroupNonUniformBallot) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBallot(true).x") -__target_intrinsic(cuda, "__activemask()") -__target_intrinsic(hlsl, "WaveActiveBallot(true).x") -WaveMask WaveGetConvergedMask(); +WaveMask WaveGetConvergedMask() +{ + __target_switch + { + case glsl: + __intrinsic_asm "subgroupBallot(true).x"; + case hlsl: + __intrinsic_asm "WaveActiveBallot(true).x"; + case cuda: + __intrinsic_asm "__activemask()"; + case spirv: + let _true = true; + let _scope = 3; // subgroup + return (spirv_asm + { + OpGroupNonUniformBallot $$uint4 result $_scope $_true + }).x; + } +} __intrinsic_op($(kIROp_WaveGetActiveMask)) WaveMask __WaveGetActiveMask(); __glsl_extension(GL_KHR_shader_subgroup_ballot) +__spirv_capability(GroupNonUniformBallot) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBallot(true).x") -__target_intrinsic(hlsl, "WaveActiveBallot(true).x") WaveMask WaveGetActiveMask() { - return __WaveGetActiveMask(); + __target_switch + { + case glsl: + __intrinsic_asm "subgroupBallot(true).x"; + case hlsl: + __intrinsic_asm "WaveActiveBallot(true).x"; + case spirv: + let _true = true; + let _scope = 3; // subgroup + return (spirv_asm + { + OpGroupNonUniformBallot $$uint4 result $_scope $_true + }).x; + default: + return __WaveGetActiveMask(); + } } __glsl_extension(GL_KHR_shader_subgroup_basic) +__spirv_capability(GroupNonUniformBallot) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupElect()") -__target_intrinsic(cuda, "(($0 & -$0) == (WarpMask(1) << _getLaneId()))") -__target_intrinsic(hlsl, "WaveIsFirstLane()") -bool WaveMaskIsFirstLane(WaveMask mask); +bool WaveMaskIsFirstLane(WaveMask mask) +{ + __target_switch + { + case glsl: + __intrinsic_asm "subgroupElect()"; + case cuda: + __intrinsic_asm "(($0 & -$0) == (WarpMask(1) << _getLaneId()))"; + case hlsl: + __intrinsic_asm "WaveIsFirstLane()"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm + { + OpGroupNonUniformElect $$bool result $_scope + }; + default: + return false; + } +} __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupAll($1)") -__target_intrinsic(cuda, "(__all_sync($0, $1) != 0)") -__target_intrinsic(hlsl, "WaveActiveAllTrue($1)") -bool WaveMaskAllTrue(WaveMask mask, bool condition); +__spirv_capability(GroupNonUniformBallot) +bool WaveMaskAllTrue(WaveMask mask, bool condition) +{ + __target_switch + { + case glsl: + __intrinsic_asm "subgroupAll($1)"; + case cuda: + __intrinsic_asm "(__all_sync($0, $1) != 0)"; + case hlsl: + __intrinsic_asm "WaveActiveAllTrue($1)"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm + { + OpGroupNonUniformAll $$bool result $_scope $condition + }; + default: + return false; + } +} __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupAny($1)") -__target_intrinsic(cuda, "(__any_sync($0, $1) != 0)") -__target_intrinsic(hlsl, "WaveActiveAnyTrue($1)") -bool WaveMaskAnyTrue(WaveMask mask, bool condition); +__spirv_capability(GroupNonUniformBallot) +bool WaveMaskAnyTrue(WaveMask mask, bool condition) +{ + __target_switch + { + case glsl: + __intrinsic_asm "subgroupAny($1)"; + case cuda: + __intrinsic_asm "(__any_sync($0, $1) != 0)"; + case hlsl: + __intrinsic_asm "WaveActiveAnyTrue($1)"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm + { + OpGroupNonUniformAny $$bool result $_scope $condition + }; + default: + return false; + } +} __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBallot($1).x") -__target_intrinsic(cuda, "__ballot_sync($0, $1)") -__target_intrinsic(hlsl, "WaveActiveBallot($1)") -WaveMask WaveMaskBallot(WaveMask mask, bool condition); +__spirv_capability(GroupNonUniformBallot) +WaveMask WaveMaskBallot(WaveMask mask, bool condition) +{ + __target_switch + { + case glsl: + __intrinsic_asm "subgroupBallot($1).x"; + case cuda: + __intrinsic_asm "__ballot_sync($0, $1)"; + case hlsl: + __intrinsic_asm "WaveActiveBallot($1)"; + case spirv: + let _scope = 3u; // subgroup + return (spirv_asm + { + OpGroupNonUniformBallot $$uint4 result $_scope $condition + }).x; + default: + return 0; + } +} -__glsl_extension(GL_KHR_shader_subgroup_ballot) -__target_intrinsic(cuda, "__popc(__ballot_sync($0, $1))") -__target_intrinsic(hlsl, "WaveActiveCountBits($1)") uint WaveMaskCountBits(WaveMask mask, bool value) { - return _WaveCountBits(WaveActiveBallot(value)); + __target_switch + { + case cuda: + __intrinsic_asm "__popc(__ballot_sync($0, $1))"; + case hlsl: + __intrinsic_asm "WaveActiveCountBits($1)"; + default: + return _WaveCountBits(WaveActiveBallot(value)); + } } // Waits until all warp lanes named in mask have executed a WaveMaskSharedSync (with the same mask) @@ -4141,12 +4370,20 @@ uint WaveMaskCountBits(WaveMask mask, bool value) // It seems this can only mean the active threads are the "threads the program flow would lead to". This implies a lockstep // "straight SIMD" style interpretation. That being the case this op on HLSL is just a memory barrier without any Sync. -__target_intrinsic(cuda, "__syncwarp($0)") -__glsl_extension(GL_KHR_shader_subgroup_basic) -__spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBarrier()") -__target_intrinsic(hlsl, "AllMemoryBarrier()") -void AllMemoryBarrierWithWaveMaskSync(WaveMask mask); +void AllMemoryBarrierWithWaveMaskSync(WaveMask mask) +{ + __target_switch + { + case cuda: + __intrinsic_asm "__syncwarp($0)"; + case hlsl: + __intrinsic_asm "AllMemoryBarrier()"; + case glsl: + case spirv: + __subgroupBarrier(); + return; + } +} // On GLSL, it appears we can't use subgroupMemoryBarrierShared, because it only implies a memory ordering, it does not // imply convergence. For subgroupBarrier we have from the docs.. @@ -4163,25 +4400,51 @@ void AllMemoryBarrierWithWaveMaskSync(WaveMask mask); // also there to inform the compiler on what order reads and writes can take place. This might seem to be silly because of the 'Active' lanes // aspect of HLSL seems to make everything in lock step - but that's not quite so, it only has to apparently be that way as far as the programmers // model appears - divergence could perhaps potentially still happen. -__target_intrinsic(cuda, "__syncwarp($0)") -__glsl_extension(GL_KHR_shader_subgroup_basic) -__spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBarrier()") -__target_intrinsic(hlsl, "GroupMemoryBarrier()") -void GroupMemoryBarrierWithWaveMaskSync(WaveMask mask); -__glsl_extension(GL_KHR_shader_subgroup_basic) -__spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBarrier()") -__target_intrinsic(hlsl, "AllMemoryBarrier()") -void AllMemoryBarrierWithWaveSync(); +void GroupMemoryBarrierWithWaveMaskSync(WaveMask mask) +{ + __target_switch + { + case cuda: + __intrinsic_asm "__syncwarp($0)"; + case hlsl: + __intrinsic_asm "GroupMemoryBarrier()"; + case glsl: + case spirv: + __subgroupBarrier(); + return; + } +} -__glsl_extension(GL_KHR_shader_subgroup_basic) -__spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBarrier()") -__target_intrinsic(hlsl, "GroupMemoryBarrier()") -__target_intrinsic(cuda, "__syncwarp()") -void GroupMemoryBarrierWithWaveSync(); +void AllMemoryBarrierWithWaveSync() +{ + __target_switch + { + case cuda: + __intrinsic_asm "__syncwarp()"; + case hlsl: + __intrinsic_asm "AllMemoryBarrier()"; + case glsl: + case spirv: + __subgroupBarrier(); + return; + } +} + +void GroupMemoryBarrierWithWaveSync() +{ + __target_switch + { + case cuda: + __intrinsic_asm "__syncwarp()"; + case hlsl: + __intrinsic_asm "GroupMemoryBarrier()"; + case glsl: + case spirv: + __subgroupBarrier(); + return; + } +} // NOTE! WaveMaskBroadcastLaneAt is *NOT* standard HLSL // It is provided as access to subgroupBroadcast which can only take a @@ -4193,17 +4456,36 @@ void GroupMemoryBarrierWithWaveSync(); __generic __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBroadcast($1, $2)") -__target_intrinsic(cuda, "__shfl_sync($0, $1, $2)") -__target_intrinsic(hlsl, "WaveReadLaneAt($1, $2)") -T WaveMaskBroadcastLaneAt(WaveMask mask, T value, constexpr int lane); +__spirv_capability(GroupNonUniformBallot) +T WaveMaskBroadcastLaneAt(WaveMask mask, T value, constexpr int lane) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupBroadcast($1, $2)"; + case cuda: __intrinsic_asm "__shfl_sync($0, $1, $2)"; + case hlsl: __intrinsic_asm "WaveReadLaneAt($1, $2)"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm {OpGroupNonUniformBroadcast $$T result $_scope $value $lane}; + } +} + __generic __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBroadcast($1, $2)") -__target_intrinsic(cuda, "_waveShuffleMultiple($0, $1, $2)") -__target_intrinsic(hlsl, "WaveReadLaneAt($1, $2)") -vector WaveMaskBroadcastLaneAt(WaveMask mask, vector value, constexpr int lane); +__spirv_capability(GroupNonUniformBallot) +vector WaveMaskBroadcastLaneAt(WaveMask mask, vector value, constexpr int lane) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupBroadcast($1, $2)"; + case cuda: __intrinsic_asm "_waveShuffleMultiple($0, $1, $2)"; + case hlsl: __intrinsic_asm "WaveReadLaneAt($1, $2)"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm {OpGroupNonUniformBroadcast $$vector result $_scope $value $lane}; + } +} __generic __target_intrinsic(cuda, "_waveShuffleMultiple($0, $1, $2)") __target_intrinsic(hlsl, "WaveReadLaneAt($1, $2)") @@ -4214,17 +4496,35 @@ matrix WaveMaskBroadcastLaneAt(WaveMask mask, matrix value, conste __generic __glsl_extension(GL_KHR_shader_subgroup_shuffle) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupShuffle($1, $2)") -__target_intrinsic(cuda, "__shfl_sync($0, $1, $2)") -__target_intrinsic(hlsl, "WaveReadLaneAt($1, $2)") -T WaveMaskReadLaneAt(WaveMask mask, T value, int lane); +__spirv_capability(GroupNonUniformShuffle) +T WaveMaskReadLaneAt(WaveMask mask, T value, int lane) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupShuffle($1, $2)"; + case cuda: __intrinsic_asm "__shfl_sync($0, $1, $2)"; + case hlsl: __intrinsic_asm "WaveReadLaneAt($1, $2)"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm {OpGroupNonUniformShuffle $$T result $_scope $value $lane}; + } +} __generic +__spirv_version(1.3)__glsl_extension(GL_KHR_shader_subgroup_shuffle) __spirv_version(1.3) -__glsl_extension(GL_KHR_shader_subgroup_shuffle) -__target_intrinsic(glsl, "subgroupShuffle($1, $2)") -__target_intrinsic(cuda, "_waveShuffleMultiple($0, $1, $2)") -__target_intrinsic(hlsl, "WaveReadLaneAt($1, $2)") -vector WaveMaskReadLaneAt(WaveMask mask, vector value, int lane); +__spirv_capability(GroupNonUniformShuffle) +vector WaveMaskReadLaneAt(WaveMask mask, vector value, int lane) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupShuffle($1, $2)"; + case cuda: __intrinsic_asm "_waveShuffleMultiple($0, $1, $2)"; + case hlsl: __intrinsic_asm "WaveReadLaneAt($1, $2)"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm {OpGroupNonUniformShuffle $$vector result $_scope $value $lane}; + } +} __generic __target_intrinsic(cuda, "_waveShuffleMultiple($0, $1, $2)") __target_intrinsic(hlsl, "WaveReadLaneAt($1, $2)") @@ -4234,47 +4534,75 @@ matrix WaveMaskReadLaneAt(WaveMask mask, matrix value, int lane); // which means it will only work on hardware which allows arbitrary laneIds which is not true // in general because it breaks the HLSL standard, which requires it's 'dynamically uniform' across the Wave. __generic -__glsl_extension(GL_KHR_shader_subgroup_shuffle) -__spirv_version(1.3) -__target_intrinsic(glsl, "subgroupShuffle($1, $2)") -__target_intrinsic(cuda, "__shfl_sync($0, $1, $2)") -__target_intrinsic(hlsl, "WaveReadLaneAt($1, $2)") -T WaveMaskShuffle(WaveMask mask, T value, int lane); +[__unsafeForceInlineEarly] +T WaveMaskShuffle(WaveMask mask, T value, int lane) +{ + return WaveMaskReadLaneAt(mask, value, lane); +} __generic -__glsl_extension(GL_KHR_shader_subgroup_shuffle) -__spirv_version(1.3) -__target_intrinsic(glsl, "subgroupShuffle($1, $2)") -__target_intrinsic(cuda, "_waveShuffleMultiple($0, $1, $2)") -__target_intrinsic(hlsl, "WaveReadLaneAt($1, $2)") -vector WaveMaskShuffle(WaveMask mask, vector value, int lane); +[__unsafeForceInlineEarly] +vector WaveMaskShuffle(WaveMask mask, vector value, int lane) +{ + return WaveMaskReadLaneAt(mask, value, lane); +} __generic -__target_intrinsic(cuda, "_waveShuffleMultiple($0, $1, $2)") -__target_intrinsic(hlsl, "WaveReadLaneAt($1, $2)") -matrix WaveMaskShuffle(WaveMask mask, matrix value, int lane); +[__unsafeForceInlineEarly] +matrix WaveMaskShuffle(WaveMask mask, matrix value, int lane) +{ + return WaveMaskReadLaneAt(mask, value, lane); +} __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBallotExclusiveBitCount(subgroupBallot($1))") -__target_intrinsic(cuda, "__popc(__ballot_sync($0, $1) & _getLaneLtMask())") -__target_intrinsic(hlsl, "WavePrefixCountBits($1)") -uint WaveMaskPrefixCountBits(WaveMask mask, bool value); +__spirv_capability(GroupNonUniformBallot) +uint WaveMaskPrefixCountBits(WaveMask mask, bool value) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupBallotExclusiveBitCount(subgroupBallot($1))"; + case cuda: __intrinsic_asm "__popc(__ballot_sync($0, $1) & _getLaneLtMask())"; + case hlsl: __intrinsic_asm "WavePrefixCountBits($1)"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm {OpGroupNonUniformBallotBitCount $$uint result $_scope 2 $value}; + } +} // Across lane ops __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupAnd($1)") -__target_intrinsic(cuda, "_waveAnd($0, $1)") -__target_intrinsic(hlsl, "WaveActiveBitAnd($1)") -T WaveMaskBitAnd(WaveMask mask, T expr); +__spirv_capability(GroupNonUniformArithmetic) +T WaveMaskBitAnd(WaveMask mask, T expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupAnd($1)"; + case cuda: __intrinsic_asm "_waveAnd($0, $1)"; + case hlsl: __intrinsic_asm "WaveActiveBitAnd($1)"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm {OpGroupNonUniformBitwiseAnd $$T result $_scope 0 $expr}; + } +} + __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupAnd($1)") -__target_intrinsic(cuda, "_waveAndMultiple($0, $1)") -__target_intrinsic(hlsl, "WaveActiveBitAnd($1)") -vector WaveMaskBitAnd(WaveMask mask, vector expr); +__spirv_capability(GroupNonUniformArithmetic) +vector WaveMaskBitAnd(WaveMask mask, vector expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupAnd($1)"; + case cuda: __intrinsic_asm "_waveAndMultiple($0, $1)"; + case hlsl: __intrinsic_asm "WaveActiveBitAnd($1)"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm {OpGroupNonUniformBitwiseAnd $$vector result $_scope 0 $expr}; + } +} __generic __target_intrinsic(cuda, "_waveAndMultiple($0, $1)") __target_intrinsic(hlsl, "WaveActiveBitAnd($1)") @@ -4283,17 +4611,35 @@ matrix WaveMaskBitAnd(WaveMask mask, matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupOr($1)") -__target_intrinsic(cuda, "_waveOr($0, $1)") -__target_intrinsic(hlsl, "WaveActiveBitOr($1)") -T WaveMaskBitOr(WaveMask mask, T expr); +__spirv_capability(GroupNonUniformArithmetic) +T WaveMaskBitOr(WaveMask mask, T expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupOr($1)"; + case cuda: __intrinsic_asm "_waveOr($0, $1)"; + case hlsl: __intrinsic_asm "WaveActiveBitOr($1)"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm {OpGroupNonUniformBitwiseOr $$T result $_scope 0 $expr}; + } +} __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupOr($1)") -__target_intrinsic(cuda, "_waveOrMultiple($0, $1)") -__target_intrinsic(hlsl, "WaveActiveBitOr($1)") -vector WaveMaskBitOr(WaveMask mask, vector expr); +__spirv_capability(GroupNonUniformArithmetic) +vector WaveMaskBitOr(WaveMask mask, vector expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupOr($1)"; + case cuda: __intrinsic_asm "_waveOrMultiple($0, $1)"; + case hlsl: __intrinsic_asm "WaveActiveBitOr($1)"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm {OpGroupNonUniformBitwiseOr $$vector result $_scope 0 $expr}; + } +} __generic __target_intrinsic(cuda, "_waveOrMultiple($0, $1)") __target_intrinsic(hlsl, "WaveActiveBitOr($1)") @@ -4302,17 +4648,35 @@ matrix WaveMaskBitOr(WaveMask mask, matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupXor($1)") -__target_intrinsic(cuda, "_waveXor($0, $1)") -__target_intrinsic(hlsl, "WaveActiveBitXor($1)") -T WaveMaskBitXor(WaveMask mask, T expr); +__spirv_capability(GroupNonUniformArithmetic) +T WaveMaskBitXor(WaveMask mask, T expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupXor($1)"; + case cuda: __intrinsic_asm "_waveXor($0, $1)"; + case hlsl: __intrinsic_asm "WaveActiveBitXor($1)"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm {OpGroupNonUniformBitwiseXor $$T result $_scope 0 $expr}; + } +} __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupXor($1)") -__target_intrinsic(cuda, "_waveXorMultiple($0, $1)") -__target_intrinsic(hlsl, "WaveActiveBitXor($1)") -vector WaveMaskBitXor(WaveMask mask, vector expr); +__spirv_capability(GroupNonUniformArithmetic) +vector WaveMaskBitXor(WaveMask mask, vector expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupXor($1)"; + case cuda: __intrinsic_asm "_waveXorMultiple($0, $1)"; + case hlsl: __intrinsic_asm "WaveActiveBitXor($1)"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm {OpGroupNonUniformBitwiseXor $$vector result $_scope 0 $expr}; + } +} __generic __target_intrinsic(cuda, "_waveXorMultiple($0, $1)") __target_intrinsic(hlsl, "WaveActiveBitXor($1)") @@ -4321,17 +4685,46 @@ matrix WaveMaskBitXor(WaveMask mask, matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupMax($1)") -__target_intrinsic(cuda, "_waveMax($0, $1)") -__target_intrinsic(hlsl, "WaveActiveMax($1)") -T WaveMaskMax(WaveMask mask, T expr); +__spirv_capability(GroupNonUniformArithmetic) +T WaveMaskMax(WaveMask mask, T expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupMax($1)"; + case cuda: __intrinsic_asm "_waveMax($0, $1)"; + case hlsl: __intrinsic_asm "WaveActiveMax($1)"; + case spirv: + let _scope = 3u; // subgroup + if (__isFloat()) + return spirv_asm {OpGroupNonUniformFMax $$T result $_scope 0 $expr}; + else if (__isSignedInt()) + return spirv_asm {OpGroupNonUniformSMax $$T result $_scope 0 $expr}; + else if (__isUnsignedInt()) + return spirv_asm {OpGroupNonUniformUMax $$T result $_scope 0 $expr}; + } +} __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupMax($1)") -__target_intrinsic(cuda, "_waveMaxMultiple($0, $1)") -__target_intrinsic(hlsl, "WaveActiveMax($1)") -vector WaveMaskMax(WaveMask mask, vector expr); +__spirv_capability(GroupNonUniformArithmetic) +vector WaveMaskMax(WaveMask mask, vector expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupMax($1)"; + case cuda: __intrinsic_asm "_waveMaxMultiple($0, $1)"; + case hlsl: __intrinsic_asm "WaveActiveMax($1)"; + case spirv: + let _scope = 3u; // subgroup + if (__isFloat()) + return spirv_asm {OpGroupNonUniformFMax $$vector result $_scope 0 $expr}; + else if (__isSignedInt()) + return spirv_asm {OpGroupNonUniformSMax $$vector result $_scope 0 $expr}; + else if (__isUnsignedInt()) + return spirv_asm {OpGroupNonUniformUMax $$vector result $_scope 0 $expr}; + } +} + __generic __target_intrinsic(cuda, "_waveMaxMultiple($0, $1)") __target_intrinsic(hlsl, "WaveActiveMax($1)") @@ -4340,17 +4733,47 @@ matrix WaveMaskMax(WaveMask mask, matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupMin($1)") -__target_intrinsic(cuda, "_waveMin($0, $1)") -__target_intrinsic(hlsl, "WaveActiveMin($1)") -T WaveMaskMin(WaveMask mask, T expr); +__spirv_capability(GroupNonUniformArithmetic) +T WaveMaskMin(WaveMask mask, T expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupMin($1)"; + case cuda: __intrinsic_asm "_waveMin($0, $1)"; + case hlsl: __intrinsic_asm "WaveActiveMin($1)"; + case spirv: + let _scope = 3u; // subgroup + if (__isFloat()) + return spirv_asm {OpGroupNonUniformFMin $$T result $_scope 0 $expr}; + else if (__isSignedInt()) + return spirv_asm {OpGroupNonUniformSMin $$T result $_scope 0 $expr}; + else if (__isUnsignedInt()) + return spirv_asm {OpGroupNonUniformUMin $$T result $_scope 0 $expr}; + } +} + __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupMin($1)") -__target_intrinsic(cuda, "_waveMinMultiple($0, $1)") -__target_intrinsic(hlsl, "WaveActiveMin($1)") -vector WaveMaskMin(WaveMask mask, vector expr); +__spirv_capability(GroupNonUniformArithmetic) +vector WaveMaskMin(WaveMask mask, vector expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupMin($1)"; + case cuda: __intrinsic_asm "_waveMinMultiple($0, $1)"; + case hlsl: __intrinsic_asm "WaveActiveMin($1)"; + case spirv: + let _scope = 3u; // subgroup + if (__isFloat()) + return spirv_asm {OpGroupNonUniformFMin $$vector result $_scope 0 $expr}; + else if (__isSignedInt()) + return spirv_asm {OpGroupNonUniformSMin $$vector result $_scope 0 $expr}; + else if (__isUnsignedInt()) + return spirv_asm {OpGroupNonUniformUMin $$vector result $_scope 0 $expr}; + } +} + __generic __target_intrinsic(cuda, "_waveMinMultiple($0, $1)") __target_intrinsic(hlsl, "WaveActiveMin($1)") @@ -4359,17 +4782,63 @@ matrix WaveMaskMin(WaveMask mask, matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupMul($1)") -__target_intrinsic(cuda, "_waveProduct($0, $1)") -__target_intrinsic(hlsl, "WaveActiveProduct($1)") -T WaveMaskProduct(WaveMask mask, T expr); +__spirv_capability(GroupNonUniformArithmetic) +T WaveMaskProduct(WaveMask mask, T expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupMul($1)"; + case cuda: __intrinsic_asm "_waveProduct($0, $1)"; + case hlsl: __intrinsic_asm "WaveActiveProduct($1)"; + case spirv: + let _scope = 3u; // subgroup + if (__isFloat()) + return spirv_asm {OpGroupNonUniformFMul $$T result $_scope 0 $expr}; + else if (__isSignedInt()) + { + return spirv_asm + { + // TODO: use the correct integer width + OpBitcast $$uint %uvalue $expr; + OpGroupNonUniformIMul $$T %mulResult $_scope 0 %uvalue; + OpBitcast $$T result %mulResult + }; + } + else if (__isUnsignedInt()) + return spirv_asm {OpGroupNonUniformIMul $$T result $_scope 0 $expr}; + } +} + __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupMul($1)") -__target_intrinsic(cuda, "_waveProductMultiple($0, $1)") -__target_intrinsic(hlsl, "WaveActiveProduct($1)") -vector WaveMaskProduct(WaveMask mask, vector expr); +__spirv_capability(GroupNonUniformArithmetic) +vector WaveMaskProduct(WaveMask mask, vector expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupMul($1)"; + case cuda: __intrinsic_asm "_waveProductMultiple($0, $1)"; + case hlsl: __intrinsic_asm "WaveActiveProduct($1)"; + case spirv: + let _scope = 3u; // subgroup + if (__isFloat()) + return spirv_asm {OpGroupNonUniformFMul $$vector result $_scope 0 $expr}; + else if (__isSignedInt()) + { + return spirv_asm + { + // TODO: use the correct integer width + OpBitcast $$vector %uvalue $expr; + OpGroupNonUniformIMul $$vector %mulResult $_scope 0 %uvalue; + OpBitcast $$vector result %mulResult + }; + } + else if (__isUnsignedInt()) + return spirv_asm {OpGroupNonUniformIMul $$vector result $_scope 0 $expr}; + } +} + __generic __target_intrinsic(cuda, "_waveProductMultiple($0, $1)") __target_intrinsic(hlsl, "WaveActiveProduct($1)") @@ -4378,17 +4847,61 @@ matrix WaveMaskProduct(WaveMask mask, matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupAdd($1)") -__target_intrinsic(cuda, "_waveSum($0, $1)") -__target_intrinsic(hlsl, "WaveActiveSum($1)") -T WaveMaskSum(WaveMask mask, T expr); +__spirv_capability(GroupNonUniformArithmetic) +T WaveMaskSum(WaveMask mask, T expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupAdd($1)"; + case cuda: __intrinsic_asm "_waveSum($0, $1)"; + case hlsl: __intrinsic_asm "WaveActiveSum($1)"; + case spirv: + let _scope = 3u; // subgroup + if (__isFloat()) + return spirv_asm {OpGroupNonUniformFAdd $$T result $_scope 0 $expr}; + else if (__isSignedInt()) + { + return spirv_asm + { + // TODO: use the correct integer width + OpBitcast $$uint %uvalue $expr; + OpGroupNonUniformIAdd $$T %mulResult $_scope 0 %uvalue; + OpBitcast $$T result %mulResult + }; + } + else if (__isUnsignedInt()) + return spirv_asm {OpGroupNonUniformIAdd $$T result $_scope 0 $expr}; + } +} __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupAdd($1)") -__target_intrinsic(cuda, "_waveSumMultiple($0, $1)") -__target_intrinsic(hlsl, "WaveActiveSum($1)") -vector WaveMaskSum(WaveMask mask, vector expr); +__spirv_capability(GroupNonUniformArithmetic) +vector WaveMaskSum(WaveMask mask, vector expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupAdd($1)"; + case cuda: __intrinsic_asm "_waveSumMultiple($0, $1)"; + case hlsl: __intrinsic_asm "WaveActiveSum($1)"; + case spirv: + let _scope = 3u; // subgroup + if (__isFloat()) + return spirv_asm {OpGroupNonUniformFAdd $$vector result $_scope 0 $expr}; + else if (__isSignedInt()) + { + return spirv_asm + { + // TODO: use the correct integer width + OpBitcast $$vector %uvalue $expr; + OpGroupNonUniformIAdd $$vector %mulResult $_scope 0 %uvalue; + OpBitcast $$vector result %mulResult + }; + } + else if (__isUnsignedInt()) + return spirv_asm {OpGroupNonUniformIAdd $$vector result $_scope 0 $expr}; + } +} __generic __target_intrinsic(cuda, "_waveSumMultiple($0, $1)") __target_intrinsic(hlsl, "WaveActiveSum($1)") @@ -4397,19 +4910,52 @@ matrix WaveMaskSum(WaveMask mask, matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupAllEqual($1)") +__spirv_capability(GroupNonUniformVote) __cuda_sm_version(7.0) -__target_intrinsic(cuda, "_waveAllEqual($0, $1)") -__target_intrinsic(hlsl, "WaveActiveAllEqual($1)") -bool WaveMaskAllEqual(WaveMask mask, T value); +bool WaveMaskAllEqual(WaveMask mask, T value) +{ + __target_switch + { + case glsl: + __intrinsic_asm "subgroupAllEqual($1)"; + case hlsl: + __intrinsic_asm "WaveActiveAllEqual($1)"; + case cuda: + __intrinsic_asm "_waveAllEqual($0, $1)"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm + { + OpGroupNonUniformAllEqual $$bool result $_scope $value + }; + default: + return false; + } +} __generic __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupAllEqual($1)") __cuda_sm_version(7.0) -__target_intrinsic(cuda, "_waveAllEqualMultiple($0, $1)") -__target_intrinsic(hlsl, "WaveActiveAllEqual($1)") -bool WaveMaskAllEqual(WaveMask mask, vector value); +bool WaveMaskAllEqual(WaveMask mask, vector value) +{ + __target_switch + { + case glsl: + __intrinsic_asm "subgroupAllEqual($1)"; + case hlsl: + __intrinsic_asm "WaveActiveAllEqual($1)"; + case cuda: + __intrinsic_asm "_waveAllEqualMultiple($0, $1)"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm + { + OpGroupNonUniformAllEqual $$bool result $_scope $value + }; + default: + return false; + } +} __generic __cuda_sm_version(7.0) __target_intrinsic(cuda, "_waveAllEqualMultiple($0, $1)") @@ -4772,21 +5318,47 @@ matrix WaveActiveSum(matrix expr) __generic __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupAllEqual($0)") -__target_intrinsic(hlsl) +__spirv_capability(GroupNonUniformVote) bool WaveActiveAllEqual(T value) { - return WaveMaskAllEqual(WaveGetActiveMask(), value); + __target_switch + { + case glsl: + __intrinsic_asm "subgroupAllEqual($0)"; + case hlsl: + __intrinsic_asm "WaveActiveAllEqual"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm + { + OpGroupNonUniformAllEqual $$bool result $_scope $value + }; + default: + return WaveMaskAllEqual(WaveGetActiveMask(), value); + } } __generic __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupAllEqual($0)") -__target_intrinsic(hlsl) +__spirv_capability(GroupNonUniformVote) bool WaveActiveAllEqual(vector value) { - return WaveMaskAllEqual(WaveGetActiveMask(), value); + __target_switch + { + case glsl: + __intrinsic_asm "subgroupAllEqual($0)"; + case hlsl: + __intrinsic_asm "WaveActiveAllEqual"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm + { + OpGroupNonUniformAllEqual $$bool result $_scope $value + }; + default: + return WaveMaskAllEqual(WaveGetActiveMask(), value); + } } __generic @@ -4798,29 +5370,70 @@ bool WaveActiveAllEqual(matrix value) __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupAll($0)") -__target_intrinsic(hlsl) +__spirv_capability(GroupNonUniformVote) bool WaveActiveAllTrue(bool condition) { - return WaveMaskAllTrue(WaveGetActiveMask(), condition); + __target_switch + { + case glsl: + __intrinsic_asm "subgroupAll($0)"; + case hlsl: + __intrinsic_asm "WaveActiveAllTrue($0)"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm + { + OpGroupNonUniformAll $$bool result $_scope $condition + }; + default: + return WaveMaskAllTrue(WaveGetActiveMask(), condition); + } } __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) +__spirv_capability(GroupNonUniformVote) __target_intrinsic(glsl, "subgroupAny($0)") __target_intrinsic(hlsl) bool WaveActiveAnyTrue(bool condition) { - return WaveMaskAnyTrue(WaveGetActiveMask(), condition); + __target_switch + { + case glsl: + __intrinsic_asm "subgroupAny($0)"; + case hlsl: + __intrinsic_asm "WaveActiveAnyTrue($0)"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm + { + OpGroupNonUniformAny $$bool result $_scope $condition + }; + default: + return WaveMaskAnyTrue(WaveGetActiveMask(), condition); + } } __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBallot($0)") -__target_intrinsic(hlsl) +__spirv_capability(GroupNonUniformBallot) uint4 WaveActiveBallot(bool condition) { - return WaveMaskBallot(WaveGetActiveMask(), condition); + __target_switch + { + case glsl: + __intrinsic_asm "subgroupBallot($0)"; + case hlsl: + __intrinsic_asm "WaveActiveBallot"; + case spirv: + let _scope = 3u; // Subgroup + return spirv_asm + { + OpGroupNonUniformBallot $$uint4 result $_scope $condition + }; + default: + return WaveMaskBallot(WaveGetActiveMask(), condition); + } } __target_intrinsic(hlsl) @@ -4842,27 +5455,51 @@ __target_intrinsic(cuda, "_getLaneId()") uint WaveGetLaneIndex(); __glsl_extension(GL_KHR_shader_subgroup_basic) +__spirv_capability(GroupNonUniformBallot) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupElect()") -__target_intrinsic(hlsl) bool WaveIsFirstLane() { - return WaveMaskIsFirstLane(WaveGetActiveMask()); + __target_switch + { + case glsl: + __intrinsic_asm "subgroupElect()"; + case hlsl: + __intrinsic_asm "WaveIsFirstLane()"; + case spirv: + let _scope = 3u; // subgroup + return spirv_asm + { + OpGroupNonUniformElect $$bool result $_scope + }; + default: + return WaveMaskIsFirstLane(WaveGetActiveMask()); + } } // It's useful to have a wave uint4 version of countbits, because some wave functions return uint4. // This implementation tries to limit the amount of work required by the actual lane count. +__spirv_capability(GroupNonUniformBallot) uint _WaveCountBits(uint4 value) { - // Assume since WaveGetLaneCount should be known at compile time, the branches will hopefully boil away - const uint waveLaneCount = WaveGetLaneCount(); - switch ((waveLaneCount - 1) / 32) + __target_switch { - default: - case 0: return countbits(value.x); - case 1: return countbits(value.x) + countbits(value.y); - case 2: return countbits(value.x) + countbits(value.y) + countbits(value.z); - case 3: return countbits(value.x) + countbits(value.y) + countbits(value.z) + countbits(value.w); + case spirv: + let _scope = 3u; // Subgroup + return spirv_asm + { + OpGroupNonUniformBallotBitCount $$uint result $_scope 0 $value + }; + default: + // Assume since WaveGetLaneCount should be known at compile time, the branches will hopefully boil away + const uint waveLaneCount = WaveGetLaneCount(); + switch ((waveLaneCount - 1) / 32) + { + default: + case 0: return countbits(value.x); + case 1: return countbits(value.x) + countbits(value.y); + case 2: return countbits(value.x) + countbits(value.y) + countbits(value.z); + case 3: return countbits(value.x) + countbits(value.y) + countbits(value.z) + countbits(value.w); + } } } diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index 76effa608..d38fd9374 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -372,6 +372,21 @@ struct ASTIterator iterator->visitExpr(stmt->expr); } + void visitTargetSwitchStmt(TargetSwitchStmt* stmt) + { + iterator->maybeDispatchCallback(stmt); + for (auto c : stmt->targetCases) + dispatchIfNotNull(c); + } + + void visitTargetCaseStmt(TargetCaseStmt* stmt) + { + iterator->maybeDispatchCallback(stmt); + iterator->visitStmt(stmt->body); + } + + void visitIntrinsicAsmStmt(IntrinsicAsmStmt*) {} + void visitDefaultStmt(DefaultStmt* stmt) { iterator->maybeDispatchCallback(stmt); } void visitIfStmt(IfStmt* stmt) diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index f5809ff4b..b890343fc 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -137,6 +137,7 @@ class RequiredSPIRVCapabilityModifier : public Modifier { SLANG_AST_CLASS(RequiredSPIRVCapabilityModifier) int32_t capability; + String extensionName; }; // A modifier to tag something as an intrinsic that requires diff --git a/source/slang/slang-ast-stmt.h b/source/slang/slang-ast-stmt.h index b18eb077d..af1fe9ec1 100644 --- a/source/slang/slang-ast-stmt.h +++ b/source/slang/slang-ast-stmt.h @@ -94,6 +94,27 @@ class SwitchStmt : public BreakableStmt Stmt* body = nullptr; }; +class TargetCaseStmt : public Stmt +{ + SLANG_AST_CLASS(TargetCaseStmt) + int32_t capability; + Stmt* body = nullptr; +}; + +class TargetSwitchStmt : public Stmt +{ + SLANG_AST_CLASS(TargetSwitchStmt) + + List targetCases; +}; + +class IntrinsicAsmStmt : public Stmt +{ + SLANG_AST_CLASS(IntrinsicAsmStmt) + + String asmText; +}; + // A statement that is expected to appear lexically nested inside // some other construct, and thus needs to keep track of the // outer statement that it is associated with... diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 1e3bde4de..c4a7b3e6d 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2470,6 +2470,12 @@ namespace Slang void visitCaseStmt(CaseStmt* stmt); + void visitTargetSwitchStmt(TargetSwitchStmt* stmt); + + void visitTargetCaseStmt(TargetCaseStmt* stmt); + + void visitIntrinsicAsmStmt(IntrinsicAsmStmt*) {} + void visitDefaultStmt(DefaultStmt* stmt); void visitIfStmt(IfStmt *stmt); diff --git a/source/slang/slang-check-stmt.cpp b/source/slang/slang-check-stmt.cpp index 4b4257f75..bbe3e51bd 100644 --- a/source/slang/slang-check-stmt.cpp +++ b/source/slang/slang-check-stmt.cpp @@ -264,6 +264,26 @@ namespace Slang stmt->parentStmt = switchStmt; } + void SemanticsStmtVisitor::visitTargetSwitchStmt(TargetSwitchStmt* stmt) + { + WithOuterStmt subContext(this, stmt); + + for (auto caseStmt : stmt->targetCases) + subContext.checkStmt(caseStmt); + } + + void SemanticsStmtVisitor::visitTargetCaseStmt(TargetCaseStmt* stmt) + { + auto switchStmt = FindOuterStmt(); + + if (!switchStmt) + { + getSink()->diagnose(stmt, Diagnostics::caseOutsideSwitch); + } + WithOuterStmt subContext(this, stmt); + subContext.checkStmt(stmt->body); + } + void SemanticsStmtVisitor::visitDefaultStmt(DefaultStmt* stmt) { auto switchStmt = FindOuterStmt(); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 1b13ae636..b562ac880 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -246,13 +246,14 @@ DIAGNOSTIC(29101, Error, misplacedResultIdMarker, "the result-id marker must onl DIAGNOSTIC(29102, Note, considerOpCopyObject, "consider adding an OpCopyObject instruction to the end of the spirv_asm expression") DIAGNOSTIC(29103, Note, noSuchAddress, "unable to take the address of this address-of asm operand") DIAGNOSTIC(29104, Error, spirvInstructionWithoutResultId, "cannot use this 'x = $0...' syntax because $0 does not have a operand") -DIAGNOSTIC(29104, Error, spirvInstructionWithoutResultTypeId, "cannot use this 'x : = $0...' syntax because $0 does not have a operand") +DIAGNOSTIC(29105, Error, spirvInstructionWithoutResultTypeId, "cannot use this 'x : = $0...' syntax because $0 does not have a operand") // This is a warning because we trust that people using the spirv_asm block know what they're doing -DIAGNOSTIC(29104, Warning, spirvInstructionWithTooManyOperands, "too many operands for $0 (expected max $1), did you forget a semicolon?") -DIAGNOSTIC(29104, Error, spirvUnableToResolveName, "unknown SPIR-V identifier $0, it's not a known enumerator or opcode") -DIAGNOSTIC(29104, Error, spirvNonConstantBitwiseOr, "only integer literals and enum names can appear in a bitwise or expression") -DIAGNOSTIC(29104, Error, spirvOperandRange, "Literal ints must be in the range 0 to 0xffffffff") +DIAGNOSTIC(29106, Warning, spirvInstructionWithTooManyOperands, "too many operands for $0 (expected max $1), did you forget a semicolon?") +DIAGNOSTIC(29107, Error, spirvUnableToResolveName, "unknown SPIR-V identifier $0, it's not a known enumerator or opcode") +DIAGNOSTIC(29108, Error, spirvNonConstantBitwiseOr, "only integer literals and enum names can appear in a bitwise or expression") +DIAGNOSTIC(29109, Error, spirvOperandRange, "Literal ints must be in the range 0 to 0xffffffff") +DIAGNOSTIC(29110, Error, unknownTargetName, "unknown target name '$0'") // // 3xxxx - Semantic analysis diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index 75a15d0c9..fe25c3f19 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -970,9 +970,10 @@ String CLikeSourceEmitter::generateName(IRInst* inst) // If the instruction names something // that should be emitted as a target intrinsic, // then use that name instead. - if(auto intrinsicDecoration = findBestTargetIntrinsicDecoration(inst)) + UnownedStringSlice intrinsicDef; + if(findTargetIntrinsicDefinition(inst, intrinsicDef)) { - return String(intrinsicDecoration->getDefinition()); + return String(intrinsicDef); } // If the instruction reprsents one of the "magic" declarations @@ -1434,7 +1435,8 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) // This is significant, because we can within a target intrinsics definition multiple accesses to the same // parameter. This is not indicated into the call, and can lead to output code computes something multiple // times as it is folding into the expression of the the target intrinsic, which we don't want. - if (auto targetIntrinsicDecoration = findBestTargetIntrinsicDecoration(funcValue)) + UnownedStringSlice intrinsicDef; + if (findTargetIntrinsicDefinition(funcValue, intrinsicDef)) { // Find the index of the original instruction, to see if it's multiply used. IRUse* args = callInst->getArgs(); @@ -1443,7 +1445,7 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) // Look through the slice to seeing how many times this parameters is used (signified via the $0...$9) { - UnownedStringSlice slice = targetIntrinsicDecoration->getDefinition(); + UnownedStringSlice slice = intrinsicDef; const char* cur = slice.begin(); const char* end = slice.end(); @@ -1705,7 +1707,7 @@ IRTargetSpecificDecoration* CLikeSourceEmitter::findBestTargetDecoration(IRInst* return Slang::findBestTargetDecoration(inInst, getTargetCaps()); } -IRTargetIntrinsicDecoration* CLikeSourceEmitter::findBestTargetIntrinsicDecoration(IRInst* inInst) +IRTargetIntrinsicDecoration* CLikeSourceEmitter::_findBestTargetIntrinsicDecoration(IRInst* inInst) { return as(findBestTargetDecoration(inInst)); } @@ -1745,14 +1747,14 @@ IRTargetIntrinsicDecoration* CLikeSourceEmitter::findBestTargetIntrinsicDecorati } -void CLikeSourceEmitter::emitIntrinsicCallExpr(IRCall* inst, IRTargetIntrinsicDecoration* targetIntrinsic, EmitOpInfo const& inOuterPrec) +void CLikeSourceEmitter::emitIntrinsicCallExpr(IRCall* inst, UnownedStringSlice intrinsicDefinition, EmitOpInfo const& inOuterPrec) { - emitIntrinsicCallExprImpl(inst, targetIntrinsic, inOuterPrec); + emitIntrinsicCallExprImpl(inst, intrinsicDefinition, inOuterPrec); } void CLikeSourceEmitter::emitIntrinsicCallExprImpl( IRCall* inst, - IRTargetIntrinsicDecoration* targetIntrinsic, + UnownedStringSlice intrinsicDefinition, EmitOpInfo const& inOuterPrec) { auto outerPrec = inOuterPrec; @@ -1764,7 +1766,7 @@ void CLikeSourceEmitter::emitIntrinsicCallExprImpl( args++; argCount--; - auto name = targetIntrinsic->getDefinition(); + auto name = intrinsicDefinition; if(isOrdinaryName(name)) { @@ -1876,6 +1878,11 @@ void CLikeSourceEmitter::emitComInterfaceCallExpr(IRCall* inst, EmitOpInfo const maybeCloseParens(needClose); } +bool CLikeSourceEmitter::findTargetIntrinsicDefinition(IRInst* callee, UnownedStringSlice& outDefinition) +{ + return Slang::findTargetIntrinsicDefinition(callee, getTargetCaps(), outDefinition); +} + void CLikeSourceEmitter::emitCallExpr(IRCall* inst, EmitOpInfo outerPrec) { auto funcValue = inst->getOperand(0); @@ -1909,9 +1916,10 @@ void CLikeSourceEmitter::emitCallExpr(IRCall* inst, EmitOpInfo outerPrec) // We want to detect any call to an intrinsic operation, // that we can emit it directly without mangling, etc. - if(auto targetIntrinsic = findBestTargetIntrinsicDecoration(funcValue)) + UnownedStringSlice intrinsicDefinition; + if (findTargetIntrinsicDefinition(funcValue, intrinsicDefinition)) { - emitIntrinsicCallExpr(inst, targetIntrinsic, outerPrec); + emitIntrinsicCallExpr(inst, intrinsicDefinition, outerPrec); } else { @@ -3325,13 +3333,14 @@ IREntryPointLayout* CLikeSourceEmitter::asEntryPoint(IRFunc* func) return nullptr; } -bool CLikeSourceEmitter::isTargetIntrinsic(IRFunc* func) +bool CLikeSourceEmitter::isTargetIntrinsic(IRInst* inst) { // A function is a target intrinsic if and only if // it has a suitable decoration marking it as a // target intrinsic for the current compilation target. // - return findBestTargetIntrinsicDecoration(func) != nullptr; + UnownedStringSlice intrinsicDef; + return findTargetIntrinsicDefinition(inst, intrinsicDef); } void CLikeSourceEmitter::emitFunc(IRFunc* func) @@ -3373,7 +3382,7 @@ void CLikeSourceEmitter::emitStruct(IRStructType* structType) { // If the selected `struct` type is actually an intrinsic // on our target, then we don't want to emit anything at all. - if(const auto intrinsicDecoration = findBestTargetIntrinsicDecoration(structType)) + if(isTargetIntrinsic(structType)) { return; } @@ -3429,7 +3438,7 @@ void CLikeSourceEmitter::emitClass(IRClassType* classType) { // If the selected `class` type is actually an intrinsic // on our target, then we don't want to emit anything at all. - if (const auto intrinsicDecoration = findBestTargetIntrinsicDecoration(classType)) + if (isTargetIntrinsic(classType)) { return; } diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 420132a5d..02ab28028 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -355,7 +355,11 @@ public: void emitInstResultDecl(IRInst* inst); IRTargetSpecificDecoration* findBestTargetDecoration(IRInst* inst); - IRTargetIntrinsicDecoration* findBestTargetIntrinsicDecoration(IRInst* inst); + IRTargetIntrinsicDecoration* _findBestTargetIntrinsicDecoration(IRInst* inst); + + // Find the definition of a target intrinsic either from __target_intrinsic decoration, or from + // a genericAsm inst in the function body. + bool findTargetIntrinsicDefinition(IRInst* callee, UnownedStringSlice& outDefinition); // Check if the string being used to define a target intrinsic // is an "ordinary" name, such that we can simply emit a call @@ -366,7 +370,7 @@ public: void emitIntrinsicCallExpr( IRCall* inst, - IRTargetIntrinsicDecoration* targetIntrinsic, + UnownedStringSlice intrinsicDefinition, EmitOpInfo const& inOuterPrec); void emitCallExpr(IRCall* inst, EmitOpInfo outerPrec); @@ -409,10 +413,10 @@ public: IREntryPointLayout* asEntryPoint(IRFunc* func); - // Detect if the given IR function represents a + // Detect if the given IR function/type represents a // declaration of an intrinsic/builtin for the // current code-generation target. - bool isTargetIntrinsic(IRFunc* func); + bool isTargetIntrinsic(IRInst* func); void emitFunc(IRFunc* func); void emitFuncDecorations(IRFunc* func) { emitFuncDecorationsImpl(func); } @@ -525,7 +529,7 @@ public: virtual void emitVarExpr(IRInst* inst, EmitOpInfo const& outerPrec); virtual void emitOperandImpl(IRInst* inst, EmitOpInfo const& outerPrec); virtual void emitParamTypeImpl(IRType* type, String const& name); - virtual void emitIntrinsicCallExprImpl(IRCall* inst, IRTargetIntrinsicDecoration* targetIntrinsic, EmitOpInfo const& inOuterPrec); + virtual void emitIntrinsicCallExprImpl(IRCall* inst, UnownedStringSlice intrinsicDefinition, EmitOpInfo const& inOuterPrec); virtual void emitFunctionPreambleImpl(IRInst* inst) { SLANG_UNUSED(inst); } virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) { SLANG_UNUSED(decl); } virtual void emitIfDecorationsImpl(IRIfElse* ifInst) { SLANG_UNUSED(ifInst); } diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp index 0b01bdf4e..ec797bcf4 100644 --- a/source/slang/slang-emit-cpp.cpp +++ b/source/slang/slang-emit-cpp.cpp @@ -1089,7 +1089,7 @@ void CPPSourceEmitter::_emitType(IRType* type, DeclaratorInfo* declarator) void CPPSourceEmitter::emitIntrinsicCallExprImpl( IRCall* inst, - IRTargetIntrinsicDecoration* targetIntrinsic, + UnownedStringSlice intrinsicDefinition, EmitOpInfo const& inOuterPrec) { // TODO: Much of this logic duplicates code that is already @@ -1104,7 +1104,7 @@ void CPPSourceEmitter::emitIntrinsicCallExprImpl( Index argCount = Index(inst->getArgCount()); auto args = inst->getArgs(); - auto name = targetIntrinsic->getDefinition(); + auto name = intrinsicDefinition; // We will special-case some names here, that // represent callable declarations that aren't @@ -1154,7 +1154,7 @@ void CPPSourceEmitter::emitIntrinsicCallExprImpl( } // Use default impl (which will do intrinsic special macro expansion as necessary) - return Super::emitIntrinsicCallExprImpl(inst, targetIntrinsic, inOuterPrec); + return Super::emitIntrinsicCallExprImpl(inst, intrinsicDefinition, inOuterPrec); } void CPPSourceEmitter::emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) diff --git a/source/slang/slang-emit-cpp.h b/source/slang/slang-emit-cpp.h index 73a0f864e..cfd3d278d 100644 --- a/source/slang/slang-emit-cpp.h +++ b/source/slang/slang-emit-cpp.h @@ -68,7 +68,7 @@ protected: void emitComInterface(IRInterfaceType* interfaceType); virtual void emitRTTIObject(IRRTTIObject* rttiObject) SLANG_OVERRIDE; virtual bool tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* varType) SLANG_OVERRIDE; - virtual void emitIntrinsicCallExprImpl(IRCall* inst, IRTargetIntrinsicDecoration* targetIntrinsic, EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE; + virtual void emitIntrinsicCallExprImpl(IRCall* inst, UnownedStringSlice intrinsicDefinition, EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE; virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) SLANG_OVERRIDE; virtual void emitFuncDecorationsImpl(IRFunc* func) SLANG_OVERRIDE; virtual void emitVarDecorationsImpl(IRInst* var) SLANG_OVERRIDE; diff --git a/source/slang/slang-emit-cuda.cpp b/source/slang/slang-emit-cuda.cpp index 3e974f60e..fa0e3c7aa 100644 --- a/source/slang/slang-emit-cuda.cpp +++ b/source/slang/slang-emit-cuda.cpp @@ -436,17 +436,17 @@ void CUDASourceEmitter::_emitInitializerList(IRType* elementType, IRUse* operand m_writer->emit("\n}"); } -void CUDASourceEmitter::emitIntrinsicCallExprImpl(IRCall* inst, IRTargetIntrinsicDecoration* targetIntrinsic, EmitOpInfo const& inOuterPrec) +void CUDASourceEmitter::emitIntrinsicCallExprImpl(IRCall* inst, UnownedStringSlice intrinsicDefinition, EmitOpInfo const& inOuterPrec) { // This works around the problem, where some intrinsics that require the "half" type enabled don't use the half/float16_t type. // For example `f16tof32` can operate on float16_t *and* uint. If the input is uint, although we are // using the half feature (as far as CUDA is concerned), the half/float16_t type is not visible/directly used. - if (targetIntrinsic->getDefinition().startsWith(toSlice("__half"))) + if (intrinsicDefinition.startsWith(toSlice("__half"))) { m_extensionTracker->requireBaseType(BaseType::Half); } - Super::emitIntrinsicCallExprImpl(inst, targetIntrinsic, inOuterPrec); + Super::emitIntrinsicCallExprImpl(inst, intrinsicDefinition, inOuterPrec); } bool CUDASourceEmitter::tryEmitInstStmtImpl(IRInst* inst) diff --git a/source/slang/slang-emit-cuda.h b/source/slang/slang-emit-cuda.h index a08afd862..82d0240b3 100644 --- a/source/slang/slang-emit-cuda.h +++ b/source/slang/slang-emit-cuda.h @@ -93,7 +93,7 @@ protected: virtual bool tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* varType) SLANG_OVERRIDE; virtual bool tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec) SLANG_OVERRIDE; virtual bool tryEmitInstStmtImpl(IRInst* inst) SLANG_OVERRIDE; - virtual void emitIntrinsicCallExprImpl(IRCall* inst, IRTargetIntrinsicDecoration* targetIntrinsic, EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE; + virtual void emitIntrinsicCallExprImpl(IRCall* inst, UnownedStringSlice intrinsicDefinition, EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE; virtual void emitModuleImpl(IRModule* module, DiagnosticSink* sink) SLANG_OVERRIDE; diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index e40a5ceca..3cde9f467 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -2474,9 +2474,10 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type) } auto decorated = getResolvedInstForDecorations(type); - if(auto targetIntrinsicDecor = findBestTargetIntrinsicDecoration(decorated)) + UnownedStringSlice intrinsicDef; + if (findTargetIntrinsicDefinition(decorated, intrinsicDef)) { - m_writer->emit(targetIntrinsicDecor->getDefinition()); + m_writer->emit(intrinsicDef); return; } diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index c6aedbe3b..462f408ab 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -3960,6 +3960,12 @@ struct SPIRVEmitContext break; case kIROp_RequireSPIRVCapabilityDecoration: requireSPIRVCapability((SpvCapability)getIntVal(decoration->getOperand(0))); + if (decoration->getOperandCount() == 2) + { + auto stringLit = as(decoration->getOperand(1)); + if (stringLit->getStringSlice().getLength()) + ensureExtensionDeclaration(stringLit->getStringSlice()); + } break; } } diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index df245f555..f3fa213da 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -654,6 +654,19 @@ struct InliningPassBase }; +static bool hasGenericAsmInst(IRInst* func) +{ + auto f = as(getResolvedInstForDecorations(func)); + if (!f) + return false; + for (auto b : f->getBlocks()) + { + if (as(b->getTerminator())) + return true; + } + return false; +} + /// An inlining pass that inlines calls to `[unsafeForceInlineEarly]` functions struct MandatoryEarlyInliningPass : InliningPassBase { @@ -665,10 +678,15 @@ struct MandatoryEarlyInliningPass : InliningPassBase bool shouldInline(CallSiteInfo const& info) { - if(info.callee->findDecoration()) - return true; if (info.callee->findDecoration()) return true; + + // Never inline a callee that has genericASM instruction. + if (hasGenericAsmInst(info.callee)) + return false; + + if(info.callee->findDecoration()) + return true; return false; } }; @@ -782,6 +800,10 @@ struct ForceInliningPass : InliningPassBase bool shouldInline(CallSiteInfo const& info) { + // Never inline a callee that has genericASM instruction. + if (hasGenericAsmInst(info.callee)) + return false; + if (info.callee->findDecoration() || info.callee->findDecoration()|| info.callee->findDecoration()) diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index cab7e973d..e09abcf75 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -543,6 +543,11 @@ INST(Throw, throw, 1, 0) INST(TryCall, tryCall, 3, 0) // switch ... INST(Switch, switch, 3, 0) +// target_switch ... +INST(TargetSwitch, targetSwitch, 1, 0) + +// A generic asm inst has an return semantics that terminates the control flow. +INST(GenericAsm, GenericAsm, 1, 0) INST(discard, discard, 0, 0) @@ -972,6 +977,13 @@ INST(SizeOf, sizeOf, 1, 0) INST(AlignOf, alignOf, 1, 0) INST(IsType, IsType, 3, 0) +INST(TypeEquals, TypeEquals, 2, 0) +INST(IsInt, IsInt, 1, 0) +INST(IsBool, IsBool, 1, 0) +INST(IsFloat, IsFloat, 1, 0) +INST(IsUnsignedInt, IsUnsignedInt, 1, 0) +INST(IsSignedInt, IsSignedInt, 1, 0) + INST(ForwardDifferentiate, ForwardDifferentiate, 1, 0) // Produces the primal computation of backward derivatives, will return an intermediate context for @@ -1054,6 +1066,7 @@ INST(DebugSource, DebugSource, 2, HOISTABLE) INST(DebugLine, DebugLine, 5, 0) /* Inline assembly */ + INST(SPIRVAsm, SPIRVAsm, 0, PARENT) INST(SPIRVAsmInst, SPIRVAsmInst, 1, 0) // These instruction serve to inform the backend precisely how to emit each @@ -1076,6 +1089,7 @@ INST(SPIRVAsmInst, SPIRVAsmInst, 1, 0) INST(SPIRVAsmOperandResult, SPIRVAsmOperandResult, 0, 0) INST_RANGE(SPIRVAsmOperand, SPIRVAsmOperandLiteral, SPIRVAsmOperandResult) + #undef PARENT #undef USE_OTHER #undef INST_RANGE diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 82b123bd4..1a0a8ed66 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2342,6 +2342,16 @@ struct IRSwitch : IRTerminatorInst IRUse* getCaseLabelUse(UInt index) { return getOperands() + 3 + index * 2 + 1; } }; +// A compile-time switch based on the current code generation target. +struct IRTargetSwitch : IRTerminatorInst +{ + IR_LEAF_ISA(TargetSwitch) + IRInst* getBreakBlock() { return getOperand(0); } + UInt getCaseCount() { return (getOperandCount() - 1) / 2; } + IRBlock* getCaseBlock(UInt index) { return (IRBlock*)getOperand(index * 2 + 2); } + IRInst* getCaseValue(UInt index) { return getOperand(index * 2 + 1); } +}; + struct IRThrow : IRTerminatorInst { IR_LEAF_ISA(Throw); @@ -2929,6 +2939,12 @@ struct IRSPIRVAsm : IRInst } }; +struct IRGenericAsm : IRInst +{ + IR_LEAF_ISA(GenericAsm) + UnownedStringSlice getAsm() { return as(getOperand(0))->getStringSlice(); } +}; + struct IRBuilderSourceLocRAII; struct IRBuilder @@ -3923,7 +3939,7 @@ public: IRSPIRVAsmOperand* emitSPIRVAsmOperandEnum(IRInst* inst, IRType* constantType); IRSPIRVAsmInst* emitSPIRVAsmInst(IRInst* opcode, List operands); IRSPIRVAsm* emitSPIRVAsm(IRType* type); - + IRInst* emitGenericAsm(UnownedStringSlice asmText); // // Decorations // @@ -4130,9 +4146,23 @@ public: addDecoration(value, kIROp_RequireSPIRVVersionDecoration, getIntValue(getBasicType(BaseType::UInt64), intValue)); } - void addRequireSPIRVCapabilityDecoration(IRInst* value, int32_t capabilityName) + void addRequireSPIRVCapabilityDecoration(IRInst* value, int32_t capabilityName, UnownedStringSlice extensionName) { - addDecoration(value, kIROp_RequireSPIRVCapabilityDecoration, getIntValue(getIntType(), IRIntegerValue(capabilityName))); + if (extensionName.getLength()) + { + addDecoration( + value, + kIROp_RequireSPIRVCapabilityDecoration, + getIntValue(getIntType(), IRIntegerValue(capabilityName)), + getStringValue(extensionName)); + } + else + { + addDecoration( + value, + kIROp_RequireSPIRVCapabilityDecoration, + getIntValue(getIntType(), IRIntegerValue(capabilityName))); + } } void addRequireCUDASMVersionDecoration(IRInst* value, const SemanticVersion& version) @@ -4499,6 +4529,8 @@ IRTargetSpecificDecoration* findBestTargetDecoration( IRInst* val, CapabilityAtom targetCapabilityAtom); +bool findTargetIntrinsicDefinition(IRInst* callee, CapabilitySet const& targetCaps, UnownedStringSlice& outDefinition); + inline IRTargetIntrinsicDecoration* findBestTargetIntrinsicDecoration( IRInst* inInst, CapabilitySet const& targetCaps) diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 66fda8a86..e37aa322e 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -7,6 +7,7 @@ #include "slang-mangle.h" #include "slang-ir-string-hash.h" #include "slang-ir-autodiff.h" +#include "slang-ir-specialize-target-switch.h" #include "slang-module-library.h" #include "../core/slang-performance-profiler.h" @@ -1629,6 +1630,8 @@ LinkedIR linkIR( } } + // Specialize target_switch branches to use the best branch for the target. + specializeTargetSwitch(targetReq, state->irModule); // TODO: *technically* we should consider the case where // we have global variables with initializers, since diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index d83c6dccd..34ccdf924 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -902,12 +902,71 @@ struct PeepholeContext : InstPassBase } break; } - + case kIROp_TypeEquals: + { + auto left = inst->getOperand(0)->getDataType(); + auto right = inst->getOperand(1)->getDataType(); + if (isConcreteType(left) && isConcreteType(right)) + { + IRBuilder builder(module); + builder.setInsertBefore(inst); + bool result = left == right; + inst->replaceUsesWith(builder.getBoolValue(result)); + maybeRemoveOldInst(inst); + changed = true; + } + break; + } + case kIROp_IsInt: + case kIROp_IsFloat: + case kIROp_IsUnsignedInt: + case kIROp_IsSignedInt: + case kIROp_IsBool: + { + auto type = inst->getOperand(0)->getDataType(); + if (auto vectorType = as(type)) + type = vectorType->getElementType(); + if (auto matType = as(type)) + type = matType->getElementType(); + if (isConcreteType(type)) + { + IRBuilder builder(module); + builder.setInsertBefore(inst); + bool result = false; + switch (inst->getOp()) + { + case kIROp_IsInt: + result = isIntegralType(type); + break; + case kIROp_IsBool: + result = type->getOp() == kIROp_BoolType; + break; + case kIROp_IsFloat: + result = isFloatingType(type); + break; + case kIROp_IsUnsignedInt: + result = isIntegralType(type) && !getIntTypeInfo(type).isSigned; + break; + case kIROp_IsSignedInt: + result = isIntegralType(type) && getIntTypeInfo(type).isSigned; + break; + } + inst->replaceUsesWith(builder.getBoolValue(result)); + maybeRemoveOldInst(inst); + changed = true; + } + break; + } default: break; } } + bool isConcreteType(IRType* type) + { + return type->parent->getOp() == kIROp_Module && !as(type); + } + bool processFunc(IRInst* func) { func->getModule()->invalidateAllAnalysis(); diff --git a/source/slang/slang-ir-restructure.cpp b/source/slang/slang-ir-restructure.cpp index 7606de263..cfd1d4597 100644 --- a/source/slang/slang-ir-restructure.cpp +++ b/source/slang/slang-ir-restructure.cpp @@ -256,6 +256,7 @@ namespace Slang case kIROp_MissingReturn: case kIROp_Return: case kIROp_discard: + case kIROp_GenericAsm: // These cases are all simple terminators that can be handled as-is // without needing to construct a separate `Region` to encapsulate them. // diff --git a/source/slang/slang-ir-sccp.cpp b/source/slang/slang-ir-sccp.cpp index ce635dca8..5ae858256 100644 --- a/source/slang/slang-ir-sccp.cpp +++ b/source/slang/slang-ir-sccp.cpp @@ -1192,7 +1192,13 @@ struct SCCPContext } cfgWorkList.add(switchInst->getDefaultLabel()); } - + else if (auto targetSwitch = as(inst)) + { + for (UInt cc = 0; cc < targetSwitch->getCaseCount(); ++cc) + { + cfgWorkList.add(targetSwitch->getCaseBlock(cc)); + } + } // There are other cases of terminator instructions not handled // above (e.g., `return` instructions), but these can't cause // additional basic blocks in the CFG to execute, so we don't @@ -1555,7 +1561,6 @@ struct SCCPContext terminator->removeAndDeallocate(); changed = true; } - } } diff --git a/source/slang/slang-ir-specialize-target-switch.cpp b/source/slang/slang-ir-specialize-target-switch.cpp new file mode 100644 index 000000000..2593389b1 --- /dev/null +++ b/source/slang/slang-ir-specialize-target-switch.cpp @@ -0,0 +1,67 @@ +#include "slang-ir-specialize-target-switch.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-compiler.h" +#include "slang-capability.h" +#include "slang-ir-dce.h" + +namespace Slang +{ + void specializeTargetSwitch(TargetRequest* target, IRGlobalValueWithCode* code) + { + bool changed = false; + for (auto block : code->getBlocks()) + { + if (auto targetSwitch = as(block->getTerminator())) + { + CapabilitySet bestCapSet = CapabilitySet::makeInvalid(); + IRBlock* targetBlock = nullptr; + for (UInt i = 0; i < targetSwitch->getCaseCount(); i++) + { + auto cap = (CapabilityAtom)getIntVal(targetSwitch->getCaseValue(i)); + CapabilitySet capSet; + if (cap == CapabilityAtom::Invalid) + capSet = CapabilitySet::makeEmpty(); + else + capSet = CapabilitySet(cap); + if (capSet.isBetterForTarget(bestCapSet, target->getTargetCaps())) + { + targetBlock = targetSwitch->getCaseBlock(i); + bestCapSet = capSet; + } + } + SLANG_ASSERT(targetBlock); + IRBuilder builder(targetSwitch); + builder.setInsertBefore(targetSwitch); + builder.emitBranch(targetBlock); + targetSwitch->removeAndDeallocate(); + changed = true; + } + } + if (changed) + { + // Remove unreachable blocks after specialization. + eliminateDeadCode(code); + } + } + + void specializeTargetSwitch(TargetRequest* target, IRModule* module) + { + for (auto globalInst : module->getGlobalInsts()) + { + if (auto code = as(globalInst)) + { + specializeTargetSwitch(target, code); + if (auto gen = as(code)) + { + auto retVal = findGenericReturnVal(gen); + if (auto innerCode = as(retVal)) + { + specializeTargetSwitch(target, innerCode); + } + } + } + } + } + +} diff --git a/source/slang/slang-ir-specialize-target-switch.h b/source/slang/slang-ir-specialize-target-switch.h new file mode 100644 index 000000000..91071cec6 --- /dev/null +++ b/source/slang/slang-ir-specialize-target-switch.h @@ -0,0 +1,15 @@ +#ifndef SLANG_IR_SPECIALIZE_TARGET_SWITCH_H +#define SLANG_IR_SPECIALIZE_TARGET_SWITCH_H + +namespace Slang +{ + struct IRModule; + class TargetRequest; + + // Repalce all target_switch insts with the case that matches current target. + // + void specializeTargetSwitch(TargetRequest* target, IRModule* module); + +} + +#endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 0b1e9c342..cd49c6df5 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -512,6 +512,7 @@ namespace Slang case kIROp_Unreachable: case kIROp_MissingReturn: case kIROp_discard: + case kIROp_GenericAsm: break; case kIROp_unconditionalBranch: @@ -538,7 +539,11 @@ namespace Slang end = operands + terminator->getOperandCount() + 1; stride = 2; break; - + case kIROp_TargetSwitch: + begin = operands + 2; + end = operands + terminator->getOperandCount() + 1; + stride = 2; + break; default: SLANG_UNEXPECTED("unhandled terminator instruction"); UNREACHABLE_RETURN(IRBlock::SuccessorList(nullptr, nullptr)); @@ -5768,6 +5773,12 @@ namespace Slang return asmInst; } + IRInst* IRBuilder::emitGenericAsm(UnownedStringSlice asmText) + { + IRInst* arg = getStringValue(asmText); + return emitIntrinsicInst(nullptr, kIROp_GenericAsm, 1, &arg); + } + // // Decorations // @@ -7743,6 +7754,26 @@ namespace Slang return findBestTargetDecoration(val, CapabilitySet(targetCapabilityAtom)); } + bool findTargetIntrinsicDefinition(IRInst* callee, CapabilitySet const& targetCaps, UnownedStringSlice& outDefinition) + { + if (auto decor = findBestTargetIntrinsicDecoration(callee, targetCaps)) + { + outDefinition = decor->getDefinition(); + return true; + } + auto func = as(callee); + if (!func) + return false; + auto block = func->getFirstBlock(); + if (!block) + return false; + if (auto genAsm = as(block->getTerminator())) + { + outDefinition = genAsm->getAsm(); + return true; + } + return false; + } #if 0 IRFunc* cloneSimpleFuncWithoutRegistering(IRSpecContextBase* context, IRFunc* originalFunc) diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index 29661a415..810315dd9 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -540,6 +540,21 @@ struct ASTLookupStmtVisitor : public StmtVisitor bool visitCaseStmt(CaseStmt* stmt) { return checkExpr(stmt->expr); } + bool visitTargetSwitchStmt(TargetSwitchStmt* stmt) + { + for (auto targetCase : stmt->targetCases) + if (dispatchIfNotNull(targetCase)) + return true; + return false; + } + + bool visitTargetCaseStmt(TargetCaseStmt* stmt) + { + return dispatchIfNotNull(stmt->body); + } + + bool visitIntrinsicAsmStmt(IntrinsicAsmStmt*) { return false; } + bool visitDefaultStmt(DefaultStmt*) { return false; } bool visitIfStmt(IfStmt* stmt) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 1b97414fb..41347ddef 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -5723,6 +5723,49 @@ struct StmtLoweringVisitor : StmtVisitor } } + void visitTargetSwitchStmt(TargetSwitchStmt* stmt) + { + auto builder = getBuilder(); + startBlockIfNeeded(stmt); + auto initialBlock = builder->getBlock(); + auto breakLabel = builder->createBlock(); + context->shared->breakLabels.add(stmt, breakLabel); + builder->setInsertInto(initialBlock->getParent()); + List args; + args.add(breakLabel); + Dictionary mapCaseStmtToBlock; + for (auto targetCase : stmt->targetCases) + { + IRBlock* caseBlock = nullptr; + if (!mapCaseStmtToBlock.tryGetValue(targetCase->body, caseBlock)) + { + caseBlock = builder->emitBlock(); + lowerStmt(context, targetCase->body); + mapCaseStmtToBlock.add(targetCase->body, caseBlock); + if (!builder->getBlock()->getTerminator()) + builder->emitBranch(breakLabel); + } + args.add(builder->getIntValue(builder->getIntType(), targetCase->capability)); + args.add(caseBlock); + } + context->shared->breakLabels.remove(stmt); + builder->setInsertInto(initialBlock); + builder->emitIntrinsicInst(nullptr, kIROp_TargetSwitch, (UInt)args.getCount(), args.getBuffer()); + insertBlock(breakLabel); + } + + void visitTargetCaseStmt(TargetCaseStmt*) + { + SLANG_UNREACHABLE("lowering target case"); + } + + void visitIntrinsicAsmStmt(IntrinsicAsmStmt* stmt) + { + auto builder = getBuilder(); + IRInst* arg = builder->getStringValue(stmt->asmText.getUnownedSlice()); + builder->emitIntrinsicInst(nullptr, kIROp_GenericAsm, 1, &arg); + } + void visitSwitchStmt(SwitchStmt* stmt) { auto builder = getBuilder(); @@ -9026,7 +9069,7 @@ struct DeclLoweringVisitor : DeclVisitor else if (auto spvVersion = as(modifier)) getBuilder()->addRequireSPIRVVersionDecoration(irFunc, spvVersion->version); else if (auto capMod = as(modifier)) - getBuilder()->addRequireSPIRVCapabilityDecoration(irFunc, capMod->capability); + getBuilder()->addRequireSPIRVCapabilityDecoration(irFunc, capMod->capability, capMod->extensionName.getUnownedSlice()); else if (auto cudasmVersion = as(modifier)) getBuilder()->addRequireCUDASMVersionDecoration(irFunc, cudasmVersion->version); } diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 580215fc7..57a21a90a 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -4202,6 +4202,108 @@ namespace Slang return stmt; } + static Stmt* parseTargetSwitchStmt(Parser* parser) + { + TargetSwitchStmt* stmt = parser->astBuilder->create(); + parser->FillPosition(stmt); + parser->ReadToken(); + if (!beginMatch(parser, MatchedTokenType::CurlyBraces)) + { + return stmt; + } + Token closingBraceToken; + while (!AdvanceIfMatch(parser, MatchedTokenType::CurlyBraces, &closingBraceToken)) + { + List caseNames; + for (;;) + { + if (parser->LookAheadToken("case")) + { + parser->ReadToken(); + caseNames.add(parser->ReadToken()); + parser->ReadToken(TokenType::Colon); + } + else if (parser->LookAheadToken("default")) + { + auto token = parser->ReadToken(); + parser->ReadToken(TokenType::Colon); + token.setContent(UnownedStringSlice("")); + caseNames.add(token); + } + else + break; + } + if (caseNames.getCount() == 0) + { + parser->sink->diagnose( + parser->tokenReader.peekLoc(), + Diagnostics::unexpectedTokenExpectedTokenType, + parser->tokenReader.peekToken(), + "'case' or 'default'"); + parser->isRecovering = true; + goto recover; + } + else + { + Stmt* bodyStmt = nullptr; + for (;;) + { + if (parser->LookAheadToken("case") || parser->LookAheadToken("default") || parser->LookAheadToken(TokenType::RBrace) || + parser->LookAheadToken(TokenType::EndOfFile)) + break; + auto nextStmt = parser->ParseStatement(stmt); + if (nextStmt) + { + if (!bodyStmt) + { + bodyStmt = nextStmt; + } + else if (auto seqStmt = as(bodyStmt)) + { + seqStmt->stmts.add(nextStmt); + } + else + { + SeqStmt* newBody = parser->astBuilder->create(); + newBody->loc = bodyStmt->loc; + newBody->stmts.add(bodyStmt); + newBody->stmts.add(nextStmt); + bodyStmt = newBody; + } + } + } + + for (auto caseName : caseNames) + { + TargetCaseStmt* targetCase = parser->astBuilder->create(); + auto cap = findCapabilityAtom(caseName.getContent()); + if (caseName.getContent().getLength() && cap == CapabilityAtom::Invalid) + { + parser->sink->diagnose(caseName.loc, Diagnostics::unknownTargetName, caseName.getContent()); + } + targetCase->capability = int32_t(cap); + targetCase->loc = caseName.loc; + targetCase->body = bodyStmt; + stmt->targetCases.add(targetCase); + } + } + recover:; + TryRecover(parser); + } + return stmt; + } + + static Stmt* parseIntrinsicAsmStmt(Parser* parser) + { + IntrinsicAsmStmt* stmt = parser->astBuilder->create(); + parser->FillPosition(stmt); + parser->ReadToken(); + + stmt->asmText = getStringLiteralTokenValue(parser->ReadToken(TokenType::StringLiteral)); + parser->ReadToken(TokenType::Semicolon); + return stmt; + } + GpuForeachStmt* ParseGpuForeachStmt(Parser* parser) { // Hard-coding parsing of the following: @@ -4421,6 +4523,10 @@ namespace Slang } else if (LookAheadToken("switch")) statement = ParseSwitchStmt(this); + else if (LookAheadToken("__target_switch")) + statement = parseTargetSwitchStmt(this); + else if (LookAheadToken("__intrinsic_asm")) + statement = parseIntrinsicAsmStmt(this); else if (LookAheadToken("case")) statement = ParseCaseStmt(this); else if (LookAheadToken("default")) @@ -6160,11 +6266,18 @@ namespace Slang return SPIRVAsmOperand{flavor, tok, varExpr}; }; + const auto slangTypeExprOperand = [&](auto flavor) { + auto tok = parser->tokenReader.peekToken(); + const auto typeExpr = parser->ParseType(); + return SPIRVAsmOperand{ flavor, tok, typeExpr }; + }; + // The result marker if(parser->LookAheadToken("result")) { return SPIRVAsmOperand{SPIRVAsmOperand::ResultMarker, parser->ReadToken()}; } + // A regular identifier else if(parser->LookAheadToken(TokenType::Identifier)) { @@ -6206,7 +6319,7 @@ namespace Slang // A $$foo type else if(AdvanceIf(parser, TokenType::DollarDollar)) { - return slangIdentOperand(SPIRVAsmOperand::SlangType); + return slangTypeExprOperand(SPIRVAsmOperand::SlangType); } Unexpected(parser); @@ -6756,8 +6869,8 @@ namespace Slang static NodeBase* parseSPIRVCapabilityModifier(Parser* parser, void*) { - Token token; - token = parser->ReadToken(); + parser->ReadToken(TokenType::LParent); + Token token = parser->ReadToken(TokenType::Identifier); auto modifier = parser->astBuilder->create(); const SPIRVCoreGrammarInfo& spirvInfo = parser->astBuilder->getGlobalSession()->getSPIRVCoreGrammarInfo(); @@ -6765,9 +6878,12 @@ namespace Slang if (!cap) { parser->sink->diagnose(token, Diagnostics::unknownSPIRVCapability, token); - return nullptr; } - modifier->capability = int32_t(*cap); + else + { + modifier->capability = (int32_t)cap.value(); + } + parser->ReadToken(TokenType::RParent); return modifier; } diff --git a/source/slang/slang-spirv-core-grammar-embed.cpp b/source/slang/slang-spirv-core-grammar-embed.cpp index ee409dad0..cd5c6ddb7 100644 --- a/source/slang/slang-spirv-core-grammar-embed.cpp +++ b/source/slang/slang-spirv-core-grammar-embed.cpp @@ -12431,21 +12431,20 @@ static bool getOperandKindUnderneathId(const OperandKind& k, OperandKind& v) RefPtr SPIRVCoreGrammarInfo::getEmbeddedVersion() { - static SPIRVCoreGrammarInfo embedded = [](){ - SPIRVCoreGrammarInfo info; - info.opcodes.embedded = &lookupSpvOp; - info.capabilities.embedded = &lookupSpvCapability; - info.allEnumsWithTypePrefix.embedded = &lookupEnumWithTypePrefix; - info.opInfos.embedded = &getOpInfo; - info.opNames.embedded = &getOpName; - info.operandKinds.embedded = &lookupOperandKind; - info.allEnums.embedded = &lookupQualifiedEnum; - info.allEnumNames.embedded = &getQualifiedEnumName; - info.operandKindNames.embedded = &getOperandKindName; - info.operandKindUnderneathIds.embedded = &getOperandKindUnderneathId; - info.addReference(); + static RefPtr embedded = [](){ + RefPtr info = new SPIRVCoreGrammarInfo(); + info->opcodes.embedded = &lookupSpvOp; + info->capabilities.embedded = &lookupSpvCapability; + info->allEnumsWithTypePrefix.embedded = &lookupEnumWithTypePrefix; + info->opInfos.embedded = &getOpInfo; + info->opNames.embedded = &getOpName; + info->operandKinds.embedded = &lookupOperandKind; + info->allEnums.embedded = &lookupQualifiedEnum; + info->allEnumNames.embedded = &getQualifiedEnumName; + info->operandKindNames.embedded = &getOperandKindName; + info->operandKindUnderneathIds.embedded = &getOperandKindUnderneathId; return info; }(); - return &embedded; + return embedded; } } -- cgit v1.2.3