From ac886fd3e329a9599ed1ac7a6d8b26ca5821046c Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 4 Oct 2023 11:20:35 -0700 Subject: SPIRV compiler performance fixes. (#3258) * SPIRV compiler performance fixes. * Cleanup. * update project files * Cleanup debug code. * Make redundancy removal non-recursive. --------- Co-authored-by: Yong He --- source/core/slang-uint-set.h | 21 ++ source/slang/hlsl.meta.slang | 269 ++++++++++----------- source/slang/slang-ast-modifier.h | 9 - source/slang/slang-emit-spirv.cpp | 28 +-- source/slang/slang-emit.cpp | 39 ++- source/slang/slang-ir-array-reg-to-mem.cpp | 88 ------- source/slang/slang-ir-array-reg-to-mem.h | 16 -- source/slang/slang-ir-autodiff-cfg-norm.cpp | 2 +- source/slang/slang-ir-autodiff-fwd.cpp | 2 +- source/slang/slang-ir-composite-reg-to-mem.cpp | 202 ++++++++++++++++ source/slang/slang-ir-composite-reg-to-mem.h | 13 + .../slang/slang-ir-eliminate-multilevel-break.cpp | 3 +- source/slang/slang-ir-eliminate-phis.cpp | 51 ++-- source/slang/slang-ir-eliminate-phis.h | 11 +- source/slang/slang-ir-inline.cpp | 35 ++- source/slang/slang-ir-inline.h | 3 + source/slang/slang-ir-inst-defs.h | 1 - source/slang/slang-ir-insts.h | 19 -- source/slang/slang-ir-loop-unroll.cpp | 2 +- source/slang/slang-ir-lower-generics.cpp | 2 +- source/slang/slang-ir-peephole.cpp | 2 +- source/slang/slang-ir-reachability.cpp | 81 +++++-- source/slang/slang-ir-reachability.h | 35 +-- source/slang/slang-ir-redundancy-removal.cpp | 30 ++- source/slang/slang-ir-sccp.cpp | 45 +++- source/slang/slang-ir-simplify-cfg.cpp | 149 ++++++------ source/slang/slang-ir-simplify-cfg.h | 12 +- source/slang/slang-ir-single-return.cpp | 2 +- source/slang/slang-ir-specialize-function-call.cpp | 2 +- source/slang/slang-ir-specialize-resources.cpp | 4 +- source/slang/slang-ir-specialize.cpp | 4 +- source/slang/slang-ir-spirv-legalize.cpp | 8 +- source/slang/slang-ir-ssa-register-allocate.cpp | 15 +- source/slang/slang-ir-ssa-register-allocate.h | 2 +- source/slang/slang-ir-ssa-simplification.cpp | 15 +- source/slang/slang-ir-ssa-simplification.h | 25 +- source/slang/slang-ir-ssa.cpp | 3 + .../slang/slang-ir-use-uninitialized-out-param.cpp | 2 +- source/slang/slang-ir-util.cpp | 17 ++ source/slang/slang-ir-util.h | 15 ++ source/slang/slang-ir.cpp | 5 +- source/slang/slang-lower-to-ir.cpp | 6 +- source/slang/slang-parser.cpp | 21 -- source/slang/slang-spirv-opt.cpp | 2 - source/slang/slang-spirv-val.cpp | 15 +- 45 files changed, 798 insertions(+), 535 deletions(-) delete mode 100644 source/slang/slang-ir-array-reg-to-mem.cpp delete mode 100644 source/slang/slang-ir-array-reg-to-mem.h create mode 100644 source/slang/slang-ir-composite-reg-to-mem.cpp create mode 100644 source/slang/slang-ir-composite-reg-to-mem.h (limited to 'source') diff --git a/source/core/slang-uint-set.h b/source/core/slang-uint-set.h index f2c573865..4912ae504 100644 --- a/source/core/slang-uint-set.h +++ b/source/core/slang-uint-set.h @@ -52,6 +52,8 @@ public: /// Returns true if the value is present inline bool contains(UInt val) const; + inline bool contains(const UIntSet& set) const; + /// == bool operator==(const UIntSet& set) const; /// != @@ -110,6 +112,25 @@ inline bool UIntSet::contains(UInt val) const ((m_buffer[idx] & (Element(1) << (val & kElementMask))) != 0); } +// -------------------------------------------------------------------------- +inline bool UIntSet::contains(const UIntSet& set) const +{ + for (Index i = 0; i < set.m_buffer.getCount(); i++) + { + if (i >= m_buffer.getCount()) + { + if (set.m_buffer[i]) + return false; + } + else + { + if ((m_buffer[i] & set.m_buffer[i]) != set.m_buffer[i]) + return false; + } + } + return true; +} + // -------------------------------------------------------------------------- inline void UIntSet::add(UInt val) { diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 5eb08e980..0fc6fb2a9 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -5116,7 +5116,6 @@ matrix trunc(matrix x) typedef uint WaveMask; __glsl_extension(GL_KHR_shader_subgroup_ballot) -__spirv_capability(GroupNonUniformBallot) __spirv_version(1.3) WaveMask WaveGetConvergedMask() { @@ -5132,6 +5131,7 @@ WaveMask WaveGetConvergedMask() let _true = true; return (spirv_asm { + OpCapability GroupNonUniformBallot; OpGroupNonUniformBallot $$uint4 result Subgroup $_true }).x; } @@ -5141,7 +5141,6 @@ __intrinsic_op($(kIROp_WaveGetActiveMask)) WaveMask __WaveGetActiveMask(); __glsl_extension(GL_KHR_shader_subgroup_ballot) -__spirv_capability(GroupNonUniformBallot) __spirv_version(1.3) WaveMask WaveGetActiveMask() { @@ -5155,6 +5154,7 @@ WaveMask WaveGetActiveMask() let _true = true; return (spirv_asm { + OpCapability GroupNonUniformBallot; OpGroupNonUniformBallot $$uint4 result Subgroup $_true }).x; default: @@ -5163,7 +5163,6 @@ WaveMask WaveGetActiveMask() } __glsl_extension(GL_KHR_shader_subgroup_basic) -__spirv_capability(GroupNonUniformBallot) __spirv_version(1.3) bool WaveMaskIsFirstLane(WaveMask mask) { @@ -5178,6 +5177,7 @@ bool WaveMaskIsFirstLane(WaveMask mask) case spirv: return spirv_asm { + OpCapability GroupNonUniformBallot; OpGroupNonUniformElect $$bool result Subgroup }; default: @@ -5187,7 +5187,6 @@ bool WaveMaskIsFirstLane(WaveMask mask) __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) -__spirv_capability(GroupNonUniformBallot) bool WaveMaskAllTrue(WaveMask mask, bool condition) { __target_switch @@ -5201,6 +5200,7 @@ bool WaveMaskAllTrue(WaveMask mask, bool condition) case spirv: return spirv_asm { + OpCapability GroupNonUniformBallot; OpGroupNonUniformAll $$bool result Subgroup $condition }; default: @@ -5210,7 +5210,6 @@ bool WaveMaskAllTrue(WaveMask mask, bool condition) __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) -__spirv_capability(GroupNonUniformBallot) bool WaveMaskAnyTrue(WaveMask mask, bool condition) { __target_switch @@ -5224,6 +5223,7 @@ bool WaveMaskAnyTrue(WaveMask mask, bool condition) case spirv: return spirv_asm { + OpCapability GroupNonUniformBallot; OpGroupNonUniformAny $$bool result Subgroup $condition }; default: @@ -5233,7 +5233,6 @@ bool WaveMaskAnyTrue(WaveMask mask, bool condition) __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__spirv_capability(GroupNonUniformBallot) WaveMask WaveMaskBallot(WaveMask mask, bool condition) { __target_switch @@ -5247,6 +5246,7 @@ WaveMask WaveMaskBallot(WaveMask mask, bool condition) case spirv: return (spirv_asm { + OpCapability GroupNonUniformBallot; OpGroupNonUniformBallot $$uint4 result Subgroup $condition }).x; default: @@ -5368,7 +5368,6 @@ void GroupMemoryBarrierWithWaveSync() __generic __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__spirv_capability(GroupNonUniformBallot) T WaveMaskBroadcastLaneAt(WaveMask mask, T value, constexpr int lane) { __target_switch @@ -5378,14 +5377,16 @@ T WaveMaskBroadcastLaneAt(WaveMask mask, T value, constexpr int lane) case hlsl: __intrinsic_asm "WaveReadLaneAt($1, $2)"; case spirv: let ulane = uint(lane); - return spirv_asm {OpGroupNonUniformBroadcast $$T result Subgroup $value $ulane}; + return spirv_asm { + OpCapability GroupNonUniformBallot; + OpGroupNonUniformBroadcast $$T result Subgroup $value $ulane; + }; } } __generic __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__spirv_capability(GroupNonUniformBallot) vector WaveMaskBroadcastLaneAt(WaveMask mask, vector value, constexpr int lane) { __target_switch @@ -5395,7 +5396,10 @@ vector WaveMaskBroadcastLaneAt(WaveMask mask, vector value, constexpr case hlsl: __intrinsic_asm "WaveReadLaneAt($1, $2)"; case spirv: let ulane = uint(lane); - return spirv_asm {OpGroupNonUniformBroadcast $$vector result Subgroup $value $ulane}; + return spirv_asm { + OpCapability GroupNonUniformBallot; + OpGroupNonUniformBroadcast $$vector result Subgroup $value $ulane; + }; } } __generic @@ -5408,7 +5412,6 @@ matrix WaveMaskBroadcastLaneAt(WaveMask mask, matrix value, conste __generic __glsl_extension(GL_KHR_shader_subgroup_shuffle) __spirv_version(1.3) -__spirv_capability(GroupNonUniformShuffle) T WaveMaskReadLaneAt(WaveMask mask, T value, int lane) { __target_switch @@ -5418,13 +5421,15 @@ T WaveMaskReadLaneAt(WaveMask mask, T value, int lane) case hlsl: __intrinsic_asm "WaveReadLaneAt($1, $2)"; case spirv: let ulane = uint(lane); - return spirv_asm {OpGroupNonUniformShuffle $$T result Subgroup $value $ulane}; + return spirv_asm { + OpCapability GroupNonUniformShuffle; + OpGroupNonUniformShuffle $$T result Subgroup $value $ulane; + }; } } __generic __spirv_version(1.3)__glsl_extension(GL_KHR_shader_subgroup_shuffle) __spirv_version(1.3) -__spirv_capability(GroupNonUniformShuffle) vector WaveMaskReadLaneAt(WaveMask mask, vector value, int lane) { __target_switch @@ -5434,7 +5439,10 @@ vector WaveMaskReadLaneAt(WaveMask mask, vector value, int lane) case hlsl: __intrinsic_asm "WaveReadLaneAt($1, $2)"; case spirv: let ulane = uint(lane); - return spirv_asm {OpGroupNonUniformShuffle $$vector result Subgroup $value $ulane}; + return spirv_asm { + OpCapability GroupNonUniformShuffle; + OpGroupNonUniformShuffle $$vector result Subgroup $value $ulane; + }; } } __generic @@ -5466,7 +5474,6 @@ matrix WaveMaskShuffle(WaveMask mask, matrix value, int lane) __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__spirv_capability(GroupNonUniformBallot) uint WaveMaskPrefixCountBits(WaveMask mask, bool value) { __target_switch @@ -5477,6 +5484,7 @@ uint WaveMaskPrefixCountBits(WaveMask mask, bool value) case spirv: return spirv_asm { + OpCapability GroupNonUniformBallot; %mask:$$uint4 = OpGroupNonUniformBallot Subgroup $value; OpGroupNonUniformBallotBitCount $$uint result Subgroup 2 %mask }; @@ -5488,7 +5496,6 @@ uint WaveMaskPrefixCountBits(WaveMask mask, bool value) __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) T WaveMaskBitAnd(WaveMask mask, T expr) { __target_switch @@ -5497,14 +5504,16 @@ T WaveMaskBitAnd(WaveMask mask, T expr) case cuda: __intrinsic_asm "_waveAnd($0, $1)"; case hlsl: __intrinsic_asm "WaveActiveBitAnd($1)"; case spirv: - return spirv_asm {OpGroupNonUniformBitwiseAnd $$T result Subgroup 0 $expr}; + return spirv_asm { + OpCapability GroupNonUniformArithmetic; + OpGroupNonUniformBitwiseAnd $$T result Subgroup 0 $expr + }; } } __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) vector WaveMaskBitAnd(WaveMask mask, vector expr) { __target_switch @@ -5513,7 +5522,10 @@ vector WaveMaskBitAnd(WaveMask mask, vector expr) case cuda: __intrinsic_asm "_waveAndMultiple($0, $1)"; case hlsl: __intrinsic_asm "WaveActiveBitAnd($1)"; case spirv: - return spirv_asm {OpGroupNonUniformBitwiseAnd $$vector result Subgroup 0 $expr}; + return spirv_asm { + OpCapability GroupNonUniformArithmetic; + OpGroupNonUniformBitwiseAnd $$vector result Subgroup 0 $expr + }; } } __generic @@ -5524,7 +5536,6 @@ matrix WaveMaskBitAnd(WaveMask mask, matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) T WaveMaskBitOr(WaveMask mask, T expr) { __target_switch @@ -5533,13 +5544,15 @@ T WaveMaskBitOr(WaveMask mask, T expr) case cuda: __intrinsic_asm "_waveOr($0, $1)"; case hlsl: __intrinsic_asm "WaveActiveBitOr($1)"; case spirv: - return spirv_asm {OpGroupNonUniformBitwiseOr $$T result Subgroup 0 $expr}; + return spirv_asm { + OpCapability GroupNonUniformArithmetic; + OpGroupNonUniformBitwiseOr $$T result Subgroup 0 $expr + }; } } __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) vector WaveMaskBitOr(WaveMask mask, vector expr) { __target_switch @@ -5548,7 +5561,10 @@ vector WaveMaskBitOr(WaveMask mask, vector expr) case cuda: __intrinsic_asm "_waveOrMultiple($0, $1)"; case hlsl: __intrinsic_asm "WaveActiveBitOr($1)"; case spirv: - return spirv_asm {OpGroupNonUniformBitwiseOr $$vector result Subgroup 0 $expr}; + return spirv_asm { + OpCapability GroupNonUniformArithmetic; + OpGroupNonUniformBitwiseOr $$vector result Subgroup 0 $expr + }; } } __generic @@ -5559,7 +5575,6 @@ matrix WaveMaskBitOr(WaveMask mask, matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) T WaveMaskBitXor(WaveMask mask, T expr) { __target_switch @@ -5568,13 +5583,15 @@ T WaveMaskBitXor(WaveMask mask, T expr) case cuda: __intrinsic_asm "_waveXor($0, $1)"; case hlsl: __intrinsic_asm "WaveActiveBitXor($1)"; case spirv: - return spirv_asm {OpGroupNonUniformBitwiseXor $$T result Subgroup 0 $expr}; + return spirv_asm { + OpCapability GroupNonUniformArithmetic; + OpGroupNonUniformBitwiseXor $$T result Subgroup 0 $expr + }; } } __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) vector WaveMaskBitXor(WaveMask mask, vector expr) { __target_switch @@ -5583,7 +5600,10 @@ vector WaveMaskBitXor(WaveMask mask, vector expr) case cuda: __intrinsic_asm "_waveXorMultiple($0, $1)"; case hlsl: __intrinsic_asm "WaveActiveBitXor($1)"; case spirv: - return spirv_asm {OpGroupNonUniformBitwiseXor $$vector result Subgroup 0 $expr}; + return spirv_asm { + OpCapability GroupNonUniformArithmetic; + OpGroupNonUniformBitwiseXor $$vector result Subgroup 0 $expr + }; } } __generic @@ -5594,7 +5614,6 @@ matrix WaveMaskBitXor(WaveMask mask, matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) T WaveMaskMax(WaveMask mask, T expr) { __target_switch @@ -5604,17 +5623,16 @@ T WaveMaskMax(WaveMask mask, T expr) case hlsl: __intrinsic_asm "WaveActiveMax($1)"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFMax $$T result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMax $$T result Subgroup 0 $expr}; else if (__isSignedInt()) - return spirv_asm {OpGroupNonUniformSMax $$T result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformSMax $$T result Subgroup 0 $expr}; else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformUMax $$T result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformUMax $$T result Subgroup 0 $expr}; } } __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) vector WaveMaskMax(WaveMask mask, vector expr) { __target_switch @@ -5624,11 +5642,11 @@ vector WaveMaskMax(WaveMask mask, vector expr) case hlsl: __intrinsic_asm "WaveActiveMax($1)"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFMax $$vector result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMax $$vector result Subgroup 0 $expr}; else if (__isSignedInt()) - return spirv_asm {OpGroupNonUniformSMax $$vector result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformSMax $$vector result Subgroup 0 $expr}; else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformUMax $$vector result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformUMax $$vector result Subgroup 0 $expr}; } } @@ -5640,7 +5658,6 @@ matrix WaveMaskMax(WaveMask mask, matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) T WaveMaskMin(WaveMask mask, T expr) { __target_switch @@ -5650,18 +5667,17 @@ T WaveMaskMin(WaveMask mask, T expr) case hlsl: __intrinsic_asm "WaveActiveMin($1)"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFMin $$T result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMin $$T result Subgroup 0 $expr}; else if (__isSignedInt()) - return spirv_asm {OpGroupNonUniformSMin $$T result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformSMin $$T result Subgroup 0 $expr}; else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformUMin $$T result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformUMin $$T result Subgroup 0 $expr}; } } __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) vector WaveMaskMin(WaveMask mask, vector expr) { __target_switch @@ -5671,11 +5687,11 @@ vector WaveMaskMin(WaveMask mask, vector expr) case hlsl: __intrinsic_asm "WaveActiveMin($1)"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFMin $$vector result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMin $$vector result Subgroup 0 $expr}; else if (__isSignedInt()) - return spirv_asm {OpGroupNonUniformSMin $$vector result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformSMin $$vector result Subgroup 0 $expr}; else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformUMin $$vector result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformUMin $$vector result Subgroup 0 $expr}; } } @@ -5687,7 +5703,6 @@ matrix WaveMaskMin(WaveMask mask, matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) T WaveMaskProduct(WaveMask mask, T expr) { __target_switch @@ -5697,11 +5712,12 @@ T WaveMaskProduct(WaveMask mask, T expr) case hlsl: __intrinsic_asm "WaveActiveProduct($1)"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFMul $$T result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMul $$T result Subgroup 0 $expr}; else if (__isSignedInt()) { return spirv_asm { + OpCapability GroupNonUniformArithmetic; // TODO: use the correct integer width OpBitcast $$uint %uvalue $expr; OpGroupNonUniformIMul $$uint %mulResult Subgroup 0 %uvalue; @@ -5709,14 +5725,13 @@ T WaveMaskProduct(WaveMask mask, T expr) }; } else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformIMul $$T result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIMul $$T result Subgroup 0 $expr}; } } __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) vector WaveMaskProduct(WaveMask mask, vector expr) { __target_switch @@ -5726,11 +5741,12 @@ vector WaveMaskProduct(WaveMask mask, vector expr) case hlsl: __intrinsic_asm "WaveActiveProduct($1)"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFMul $$vector result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMul $$vector result Subgroup 0 $expr}; else if (__isSignedInt()) { return spirv_asm { + OpCapability GroupNonUniformArithmetic; // TODO: use the correct integer width OpBitcast $$vector %uvalue $expr; OpGroupNonUniformIMul $$vector %mulResult Subgroup 0 %uvalue; @@ -5738,7 +5754,7 @@ vector WaveMaskProduct(WaveMask mask, vector expr) }; } else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformIMul $$vector result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIMul $$vector result Subgroup 0 $expr}; } } @@ -5750,7 +5766,6 @@ matrix WaveMaskProduct(WaveMask mask, matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) T WaveMaskSum(WaveMask mask, T expr) { __target_switch @@ -5760,11 +5775,12 @@ T WaveMaskSum(WaveMask mask, T expr) case hlsl: __intrinsic_asm "WaveActiveSum($1)"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFAdd $$T result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$T result Subgroup 0 $expr}; else if (__isSignedInt()) { return spirv_asm { + OpCapability GroupNonUniformArithmetic; // TODO: use the correct integer width OpBitcast $$uint %uvalue $expr; OpGroupNonUniformIAdd $$uint %mulResult Subgroup 0 %uvalue; @@ -5772,13 +5788,12 @@ T WaveMaskSum(WaveMask mask, T expr) }; } else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformIAdd $$T result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIAdd $$T result Subgroup 0 $expr}; } } __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) vector WaveMaskSum(WaveMask mask, vector expr) { __target_switch @@ -5788,11 +5803,12 @@ vector WaveMaskSum(WaveMask mask, vector expr) case hlsl: __intrinsic_asm "WaveActiveSum($1)"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFAdd $$vector result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$vector result Subgroup 0 $expr}; else if (__isSignedInt()) { return spirv_asm { + OpCapability GroupNonUniformArithmetic; // TODO: use the correct integer width OpBitcast $$vector %uvalue $expr; OpGroupNonUniformIAdd $$vector %mulResult Subgroup 0 %uvalue; @@ -5800,7 +5816,7 @@ vector WaveMaskSum(WaveMask mask, vector expr) }; } else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformIAdd $$vector result Subgroup 0 $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIAdd $$vector result Subgroup 0 $expr}; } } __generic @@ -5811,7 +5827,6 @@ matrix WaveMaskSum(WaveMask mask, matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) -__spirv_capability(GroupNonUniformVote) __cuda_sm_version(7.0) bool WaveMaskAllEqual(WaveMask mask, T value) { @@ -5826,6 +5841,7 @@ bool WaveMaskAllEqual(WaveMask mask, T value) case spirv: return spirv_asm { + OpCapability GroupNonUniformVote; OpGroupNonUniformAllEqual $$bool result Subgroup $value }; default: @@ -5836,7 +5852,6 @@ __generic __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) __cuda_sm_version(7.0) -__spirv_capability(GroupNonUniformVote) bool WaveMaskAllEqual(WaveMask mask, vector value) { __target_switch @@ -5850,6 +5865,7 @@ bool WaveMaskAllEqual(WaveMask mask, vector value) case spirv: return spirv_asm { + OpCapability GroupNonUniformVote; OpGroupNonUniformAllEqual $$bool result Subgroup $value }; default: @@ -5867,7 +5883,6 @@ bool WaveMaskAllEqual(WaveMask mask, matrix value); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) T WaveMaskPrefixProduct(WaveMask mask, T expr) { __target_switch @@ -5877,11 +5892,12 @@ T WaveMaskPrefixProduct(WaveMask mask, T expr) case hlsl: __intrinsic_asm "WavePrefixProduct($1)"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFMul $$T result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMul $$T result Subgroup ExclusiveScan $expr}; else if (__isSignedInt()) { return spirv_asm { + OpCapability GroupNonUniformArithmetic; // TODO: use the correct integer width OpBitcast $$uint %uvalue $expr; OpGroupNonUniformIMul $$uint %mulResult Subgroup ExclusiveScan %uvalue; @@ -5895,7 +5911,6 @@ T WaveMaskPrefixProduct(WaveMask mask, T expr) __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) vector WaveMaskPrefixProduct(WaveMask mask, vector expr) { __target_switch @@ -5905,11 +5920,12 @@ vector WaveMaskPrefixProduct(WaveMask mask, vector expr) case hlsl: __intrinsic_asm "WavePrefixProduct($1)"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFMul $$vector result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMul $$vector result Subgroup ExclusiveScan $expr}; else if (__isSignedInt()) { return spirv_asm { + OpCapability GroupNonUniformArithmetic; // TODO: use the correct integer width OpBitcast $$vector %uvalue $expr; OpGroupNonUniformIMul $$vector %mulResult Subgroup ExclusiveScan %uvalue; @@ -5917,7 +5933,7 @@ vector WaveMaskPrefixProduct(WaveMask mask, vector expr) }; } else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformIMul $$vector result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIMul $$vector result Subgroup ExclusiveScan $expr}; } } __generic @@ -5928,7 +5944,6 @@ matrix WaveMaskPrefixProduct(WaveMask mask, matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) T WaveMaskPrefixSum(WaveMask mask, T expr) { __target_switch @@ -5938,11 +5953,12 @@ T WaveMaskPrefixSum(WaveMask mask, T expr) case hlsl: __intrinsic_asm "WavePrefixSum($1)"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFAdd $$T result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$T result Subgroup ExclusiveScan $expr}; else if (__isSignedInt()) { return spirv_asm { + OpCapability GroupNonUniformArithmetic; // TODO: use the correct integer width %uvalue:$$uint = OpBitcast $expr; %mulResult:$$uint = OpGroupNonUniformIAdd Subgroup ExclusiveScan %uvalue; @@ -5950,14 +5966,13 @@ T WaveMaskPrefixSum(WaveMask mask, T expr) }; } else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformIAdd $$T result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIAdd $$T result Subgroup ExclusiveScan $expr}; } } __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) vector WaveMaskPrefixSum(WaveMask mask, vector expr) { __target_switch @@ -5967,11 +5982,12 @@ vector WaveMaskPrefixSum(WaveMask mask, vector expr) case hlsl: __intrinsic_asm "WavePrefixSum($1)"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFAdd $$vector result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$vector result Subgroup ExclusiveScan $expr}; else if (__isSignedInt()) { return spirv_asm { + OpCapability GroupNonUniformArithmetic; // TODO: use the correct integer width %uvalue: $$vector = OpBitcast $expr; %mulResult: $$vector = OpGroupNonUniformIAdd Subgroup ExclusiveScan %uvalue; @@ -5979,7 +5995,7 @@ vector WaveMaskPrefixSum(WaveMask mask, vector expr) }; } else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformIAdd $$vector result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIAdd $$vector result Subgroup ExclusiveScan $expr}; } } __generic @@ -5990,7 +6006,6 @@ matrix WaveMaskPrefixSum(WaveMask mask, matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__spirv_capability(GroupNonUniformBallot) T WaveMaskReadLaneFirst(WaveMask mask, T expr) { __target_switch @@ -5999,13 +6014,12 @@ T WaveMaskReadLaneFirst(WaveMask mask, T expr) case cuda: __intrinsic_asm "_waveReadFirst($0, $1)"; case hlsl: __intrinsic_asm "WaveReadLaneFirst($1)"; case spirv: - return spirv_asm {OpGroupNonUniformBroadcastFirst $$T result Subgroup $expr}; + return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$T result Subgroup $expr}; } } __generic __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__spirv_capability(GroupNonUniformBallot) vector WaveMaskReadLaneFirst(WaveMask mask, vector expr) { __target_switch @@ -6014,7 +6028,7 @@ vector WaveMaskReadLaneFirst(WaveMask mask, vector expr) case cuda: __intrinsic_asm "_waveReadFirstMultiple($0, $1)"; case hlsl: __intrinsic_asm "WaveReadLaneFirst($1)"; case spirv: - return spirv_asm {OpGroupNonUniformBroadcastFirst $$vector result Subgroup $expr}; + return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$vector result Subgroup $expr}; } } @@ -6079,7 +6093,6 @@ WaveMask WaveMaskMatch(WaveMask mask, matrix value); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) T WaveMaskPrefixBitAnd(WaveMask mask, T expr) { __target_switch @@ -6088,14 +6101,13 @@ T WaveMaskPrefixBitAnd(WaveMask mask, T expr) case cuda: __intrinsic_asm "_wavePrefixAnd($0, $1)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd($1, uint4($0, 0, 0, 0))"; case spirv: - return spirv_asm {OpGroupNonUniformBitwiseAnd $$T result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformBitwiseAnd $$T result Subgroup ExclusiveScan $expr}; } } __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) vector WaveMaskPrefixBitAnd(WaveMask mask, vector expr) { __target_switch @@ -6104,7 +6116,7 @@ vector WaveMaskPrefixBitAnd(WaveMask mask, vector expr) case cuda: __intrinsic_asm "_wavePrefixAndMultiple($0, $1)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd($1, uint4($0, 0, 0, 0))"; case spirv: - return spirv_asm {OpGroupNonUniformBitwiseAnd $$vector result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformBitwiseAnd $$vector result Subgroup ExclusiveScan $expr}; } } @@ -6116,7 +6128,6 @@ matrix WaveMaskPrefixBitAnd(WaveMask mask, matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) T WaveMaskPrefixBitOr(WaveMask mask, T expr) { __target_switch @@ -6125,14 +6136,13 @@ T WaveMaskPrefixBitOr(WaveMask mask, T expr) case cuda: __intrinsic_asm "_wavePrefixOr($0, $1)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr($1, uint4($0, 0, 0, 0))"; case spirv: - return spirv_asm {OpGroupNonUniformBitwiseAnd $$T result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformBitwiseAnd $$T result Subgroup ExclusiveScan $expr}; } } __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) vector WaveMaskPrefixBitOr(WaveMask mask, vector expr) { __target_switch @@ -6141,7 +6151,7 @@ vector WaveMaskPrefixBitOr(WaveMask mask, vector expr) case cuda: __intrinsic_asm "_wavePrefixOrMultiple($0, $1)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr($1, uint4($0, 0, 0, 0))"; case spirv: - return spirv_asm {OpGroupNonUniformBitwiseOr $$vector result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformBitwiseOr $$vector result Subgroup ExclusiveScan $expr}; } } @@ -6153,7 +6163,6 @@ matrix WaveMaskPrefixBitOr(WaveMask mask, matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) T WaveMaskPrefixBitXor(WaveMask mask, T expr) { __target_switch @@ -6162,14 +6171,13 @@ T WaveMaskPrefixBitXor(WaveMask mask, T expr) case cuda: __intrinsic_asm "_wavePrefixXor($0, $1)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor($1, uint4($0, 0, 0, 0))"; case spirv: - return spirv_asm {OpGroupNonUniformBitwiseXor $$T result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformBitwiseXor $$T result Subgroup ExclusiveScan $expr}; } } __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) vector WaveMaskPrefixBitXor(WaveMask mask, vector expr) { __target_switch @@ -6178,7 +6186,7 @@ vector WaveMaskPrefixBitXor(WaveMask mask, vector expr) case cuda: __intrinsic_asm "_wavePrefixXorMultiple($0, $1)"; case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor($1, uint4($0, 0, 0, 0))"; case spirv: - return spirv_asm {OpGroupNonUniformBitwiseXor $$vector result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformBitwiseXor $$vector result Subgroup ExclusiveScan $expr}; } } @@ -6218,7 +6226,6 @@ for (auto opName : kWaveActiveBitOpEntries) { __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) T WaveActive$(opName.hlslName)(T expr) { __target_switch @@ -6226,7 +6233,7 @@ T WaveActive$(opName.hlslName)(T expr) case glsl: __intrinsic_asm "subgroup$(opName.glslName)($0)"; case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)"; case spirv: - return spirv_asm {OpGroupNonUniform$(opName.spirvName) $$T result Subgroup Reduce $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniform$(opName.spirvName) $$T result Subgroup Reduce $expr}; default: return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr); } @@ -6235,7 +6242,6 @@ T WaveActive$(opName.hlslName)(T expr) __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) vector WaveActive$(opName.hlslName)(vector expr) { __target_switch @@ -6243,7 +6249,7 @@ vector WaveActive$(opName.hlslName)(vector expr) case glsl: __intrinsic_asm "subgroup$(opName.glslName)($0)"; case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)"; case spirv: - return spirv_asm {OpGroupNonUniform$(opName.spirvName) $$vector result Subgroup Reduce $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniform$(opName.spirvName) $$vector result Subgroup Reduce $expr}; default: return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr); } @@ -6268,7 +6274,6 @@ for (const char* opName : kWaveActiveMinMaxNames) { __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) T WaveActive$(opName)(T expr) { __target_switch @@ -6277,11 +6282,11 @@ T WaveActive$(opName)(T expr) case hlsl: __intrinsic_asm "WaveActive$(opName)"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformF$(opName) $$T result Subgroup Reduce $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformF$(opName) $$T result Subgroup Reduce $expr}; else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformU$(opName) $$T result Subgroup Reduce $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformU$(opName) $$T result Subgroup Reduce $expr}; else - return spirv_asm {OpGroupNonUniformS$(opName) $$T result Subgroup Reduce $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformS$(opName) $$T result Subgroup Reduce $expr}; default: return WaveMask$(opName)(WaveGetActiveMask(), expr); } @@ -6290,7 +6295,6 @@ T WaveActive$(opName)(T expr) __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) vector WaveActive$(opName)(vector expr) { __target_switch @@ -6299,11 +6303,11 @@ vector WaveActive$(opName)(vector expr) case hlsl: __intrinsic_asm "WaveActive$(opName)"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformF$(opName) $$vector result Subgroup Reduce $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformF$(opName) $$vector result Subgroup Reduce $expr}; else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformU$(opName) $$vector result Subgroup Reduce $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformU$(opName) $$vector result Subgroup Reduce $expr}; else - return spirv_asm {OpGroupNonUniformS$(opName) $$vector result Subgroup Reduce $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformS$(opName) $$vector result Subgroup Reduce $expr}; default: return WaveMask$(opName)(WaveGetActiveMask(), expr); } @@ -6415,7 +6419,6 @@ ${{{{ __generic __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) -__spirv_capability(GroupNonUniformVote) bool WaveActiveAllEqual(T value) { __target_switch @@ -6427,6 +6430,7 @@ bool WaveActiveAllEqual(T value) case spirv: return spirv_asm { + OpCapability GroupNonUniformVote; OpGroupNonUniformAllEqual $$bool result Subgroup $value }; default: @@ -6437,7 +6441,6 @@ bool WaveActiveAllEqual(T value) __generic __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) -__spirv_capability(GroupNonUniformVote) bool WaveActiveAllEqual(vector value) { __target_switch @@ -6449,6 +6452,7 @@ bool WaveActiveAllEqual(vector value) case spirv: return spirv_asm { + OpCapability GroupNonUniformVote; OpGroupNonUniformAllEqual $$bool result Subgroup $value }; default: @@ -6465,7 +6469,6 @@ bool WaveActiveAllEqual(matrix value) __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) -__spirv_capability(GroupNonUniformVote) bool WaveActiveAllTrue(bool condition) { __target_switch @@ -6477,6 +6480,7 @@ bool WaveActiveAllTrue(bool condition) case spirv: return spirv_asm { + OpCapability GroupNonUniformVote; OpGroupNonUniformAll $$bool result Subgroup $condition }; default: @@ -6486,7 +6490,6 @@ bool WaveActiveAllTrue(bool condition) __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) -__spirv_capability(GroupNonUniformVote) bool WaveActiveAnyTrue(bool condition) { __target_switch @@ -6498,6 +6501,7 @@ bool WaveActiveAnyTrue(bool condition) case spirv: return spirv_asm { + OpCapability GroupNonUniformVote; OpGroupNonUniformAny $$bool result Subgroup $condition }; default: @@ -6507,7 +6511,6 @@ bool WaveActiveAnyTrue(bool condition) __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__spirv_capability(GroupNonUniformBallot) uint4 WaveActiveBallot(bool condition) { __target_switch @@ -6519,6 +6522,7 @@ uint4 WaveActiveBallot(bool condition) case spirv: return spirv_asm { + OpCapability GroupNonUniformBallot; OpGroupNonUniformBallot $$uint4 result Subgroup $condition }; default: @@ -6534,7 +6538,6 @@ uint WaveActiveCountBits(bool value) __glsl_extension(GL_KHR_shader_subgroup_basic) __spirv_version(1.3) -__spirv_capability(GroupNonUniform) uint WaveGetLaneCount() { __target_switch @@ -6545,6 +6548,7 @@ uint WaveGetLaneCount() case spirv: return spirv_asm { + OpCapability GroupNonUniform; result:$$uint = OpLoad builtin(SubgroupSize:uint) }; } @@ -6552,7 +6556,6 @@ uint WaveGetLaneCount() __glsl_extension(GL_KHR_shader_subgroup_basic) __spirv_version(1.3) -__spirv_capability(GroupNonUniform) uint WaveGetLaneIndex() { __target_switch @@ -6563,13 +6566,13 @@ uint WaveGetLaneIndex() case spirv: return spirv_asm { + OpCapability GroupNonUniform; result:$$uint = OpLoad builtin(SubgroupLocalInvocationId:uint) }; } } __glsl_extension(GL_KHR_shader_subgroup_basic) -__spirv_capability(GroupNonUniformBallot) __spirv_version(1.3) bool WaveIsFirstLane() { @@ -6582,6 +6585,7 @@ bool WaveIsFirstLane() case spirv: return spirv_asm { + OpCapability GroupNonUniformBallot; OpGroupNonUniformElect $$bool result Subgroup }; default: @@ -6591,7 +6595,6 @@ bool WaveIsFirstLane() // It's useful to have a wave uint4 version of countbits, because some wave functions return uint4. // This implementation tries to limit the amount of work required by the actual lane count. -__spirv_capability(GroupNonUniformBallot) uint _WaveCountBits(uint4 value) { __target_switch @@ -6599,6 +6602,7 @@ uint _WaveCountBits(uint4 value) case spirv: return spirv_asm { + OpCapability GroupNonUniformBallot; OpGroupNonUniformBallotBitCount $$uint result Subgroup Reduce $value }; default: @@ -6621,7 +6625,6 @@ uint _WaveCountBits(uint4 value) __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) T WavePrefixProduct(T expr) { __target_switch @@ -6630,11 +6633,15 @@ T WavePrefixProduct(T expr) case hlsl: __intrinsic_asm "WavePrefixProduct"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFMul $$T result Subgroup ExclusiveScan $expr}; + return spirv_asm { + OpCapability GroupNonUniformArithmetic; + OpGroupNonUniformFMul $$T result Subgroup ExclusiveScan $expr + }; else if (__isSignedInt()) { return spirv_asm { + OpCapability GroupNonUniformArithmetic; // TODO: use the correct integer width OpBitcast $$uint %uvalue $expr; OpGroupNonUniformIMul $$uint %mulResult Subgroup ExclusiveScan %uvalue; @@ -6642,7 +6649,7 @@ T WavePrefixProduct(T expr) }; } else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformIMul $$T result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIMul $$T result Subgroup ExclusiveScan $expr}; default: return WaveMaskPrefixProduct(WaveGetActiveMask(), expr); } @@ -6652,7 +6659,6 @@ T WavePrefixProduct(T expr) __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) vector WavePrefixProduct(vector expr) { __target_switch @@ -6661,11 +6667,12 @@ vector WavePrefixProduct(vector expr) case hlsl: __intrinsic_asm "WavePrefixProduct"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFMul $$vector result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMul $$vector result Subgroup ExclusiveScan $expr}; else if (__isSignedInt()) { return spirv_asm { + OpCapability GroupNonUniformArithmetic; // TODO: use the correct integer width OpBitcast $$vector %uvalue $expr; OpGroupNonUniformIMul $$vector %mulResult Subgroup ExclusiveScan %uvalue; @@ -6673,7 +6680,7 @@ vector WavePrefixProduct(vector expr) }; } else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformIMul $$vector result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIMul $$vector result Subgroup ExclusiveScan $expr}; default: return WaveMaskPrefixProduct(WaveGetActiveMask(), expr); } @@ -6689,7 +6696,6 @@ matrix WavePrefixProduct(matrix expr) __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) T WavePrefixSum(T expr) { __target_switch @@ -6698,11 +6704,12 @@ T WavePrefixSum(T expr) case hlsl: __intrinsic_asm "WavePrefixSum"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFAdd $$T result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$T result Subgroup ExclusiveScan $expr}; else if (__isSignedInt()) { return spirv_asm { + OpCapability GroupNonUniformArithmetic; // TODO: use the correct integer width %uvalue:$$uint = OpBitcast $expr; %mulResult:$$uint = OpGroupNonUniformIAdd Subgroup ExclusiveScan %uvalue; @@ -6710,7 +6717,7 @@ T WavePrefixSum(T expr) }; } else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformIAdd $$T result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIAdd $$T result Subgroup ExclusiveScan $expr}; default: return WaveMaskPrefixSum(WaveGetActiveMask(), expr); } @@ -6719,7 +6726,6 @@ T WavePrefixSum(T expr) __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__spirv_capability(GroupNonUniformArithmetic) vector WavePrefixSum(vector expr) { __target_switch @@ -6728,11 +6734,12 @@ vector WavePrefixSum(vector expr) case hlsl: __intrinsic_asm "WavePrefixSum"; case spirv: if (__isFloat()) - return spirv_asm {OpGroupNonUniformFAdd $$vector result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$vector result Subgroup ExclusiveScan $expr}; else if (__isSignedInt()) { return spirv_asm { + OpCapability GroupNonUniformArithmetic; // TODO: use the correct integer width %uvalue:$$vector = OpBitcast $expr; %mulResult:$$vector = OpGroupNonUniformIAdd Subgroup ExclusiveScan %uvalue; @@ -6740,7 +6747,7 @@ vector WavePrefixSum(vector expr) }; } else if (__isUnsignedInt()) - return spirv_asm {OpGroupNonUniformIAdd $$vector result Subgroup ExclusiveScan $expr}; + return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIAdd $$vector result Subgroup ExclusiveScan $expr}; default: return WaveMaskPrefixSum(WaveGetActiveMask(), expr); } @@ -6756,7 +6763,6 @@ matrix WavePrefixSum(matrix expr) __generic __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__spirv_capability(GroupNonUniformBallot) T WaveReadLaneFirst(T expr) { __target_switch @@ -6764,7 +6770,7 @@ T WaveReadLaneFirst(T expr) case glsl: __intrinsic_asm "subgroupBroadcastFirst($0)"; case hlsl: __intrinsic_asm "WaveReadLaneFirst"; case spirv: - return spirv_asm {OpGroupNonUniformBroadcastFirst $$T result Subgroup $expr}; + return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$T result Subgroup $expr}; default: return WaveMaskReadLaneFirst(WaveGetActiveMask(), expr); } @@ -6773,7 +6779,6 @@ T WaveReadLaneFirst(T expr) __generic __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__spirv_capability(GroupNonUniformBallot) vector WaveReadLaneFirst(vector expr) { __target_switch @@ -6781,7 +6786,7 @@ vector WaveReadLaneFirst(vector expr) case glsl: __intrinsic_asm "subgroupBroadcastFirst($0)"; case hlsl: __intrinsic_asm "WaveReadLaneFirst"; case spirv: - return spirv_asm {OpGroupNonUniformBroadcastFirst $$vector result Subgroup $expr}; + return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$vector result Subgroup $expr}; default: return WaveMaskReadLaneFirst(WaveGetActiveMask(), expr); } @@ -6803,7 +6808,6 @@ matrix WaveReadLaneFirst(matrix expr) __generic __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__spirv_capability(GroupNonUniformBallot) T WaveBroadcastLaneAt(T value, constexpr int lane) { __target_switch @@ -6812,7 +6816,7 @@ T WaveBroadcastLaneAt(T value, constexpr int lane) case hlsl: __intrinsic_asm "WaveReadLaneAt"; case spirv: let ulane = uint(lane); - return spirv_asm {OpGroupNonUniformBroadcast $$T result Subgroup $value $ulane}; + return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcast $$T result Subgroup $value $ulane}; default: return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, lane); } @@ -6821,7 +6825,6 @@ T WaveBroadcastLaneAt(T value, constexpr int lane) __generic __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__spirv_capability(GroupNonUniformBallot) vector WaveBroadcastLaneAt(vector value, constexpr int lane) { __target_switch @@ -6830,7 +6833,7 @@ vector WaveBroadcastLaneAt(vector value, constexpr int lane) case hlsl: __intrinsic_asm "WaveReadLaneAt"; case spirv: let ulane = uint(lane); - return spirv_asm {OpGroupNonUniformBroadcast $$vector result Subgroup $value $ulane}; + return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcast $$vector result Subgroup $value $ulane}; default: return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, lane); } @@ -6849,7 +6852,6 @@ matrix WaveBroadcastLaneAt(matrix value, constexpr int lane) __generic __glsl_extension(GL_KHR_shader_subgroup_shuffle) __spirv_version(1.3) -__spirv_capability(GroupNonUniformShuffle) T WaveReadLaneAt(T value, int lane) { __target_switch @@ -6858,7 +6860,7 @@ T WaveReadLaneAt(T value, int lane) case hlsl: __intrinsic_asm "WaveReadLaneAt"; case spirv: let ulane = uint(lane); - return spirv_asm {OpGroupNonUniformShuffle $$T result Subgroup $value $ulane}; + return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$T result Subgroup $value $ulane}; default: return WaveMaskReadLaneAt(WaveGetActiveMask(), value, lane); } @@ -6867,7 +6869,6 @@ T WaveReadLaneAt(T value, int lane) __generic __spirv_version(1.3) __glsl_extension(GL_KHR_shader_subgroup_shuffle) -__spirv_capability(GroupNonUniformShuffle) vector WaveReadLaneAt(vector value, int lane) { __target_switch @@ -6876,7 +6877,7 @@ vector WaveReadLaneAt(vector value, int lane) case hlsl: __intrinsic_asm "WaveReadLaneAt"; case spirv: let ulane = uint(lane); - return spirv_asm {OpGroupNonUniformShuffle $$vector result Subgroup $value $ulane}; + return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$vector result Subgroup $value $ulane}; default: return WaveMaskReadLaneAt(WaveGetActiveMask(), value, lane); } @@ -6896,7 +6897,6 @@ matrix WaveReadLaneAt(matrix value, int lane) __generic __glsl_extension(GL_KHR_shader_subgroup_shuffle) __spirv_version(1.3) -__spirv_capability(GroupNonUniformShuffle) T WaveShuffle(T value, int lane) { __target_switch @@ -6905,7 +6905,7 @@ T WaveShuffle(T value, int lane) case hlsl: __intrinsic_asm "WaveReadLaneAt"; case spirv: let ulane = uint(lane); - return spirv_asm {OpGroupNonUniformShuffle $$T result Subgroup $value $ulane}; + return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$T result Subgroup $value $ulane}; default: return WaveMaskShuffle(WaveGetActiveMask(), value, lane); } @@ -6914,7 +6914,6 @@ T WaveShuffle(T value, int lane) __generic __glsl_extension(GL_KHR_shader_subgroup_shuffle) __spirv_version(1.3) -__spirv_capability(GroupNonUniformShuffle) vector WaveShuffle(vector value, int lane) { __target_switch @@ -6923,7 +6922,7 @@ vector WaveShuffle(vector value, int lane) case hlsl: __intrinsic_asm "WaveReadLaneAt"; case spirv: let ulane = uint(lane); - return spirv_asm {OpGroupNonUniformShuffle $$vector result Subgroup $value $ulane}; + return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$vector result Subgroup $value $ulane}; default: return WaveMaskShuffle(WaveGetActiveMask(), value, lane); } @@ -6938,7 +6937,6 @@ matrix WaveShuffle(matrix value, int lane) __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__spirv_capability(GroupNonUniformBallot) uint WavePrefixCountBits(bool value) { __target_switch @@ -6948,6 +6946,7 @@ uint WavePrefixCountBits(bool value) case spirv: return spirv_asm { + OpCapability GroupNonUniformBallot; %mask:$$uint4 = OpGroupNonUniformBallot Subgroup $value; OpGroupNonUniformBallotBitCount $$uint result Subgroup 2 %mask }; @@ -6958,7 +6957,6 @@ uint WavePrefixCountBits(bool value) __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__spirv_capability(GroupNonUniformBallot) uint4 WaveGetConvergedMulti() { __target_switch @@ -6970,6 +6968,7 @@ uint4 WaveGetConvergedMulti() let _true = true; return spirv_asm { + OpCapability GroupNonUniformBallot; OpGroupNonUniformBallot $$uint4 result Subgroup $_true }; } diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 68824b931..bf79028d7 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -131,15 +131,6 @@ class RequiredSPIRVVersionModifier : public Modifier SemanticVersion version; }; -// A modifier to tag something as an intrinsic that requires -// a certain SPIRV capability to be enabled when used. -class RequiredSPIRVCapabilityModifier : public Modifier -{ - SLANG_AST_CLASS(RequiredSPIRVCapabilityModifier) - int32_t capability; - String extensionName; -}; - // A modifier to tag something as an intrinsic that requires // a certain CUDA SM version to be enabled when used. Specified as "major.minor" class RequiredCUDASMVersionModifier : public Modifier diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 62e7d428b..0ba637978 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -3310,7 +3310,12 @@ struct SPIRVEmitContext // Find phi arguments from incoming branch instructions that target `block`. for (auto use = block->firstUse; use; use = use->nextUse) { - auto branchInst = use->getUser(); + auto branchInst = as(use->getUser()); + if (!branchInst) + continue; + if (branchInst->getTargetBlock() != inst->getParent()) + continue; + UInt argStartIndex = 0; switch (branchInst->getOp()) { @@ -4742,27 +4747,6 @@ struct SPIRVEmitContext } } - void handleRequiredCapabilitiesImpl(IRInst* inst) - { - for (auto decoration : inst->getDecorations()) - { - switch (decoration->getOp()) - { - default: - break; - case kIROp_RequireSPIRVCapabilityDecoration: - requireSPIRVCapability((SpvCapability)getIntVal(decoration->getOperand(0))); - if (decoration->getOperandCount() == 2) - { - auto stringLit = as(decoration->getOperand(1)); - if (stringLit->getStringSlice().getLength()) - ensureExtensionDeclaration(stringLit->getStringSlice()); - } - break; - } - } - } - SPIRVEmitContext(IRModule* module, TargetRequest* target, DiagnosticSink* sink) : SPIRVEmitSharedContext(module, target, sink) , m_irModule(module) diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 6a49e9842..bbf6885a8 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -10,6 +10,7 @@ #include "slang-ir-byte-address-legalize.h" #include "slang-ir-collect-global-uniforms.h" #include "slang-ir-cleanup-void.h" +#include "slang-ir-composite-reg-to-mem.h" #include "slang-ir-dce.h" #include "slang-ir-diff-call.h" #include "slang-ir-autodiff.h" @@ -360,7 +361,7 @@ Result linkAndOptimizeIR( // Lower all the LValue implict casts (used for out/inout/ref scenarios) lowerLValueCast(targetRequest, irModule); - simplifyIR(irModule, sink); + simplifyIR(irModule, IRSimplificationOptions::getDefault(), sink); // Fill in default matrix layout into matrix types that left layout unspecified. specializeMatrixLayout(codeGenContext->getTargetReq(), irModule); @@ -472,7 +473,7 @@ Result linkAndOptimizeIR( validateIRModuleIfEnabled(codeGenContext, irModule); - simplifyIR(irModule, sink); + simplifyIR(irModule, IRSimplificationOptions::getFast(), sink); if (!ArtifactDescUtil::isCpuLikeTarget(artifactDesc)) { @@ -501,7 +502,7 @@ Result linkAndOptimizeIR( // up downstream passes like type legalization, so we // will run a DCE pass to clean up after the specialization. // - simplifyIR(irModule, sink); + simplifyIR(irModule, IRSimplificationOptions::getDefault(), sink); validateIRModuleIfEnabled(codeGenContext, irModule); @@ -591,7 +592,7 @@ Result linkAndOptimizeIR( // to see if we can clean up any temporaries created by legalization. // (e.g., things that used to be aggregated might now be split up, // so that we can work with the individual fields). - simplifyIR(irModule, sink); + simplifyIR(irModule, IRSimplificationOptions::getFast(), sink); #if 0 dumpIRIfEnabled(codeGenContext, irModule, "AFTER SSA"); @@ -924,12 +925,20 @@ Result linkAndOptimizeIR( // bit_cast on basic types. lowerBitCast(targetRequest, irModule); - eliminateMultiLevelBreak(irModule); if (isKhronosTarget(targetRequest) && targetRequest->shouldEmitSPIRVDirectly()) - performIntrinsicFunctionFunctionInlining(irModule); + { + //performIntrinsicFunctionFunctionInlining(irModule); + performSpirvInlining(irModule); + eliminateDeadCode(irModule); + } + eliminateMultiLevelBreak(irModule); - simplifyIR(irModule, sink); + { + IRSimplificationOptions simplificationOptions = IRSimplificationOptions::getFast(); + simplificationOptions.cfgOptions.removeTrivialSingleIterationLoops = true; + simplifyIR(irModule, IRSimplificationOptions::getFast(), sink); + } // As a late step, we need to take the SSA-form IR and move things *out* // of SSA form, by eliminating all "phi nodes" (block parameters) and @@ -956,7 +965,13 @@ Result linkAndOptimizeIR( } // We only want to accumulate locations if liveness tracking is enabled. - eliminatePhis(livenessMode, irModule); + PhiEliminationOptions phiEliminationOptions; + if (isKhronosTarget(targetRequest) && targetRequest->shouldEmitSPIRVDirectly()) + { + phiEliminationOptions.eliminateCompositeTypedPhiOnly = false; + phiEliminationOptions.useRegisterAllocation = true; + } + eliminatePhis(livenessMode, irModule, phiEliminationOptions); #if 0 dumpIRIfEnabled(codeGenContext, irModule, "PHIS ELIMINATED"); #endif @@ -1000,7 +1015,7 @@ Result linkAndOptimizeIR( } // Run a final round of simplifications to clean up unused things after phi-elimination. - simplifyNonSSAIR(irModule); + simplifyNonSSAIR(irModule, IRSimplificationOptions::getFast()); // We include one final step to (optionally) dump the IR and validate // it after all of the optimization passes are complete. This should @@ -1263,15 +1278,17 @@ SlangResult emitSPIRVForEntryPointsDirectly( List spirv, outSpirv; emitSPIRVFromIR(codeGenContext, irModule, irEntryPoints, spirv); +#if 0 String optErr; if (SLANG_FAILED(optimizeSPIRV(spirv, optErr, outSpirv))) { codeGenContext->getSink()->diagnose(SourceLoc(), Diagnostics::spirvOptFailed, optErr); - outSpirv = _Move(spirv); + spirv = _Move(outSpirv); } +#endif auto artifact = ArtifactUtil::createArtifactForCompileTarget(asExternal(codeGenContext->getTargetFormat())); - artifact->addRepresentationUnknown(ListBlob::moveCreate(outSpirv)); + artifact->addRepresentationUnknown(ListBlob::moveCreate(spirv)); ArtifactUtil::addAssociated(artifact, linkedIR.metadata); diff --git a/source/slang/slang-ir-array-reg-to-mem.cpp b/source/slang/slang-ir-array-reg-to-mem.cpp deleted file mode 100644 index 34bd5b148..000000000 --- a/source/slang/slang-ir-array-reg-to-mem.cpp +++ /dev/null @@ -1,88 +0,0 @@ -#include "slang-ir-array-reg-to-mem.h" - -#include "slang-ir.h" -#include "slang-ir-insts.h" -#include "slang-ir-util.h" - -namespace Slang -{ - bool eliminateArrayTypeParameters(IRFunc* func) - { - IRBuilder builder(func); - bool changed = false; - List arrayParamIds; - UInt idx = 0; - List paramWorkList; - for (auto param : func->getParams()) - { - if (auto arrayType = as(param->getFullType())) - { - paramWorkList.add(param); - arrayParamIds.add(idx); - } - idx++; - } - for (auto param : paramWorkList) - { - // We have an array type parameter, so we need to replace it with a pointer to the array - // type. - // - // We will also need to insert a `load` instruction at the start of the function body - // to load the actual pointer value from the parameter. - // - if (auto arrayType = as(param->getFullType())) - { - changed = true; - auto ptrArrayType = builder.getPtrType(arrayType); - auto newParam = builder.createParam(ptrArrayType); - newParam->insertBefore(param); - setInsertAfterOrdinaryInst(&builder, param); - auto regVal = builder.emitLoad(newParam); - param->replaceUsesWith(regVal); - param->removeAndDeallocate(); - } - } - if (changed) - { - // The function is modified, we need to also update its type. - List paramTypes; - for (auto param : func->getParams()) - { - paramTypes.add(param->getFullType()); - } - auto newFuncType = builder.getFuncType((UInt)paramTypes.getCount(), paramTypes.getBuffer(), func->getResultType()); - func->setFullType(newFuncType); - - // Update all the call sites to pass the arrays by pointer. - traverseUses(func, [&](IRUse* use) - { - if (const auto call = as(use->getUser())) - { - builder.setInsertBefore(call); - for (auto paramId : arrayParamIds) - { - auto arg = call->getArg(paramId); - SLANG_ASSERT(as(paramTypes[paramId])); - auto var = builder.emitVar(as(paramTypes[paramId])->getValueType()); - builder.emitStore(var, arg); - call->setArg(paramId, var); - } - } - }); - } - return changed; - } - - bool eliminateArrayTypeSSARegisters(IRModule* module) - { - bool changed = false; - for (auto inst : module->getGlobalInsts()) - { - if (auto func = as(inst)) - { - changed |= eliminateArrayTypeParameters(func); - } - } - return changed; - } -} diff --git a/source/slang/slang-ir-array-reg-to-mem.h b/source/slang/slang-ir-array-reg-to-mem.h deleted file mode 100644 index 941125e1c..000000000 --- a/source/slang/slang-ir-array-reg-to-mem.h +++ /dev/null @@ -1,16 +0,0 @@ -// slang-ir-array-reg-to-mem.h -#pragma once - -namespace Slang -{ - struct IRModule; - struct IRCall; - struct IRInst; - struct IRFunc; - - /// Eliminate SSA registers and IRParams of array type and turn them into pointers to memory objects. - bool eliminateArrayTypeSSARegisters(IRModule* module); - - bool eliminateArrayTypeParameters(IRFunc* func); - -} diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp index 3602e77ae..bb0b4a8c6 100644 --- a/source/slang/slang-ir-autodiff-cfg-norm.cpp +++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp @@ -798,7 +798,7 @@ void normalizeCFG( // Remove phis to simplify our pass. We'll add them back in later // with constructSSA. // - eliminatePhisInFunc(LivenessMode::Disabled, func->getModule(), func, false); + eliminatePhisInFunc(LivenessMode::Disabled, func->getModule(), func, PhiEliminationOptions::getFast()); CFGNormalizationContext context = {module, options.sink}; CFGNormalizationPass cfgPass(context); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 2465c1b74..8f3ab2a34 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -1702,7 +1702,7 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func) if (SLANG_SUCCEEDED(result)) { disableIRValidationAtInsert(); - simplifyFunc(func); + simplifyFunc(func, IRSimplificationOptions::getDefault()); enableIRValidationAtInsert(); } return result; diff --git a/source/slang/slang-ir-composite-reg-to-mem.cpp b/source/slang/slang-ir-composite-reg-to-mem.cpp new file mode 100644 index 000000000..243a0e2b0 --- /dev/null +++ b/source/slang/slang-ir-composite-reg-to-mem.cpp @@ -0,0 +1,202 @@ +#include "slang-ir-composite-reg-to-mem.h" + +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir-dce.h" + +namespace Slang +{ + struct RegisterReplacementWorkItem + { + IRInst* ssaValue; + IRInst* addr; + IRInst* initialStore; + }; + + void replaceRegisterUseWithAddrUse( + List& workList, + IRInst* ssaValue, + IRInst* addr, + IRInst* initialStore) + { + IRBuilder builder(ssaValue); + traverseUses(ssaValue, [&](IRUse* use) + { + auto user = use->getUser(); + if (user == initialStore) + return; + builder.setInsertBefore(user); + IRInst* newAddr = nullptr; + // If the user is itself a getElement/getField inst, + // we want to follow that chain and recursively replace + // their users. + if (auto getElementUser = as(user)) + { + if (getElementUser->getOperands() == use) + { + newAddr = builder.emitElementAddress( + builder.getPtrType(user->getFullType()), + addr, + getElementUser->getIndex()); + } + } + else if (auto getFieldUser = as(user)) + { + if (getFieldUser->getOperands() == use) + { + newAddr = builder.emitFieldAddress( + builder.getPtrType(user->getFullType()), + addr, + getFieldUser->getField()); + } + } + if (newAddr) + { + workList.add(RegisterReplacementWorkItem{ user, newAddr, nullptr }); + } + else + { + // For all other uses, we emit a load from addr and use it. + auto val = builder.emitLoad(addr); + builder.replaceOperand(use, val); + } + }); + } + + void replaceRegisterUseWithAddrUse(IRInst* ssaValue, IRInst* addr, IRInst* initialStore) + { + List workList, pendingWorkList; + workList.add(RegisterReplacementWorkItem{ ssaValue, addr, initialStore }); + + while (workList.getCount()) + { + for (auto item : workList) + { + replaceRegisterUseWithAddrUse(pendingWorkList, item.ssaValue, item.addr, item.initialStore); + } + workList.swapWith(pendingWorkList); + pendingWorkList.clear(); + } + } + + void convertCompositeTypeParametersToPointers(IRFunc* func) + { + IRBuilder builder(func); + List compositeParamIds; + UInt idx = 0; + List paramWorkList; + if (!func->findDecoration()) + { + // Only translate function parameters for non entry points. + for (auto param : func->getParams()) + { + if (as(param->getFullType()) || + as(param->getFullType())) + { + paramWorkList.add(param); + compositeParamIds.add(idx); + } + idx++; + } + } + for (auto param : paramWorkList) + { + // We have a composite type parameter, so we need to replace it with a pointer. + // + + auto ptrCompositeType = builder.getPtrType(param->getFullType()); + auto newParam = builder.createParam(ptrCompositeType); + newParam->insertBefore(param); + replaceRegisterUseWithAddrUse(param, newParam, nullptr); + param->removeAndDeallocate(); + } + if (paramWorkList.getCount()) + { + // The function is modified, we need to also update its type. + List paramTypes; + for (auto param : func->getParams()) + { + paramTypes.add(param->getFullType()); + } + auto newFuncType = builder.getFuncType((UInt)paramTypes.getCount(), paramTypes.getBuffer(), func->getResultType()); + func->setFullType(newFuncType); + + // Update all the call sites to pass the composite by pointer. + traverseUses(func, [&](IRUse* use) + { + if (const auto call = as(use->getUser())) + { + builder.setInsertBefore(call); + for (auto paramId : compositeParamIds) + { + auto arg = call->getArg(paramId); + SLANG_ASSERT(as(paramTypes[paramId])); + auto var = builder.emitVar(as(paramTypes[paramId])->getValueType()); + builder.emitStore(var, arg); + call->setArg(paramId, var); + } + } + }); + } + + // Now work through all the local values and process uses of `Load(composite)`. + for (auto block : func->getBlocks()) + { + for (auto inst : block->getModifiableChildren()) + { + if (!as(inst->getDataType()) && + !as(inst->getDataType())) + continue; + if (inst->getParent() != block) + continue; + IRInst* tempVar = nullptr; + IRInst* initialStore = nullptr; + builder.setInsertAfter(inst); + switch (inst->getOp()) + { + case kIROp_Load: + { + auto ptr = inst->getOperand(0); + auto rootPtr = getRootAddr(ptr); + if (as(rootPtr->getDataType()) || + as(rootPtr->getDataType())) + { + tempVar = ptr; + } + else + { + tempVar = builder.emitVar(inst->getFullType()); + initialStore = builder.emitStore(tempVar, inst); + } + break; + } + case kIROp_Call: + { + tempVar = builder.emitVar(inst->getFullType()); + initialStore = builder.emitStore(tempVar, inst); + break; + } + default: + break; + } + + if (!tempVar) + continue; + replaceRegisterUseWithAddrUse(inst, tempVar, initialStore); + } + } + eliminateDeadCode(func); + } + + void convertCompositeTypeParametersToPointers(IRModule* module) + { + for (auto inst : module->getGlobalInsts()) + { + if (auto func = as(inst)) + { + convertCompositeTypeParametersToPointers(func); + } + } + } +} diff --git a/source/slang/slang-ir-composite-reg-to-mem.h b/source/slang/slang-ir-composite-reg-to-mem.h new file mode 100644 index 000000000..9ccfc8734 --- /dev/null +++ b/source/slang/slang-ir-composite-reg-to-mem.h @@ -0,0 +1,13 @@ +// slang-ir-composite-reg-to-mem.h +#pragma once + +namespace Slang +{ + struct IRModule; + struct IRCall; + struct IRInst; + struct IRFunc; + + /// Convert parameters of composite type into pointers and modify the callsites accordingly. + void convertCompositeTypeParametersToPointers(IRModule* module); +} diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp index 5ff71a248..eafdbbcec 100644 --- a/source/slang/slang-ir-eliminate-multilevel-break.cpp +++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp @@ -303,7 +303,6 @@ struct EliminateMultiLevelBreakContext void processFunc(IRGlobalValueWithCode* func) { - normalizeBranchesIntoBreakBlocks(func); // If func does not have any multi-level breaks, return. @@ -316,7 +315,7 @@ struct EliminateMultiLevelBreakContext } // To make things easy, eliminate Phis before perform transformations. - eliminatePhisInFunc(LivenessMode::Disabled, irModule, func); + eliminatePhisInFunc(LivenessMode::Disabled, irModule, func, PhiEliminationOptions::getFast()); // Before modifying the cfg, we gather all required info from the existing cfg. FuncContext funcInfo; diff --git a/source/slang/slang-ir-eliminate-phis.cpp b/source/slang/slang-ir-eliminate-phis.cpp index 1023e6148..b17fad6ec 100644 --- a/source/slang/slang-ir-eliminate-phis.cpp +++ b/source/slang/slang-ir-eliminate-phis.cpp @@ -1,6 +1,7 @@ // slang-ir-eliminate-phis.cpp #include "slang-ir-eliminate-phis.h" #include "slang-ir-ssa-register-allocate.h" +#include "slang-ir-util.h" // This file implements a pass to take code in the Slang IR out out SSA form // by eliminating all "phi nodes." @@ -68,20 +69,20 @@ struct PhiEliminationContext IRModule* m_module = nullptr; IRBuilder m_builder; LivenessMode m_livenessMode; - bool m_useRegisterAllocation; + PhiEliminationOptions m_options; PhiEliminationContext(LivenessMode livenessMode, IRModule* module) : m_module(module) , m_builder(module) , m_livenessMode(livenessMode) - , m_useRegisterAllocation(true) + , m_options() {} - PhiEliminationContext(LivenessMode livenessMode, IRModule* module, bool useRegisterAllocation) + PhiEliminationContext(LivenessMode livenessMode, IRModule* module, PhiEliminationOptions options) : m_module(module) , m_builder(module) , m_livenessMode(livenessMode) - , m_useRegisterAllocation(useRegisterAllocation) + , m_options(options) {} // We start with the top-down logic of the pass, which is to process @@ -220,9 +221,9 @@ struct PhiEliminationContext m_func = func; m_dominatorTree = nullptr; - if (m_useRegisterAllocation) + if (m_options.useRegisterAllocation) { - m_registerAllocation = allocateRegistersForFunc(func, m_dominatorTree); + m_registerAllocation = allocateRegistersForFunc(func, m_dominatorTree, m_options.eliminateCompositeTypedPhiOnly); m_mapRegToTempVar = createTempVarForInsts(func); } } @@ -435,7 +436,7 @@ struct PhiEliminationContext { Index paramIndex = paramCounter++; mapParamToIndex.add(param, paramIndex); - + IRInst* temp = nullptr; // Have we already allocated a register for this inst? @@ -445,7 +446,9 @@ struct PhiEliminationContext m_mapRegToTempVar.tryGetValue(registerInfo->get(), temp); } - if (!temp) + bool shouldAllocTemp = !m_options.eliminateCompositeTypedPhiOnly || isCompositeType(param->getFullType()); + + if (!temp && shouldAllocTemp) { // Note that the `emitVar` operation expects to be passed the // type *stored* in the variable, but the IR `var` instruction @@ -471,7 +474,7 @@ struct PhiEliminationContext PhiInfo phiInfo; auto& paramInfo = phiInfo.param; paramInfo.param = param; - paramInfo.temp = cast(temp); + paramInfo.temp = as(temp); phiInfos.add(phiInfo); } } @@ -515,6 +518,9 @@ struct PhiEliminationContext auto& paramInfo = phiInfo.param; auto param = paramInfo.param; auto temp = paramInfo.temp; + + if (!temp) + continue; // We will repeatedly replace whatever the *first* // use of `param` is, until there are no more uses @@ -738,6 +744,10 @@ struct PhiEliminationContext { assignment.state = kState_Done; } + else if (!dstParam.temp) + { + assignment.state = kState_Done; + } else { // Otherwise we start out assuming that the assignment is ready @@ -909,17 +919,24 @@ struct PhiEliminationContext static const Count kMaxNewOperandCount = 3; SLANG_ASSERT(newOperandCount <= kMaxNewOperandCount); - IRInst* newOperands[kMaxNewOperandCount] = {}; + ShortList newOperands; for (Index i = 0; i < newOperandCount; ++i) { - newOperands[i] = oldBranch->getOperand(i); + newOperands.add(oldBranch->getOperand(i)); + } + + // Add operands for any remaining phi parameters that has not been eliminated. + for (UInt i = 0; i < (UInt)phiInfos.getCount(); i++) + { + if (!phiInfos[i].param.temp) + newOperands.add(oldBranch->getArg(i)); } auto newBranch = m_builder.emitIntrinsicInst( oldBranch->getFullType(), oldBranch->getOp(), - newOperandCount, - newOperands); + newOperands.getCount(), + newOperands.getArrayView().getBuffer()); oldBranch->transferDecorationsTo(newBranch); // TODO: We could consider just modifying `branch` in-place by clearing @@ -1122,9 +1139,9 @@ struct PhiEliminationContext } }; -void eliminatePhis(LivenessMode livenessMode, IRModule* module, bool useRegisterAllocation) +void eliminatePhis(LivenessMode livenessMode, IRModule* module, PhiEliminationOptions options) { - PhiEliminationContext context(livenessMode, module, useRegisterAllocation); + PhiEliminationContext context(livenessMode, module, options); context.eliminatePhisInModule(); } @@ -1132,9 +1149,9 @@ void eliminatePhisInFunc( LivenessMode livenessMode, IRModule* module, IRGlobalValueWithCode* func, - bool useRegisterAllocation) + PhiEliminationOptions options) { - PhiEliminationContext context(livenessMode, module, useRegisterAllocation); + PhiEliminationContext context(livenessMode, module, options); context.eliminatePhisInFunc(func); } diff --git a/source/slang/slang-ir-eliminate-phis.h b/source/slang/slang-ir-eliminate-phis.h index 9bfbe51b8..6655a16e4 100644 --- a/source/slang/slang-ir-eliminate-phis.h +++ b/source/slang/slang-ir-eliminate-phis.h @@ -8,6 +8,13 @@ namespace Slang struct CodeGenContext; struct IRModule; + struct PhiEliminationOptions + { + bool eliminateCompositeTypedPhiOnly = false; + bool useRegisterAllocation = true; + static PhiEliminationOptions getFast() { return PhiEliminationOptions{ false, false }; } + }; + /// Eliminate all "phi nodes" from the given `module`. /// /// This process moves the code in `module` *out* of SSA form, @@ -15,11 +22,11 @@ namespace Slang /// are not themselves based on an SSA representation. /// /// If livenessMode is enabled LiveRangeStarts will be inserted into the module. - void eliminatePhis(LivenessMode livenessMode, IRModule* module, bool useRegisterAllocation = true); + void eliminatePhis(LivenessMode livenessMode, IRModule* module, PhiEliminationOptions options); void eliminatePhisInFunc( LivenessMode livenessMode, IRModule* module, IRGlobalValueWithCode* func, - bool useRegisterAllocation = true); + PhiEliminationOptions options); } diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index 8412a1912..3b01d7dde 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -252,9 +252,6 @@ struct InliningPassBase { case kIROp_IntrinsicOpDecoration: return true; - case kIROp_RequireSPIRVCapabilityDecoration: - // Don't inline a function with spirv capability decoration to avoid losing it. - return false; } } @@ -947,7 +944,7 @@ void performGLSLResourceReturnFunctionInlining(IRModule* module) while (changed) { changed = pass.considerAllCallSites(); - simplifyIR(module); + simplifyIR(module, IRSimplificationOptions::getFast()); } } @@ -964,8 +961,6 @@ struct IntrinsicFunctionInliningPass : InliningPassBase auto func = as(getResolvedInstForDecorations(info.callee)); if (!func) return false; - if (func->findDecorationImpl(kIROp_RequireSPIRVCapabilityDecoration)) - return false; auto returnInst = as(func->getFirstBlock()->getTerminator()); if (!returnInst) return false; @@ -1025,4 +1020,32 @@ bool inlineCall(IRCall* call) } +struct SpirvInliningPass : InliningPassBase +{ + typedef InliningPassBase Super; + + SpirvInliningPass(IRModule* module) + : Super(module) + {} + + bool shouldInline(CallSiteInfo const& info) + { + if (!info.callee->findDecoration()) + return true; + return false; + } +}; + +void performSpirvInlining(IRModule* module) +{ + SLANG_PROFILE; + while (true) + { + SpirvInliningPass pass(module); + if (pass.considerAllCallSites()) + continue; + break; + } +} + } // namespace Slang diff --git a/source/slang/slang-ir-inline.h b/source/slang/slang-ir-inline.h index 539bb26c0..5e888ad07 100644 --- a/source/slang/slang-ir-inline.h +++ b/source/slang/slang-ir-inline.h @@ -34,6 +34,9 @@ namespace Slang /// Inline simple intrinsic functions whose definition is a single asm block. void performIntrinsicFunctionFunctionInlining(IRModule* module); + /// Inline all functions for SPIRV emit. + void performSpirvInlining(IRModule* module); + /// Inline a specific call. bool inlineCall(IRCall* call); } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 5a8d7fe06..bd32a1896 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -688,7 +688,6 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(VulkanHitObjectAttributesDecoration, vulkanHitObjectAttributes, 0, 0) INST(RequireSPIRVVersionDecoration, requireSPIRVVersion, 1, 0) - INST(RequireSPIRVCapabilityDecoration, requireSPIRVCapability, 1, 0) INST(RequireGLSLVersionDecoration, requireGLSLVersion, 1, 0) INST(RequireGLSLExtensionDecoration, requireGLSLExtension, 1, 0) INST(RequireCUDASMVersionDecoration, requireCUDASMVersion, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 9ff2df889..070f989b5 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -4264,25 +4264,6 @@ public: addDecoration(value, kIROp_RequireSPIRVVersionDecoration, getIntValue(getBasicType(BaseType::UInt64), intValue)); } - void addRequireSPIRVCapabilityDecoration(IRInst* value, int32_t capabilityName, UnownedStringSlice extensionName) - { - if (extensionName.getLength()) - { - addDecoration( - value, - kIROp_RequireSPIRVCapabilityDecoration, - getIntValue(getIntType(), IRIntegerValue(capabilityName)), - getStringValue(extensionName)); - } - else - { - addDecoration( - value, - kIROp_RequireSPIRVCapabilityDecoration, - getIntValue(getIntType(), IRIntegerValue(capabilityName))); - } - } - void addRequireCUDASMVersionDecoration(IRInst* value, const SemanticVersion& version) { SemanticVersion::IntegerType intValue = version.toInteger(); diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp index c4ef1650c..b5af2d974 100644 --- a/source/slang/slang-ir-loop-unroll.cpp +++ b/source/slang/slang-ir-loop-unroll.cpp @@ -459,7 +459,7 @@ bool unrollLoopsInFunc( // Make sure we simplify things as much as possible before // attempting to potentially unroll outer loop. - simplifyCFG(func); + simplifyCFG(func, CFGSimplificationOptions::getDefault()); eliminateDeadCode(func); } return true; diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index 722515692..2002b42cc 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -255,7 +255,7 @@ namespace Slang // real RTTI objects and witness tables. specializeRTTIObjects(&sharedContext, sink); - simplifyIR(module); + simplifyIR(module, IRSimplificationOptions::getFast()); lowerTuples(module, sink); if (sink->getErrorCount() != 0) diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp index a7e89eb4f..b6b7823c9 100644 --- a/source/slang/slang-ir-peephole.cpp +++ b/source/slang/slang-ir-peephole.cpp @@ -359,6 +359,7 @@ struct PeepholeContext : InstPassBase changed = true; break; } + startIndex += vecSize->getValue(); } else { @@ -807,7 +808,6 @@ struct PeepholeContext : InstPassBase case kIROp_Or: changed = tryOptimizeArithmeticInst(inst); break; - case kIROp_Param: { auto block = as(inst->parent); diff --git a/source/slang/slang-ir-reachability.cpp b/source/slang/slang-ir-reachability.cpp index 1a3cab386..1a4aa271b 100644 --- a/source/slang/slang-ir-reachability.cpp +++ b/source/slang/slang-ir-reachability.cpp @@ -5,35 +5,64 @@ namespace Slang // Computes whether block1 can reach block2. // A block is considered not reachable from itself unless there is a backedge in the CFG. -bool ReachabilityContext::computeReachability(IRBlock* block1, IRBlock* block2) -{ - workList.clear(); - reachableBlocks.clear(); - workList.add(block1); - for (Index i = 0; i < workList.getCount(); i++) + ReachabilityContext::ReachabilityContext(IRGlobalValueWithCode* code) { - auto src = workList[i]; - for (auto successor : src->getSuccessors()) + int id = 0; + for (auto block : code->getBlocks()) + { + mapBlockToId[block] = id; + id++; + allBlocks.add(block); + } + sourceBlocks.setCount(allBlocks.getCount()); + for (auto &srcBlock : sourceBlocks) + srcBlock.resizeAndClear(allBlocks.getCount()); + + if (allBlocks.getCount() == 0) + return; + + List workList; + List pendingWorkList; + workList.add(allBlocks[0]); + while (workList.getCount()) { - if (successor == block2) - return true; - if (reachableBlocks.add(successor)) - workList.add(successor); + pendingWorkList.clear(); + for (Index i = 0; i < workList.getCount(); i++) + { + auto src = workList[i]; + auto srcId = mapBlockToId.getValue(src); + for (auto successor : src->getSuccessors()) + { + auto successorId = mapBlockToId.getValue(successor); + auto& blockSet = sourceBlocks[successorId]; + bool changed = false; + if (!blockSet.contains(srcId)) + { + blockSet.add(srcId); + changed = true; + } + if (!blockSet.contains(sourceBlocks[srcId])) + { + blockSet.unionWith(sourceBlocks[srcId]); + changed = true; + } + if (changed) + pendingWorkList.add(successor); + } + } + workList.swapWith(pendingWorkList); } + } - return false; -} -bool ReachabilityContext::isBlockReachable(IRBlock* from, IRBlock* to) -{ - BlockPair pair; - pair.first = from; - pair.second = to; - bool result = false; - if (reachabilityResults.tryGetValue(pair, result)) - return result; - result = computeReachability(from, to); - reachabilityResults[pair] = result; - return result; -} + bool ReachabilityContext::isBlockReachable(IRBlock* from, IRBlock* to) + { + if (!from) return false; + if (!to) return false; + int* fromId = mapBlockToId.tryGetValue(from); + int* toId = mapBlockToId.tryGetValue(to); + if (!fromId || !toId) + return true; + return sourceBlocks[*toId].contains(*fromId); + } } diff --git a/source/slang/slang-ir-reachability.h b/source/slang/slang-ir-reachability.h index 74463b7fe..ef6a182df 100644 --- a/source/slang/slang-ir-reachability.h +++ b/source/slang/slang-ir-reachability.h @@ -9,39 +9,18 @@ namespace Slang // A context for computing and caching reachability between blocks on the CFG. struct ReachabilityContext { - struct BlockPair - { - IRBlock* first; - IRBlock* second; - bool operator == (const BlockPair& other) const - { - return first == other.first && second == other.second; - } - HashCode getHashCode() const - { - Hasher h; - h.hashValue(first); - h.hashValue(second); - return h.getResult(); - } - }; - Dictionary reachabilityResults; - - List workList; - HashSet reachableBlocks; + Dictionary mapBlockToId; + List allBlocks; + List sourceBlocks; // sourcesBlocks[i] stores the set of blocks from which block i can be reached. - // Computes whether block1 can reach block2. - // A block is considered not reachable from itself unless there is a backedge in the CFG. - bool computeReachability(IRBlock* block1, IRBlock* block2); + ReachabilityContext() = default; + ReachabilityContext(IRGlobalValueWithCode* code); bool isBlockReachable(IRBlock* from, IRBlock* to); bool isInstReachable(IRInst* inst1, IRInst* inst2) { - if (isBlockReachable(as(inst1->getParent()), as(inst2->getParent()))) - return true; - - // If the parent blocks are not reachable, but inst1 and inst2 are in the same block, + // If inst1 and inst2 are in the same block, // we test if inst2 appears after inst1. if (inst1->getParent() == inst2->getParent()) { @@ -52,7 +31,7 @@ struct ReachabilityContext } } - return false; + return isBlockReachable(as(inst1->getParent()), as(inst2->getParent())); } }; diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp index 42a31ea91..c7986cfbc 100644 --- a/source/slang/slang-ir-redundancy-removal.cpp +++ b/source/slang/slang-ir-redundancy-removal.cpp @@ -53,9 +53,10 @@ struct RedundancyRemovalContext return changed; } - bool removeRedundancyInBlock(DeduplicateContext& deduplicateContext, IRGlobalValueWithCode* func, IRBlock* block) + bool removeRedundancyInBlock(Dictionary& mapBlockToDedupContext, IRGlobalValueWithCode* func, IRBlock* block) { bool result = false; + auto& deduplicateContext = mapBlockToDedupContext.getValue(block); for (auto instP : block->getModifiableChildren()) { auto resultInst = deduplicateContext.deduplicate(instP, [&](IRInst* inst) @@ -82,9 +83,8 @@ struct RedundancyRemovalContext } for (auto child : dom->getImmediatelyDominatedBlocks(block)) { - DeduplicateContext subContext; + DeduplicateContext& subContext = mapBlockToDedupContext.getValue(child); subContext.deduplicateMap = deduplicateContext.deduplicateMap; - result |= removeRedundancyInBlock(subContext, func, child); } return result; } @@ -116,8 +116,28 @@ bool removeRedundancyInFunc(IRGlobalValueWithCode* func) RedundancyRemovalContext context; context.dom = computeDominatorTree(func); - DeduplicateContext deduplicateCtx; - bool result = context.removeRedundancyInBlock(deduplicateCtx, func, root); + Dictionary mapBlockToDeduplicateContext; + for (auto block : func->getBlocks()) + { + mapBlockToDeduplicateContext[block] = DeduplicateContext(); + } + List workList, pendingWorkList; + workList.add(root); + bool result = false; + while (workList.getCount()) + { + for (auto block : workList) + { + result |= context.removeRedundancyInBlock(mapBlockToDeduplicateContext, func, block); + + for (auto child : context.dom->getImmediatelyDominatedBlocks(block)) + { + pendingWorkList.add(child); + } + } + workList.swapWith(pendingWorkList); + pendingWorkList.clear(); + } if (auto normalFunc = as(func)) { result |= eliminateRedundantLoadStore(normalFunc); diff --git a/source/slang/slang-ir-sccp.cpp b/source/slang/slang-ir-sccp.cpp index d874514fe..fc3766e63 100644 --- a/source/slang/slang-ir-sccp.cpp +++ b/source/slang/slang-ir-sccp.cpp @@ -720,6 +720,39 @@ struct SCCPContext { SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0) auto c0 = as(v0.value); + + uint64_t sourceValueBits = 0; + switch (c0->getDataType()->getOp()) + { + case kIROp_FloatType: + { + float fval = (float)c0->value.floatVal; + memcpy(&sourceValueBits, &fval, sizeof(fval)); + break; + } + case kIROp_DoubleType: + { + double dval = c0->value.floatVal; + memcpy(&sourceValueBits, &dval, sizeof(dval)); + break; + } + case kIROp_BoolType: + { + sourceValueBits = c0->value.intVal; + break; + } + default: + if (isIntegralType(c0->getDataType())) + { + sourceValueBits = c0->value.intVal; + } + else + { + return LatticeVal::getAny(); + } + break; + } + IRInst* resultVal = nullptr; switch (type->getOp()) { @@ -729,7 +762,7 @@ struct SCCPContext case kIROp_IntPtrType: case kIROp_UIntPtrType: #endif - resultVal = getBuilder()->getIntValue(type, c0->value.intVal); + resultVal = getBuilder()->getIntValue(type, sourceValueBits); break; case kIROp_IntType: case kIROp_UIntType: @@ -737,21 +770,17 @@ struct SCCPContext case kIROp_IntPtrType: case kIROp_UIntPtrType: #endif - { - float val = (float)c0->value.floatVal; - uint32_t intVal = (uint32_t)FloatAsInt(val); - resultVal = getBuilder()->getIntValue(type, intVal); - } + resultVal = getBuilder()->getIntValue(type, (uint32_t)sourceValueBits); break; case kIROp_FloatType: { - uint32_t val = (uint32_t)c0->value.intVal; + uint32_t val = (uint32_t)sourceValueBits; float floatVal = IntAsFloat((int)val); resultVal = getBuilder()->getFloatValue(type, floatVal); } break; case kIROp_DoubleType: - resultVal = getBuilder()->getFloatValue(type, Int64AsDouble(c0->value.intVal)); + resultVal = getBuilder()->getFloatValue(type, Int64AsDouble(sourceValueBits)); break; default: break; diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp index e00c24bdc..0fe752c8f 100644 --- a/source/slang/slang-ir-simplify-cfg.cpp +++ b/source/slang/slang-ir-simplify-cfg.cpp @@ -15,20 +15,9 @@ struct CFGSimplificationContext { RefPtr regionTree; RefPtr domTree; + Dictionary> relatedAddrMap; }; -static BreakableRegion* findBreakableRegion(Region* region) -{ - for (;;) - { - if (auto b = as(region)) - return b; - region = region->getParent(); - if (!region) - return nullptr; - } -} - static bool isBlockInRegion(IRDominatorTree* domTree, IRTerminatorInst* regionHeader, IRBlock* block) { auto headerBlock = cast(regionHeader->getParent()); @@ -55,10 +44,26 @@ static bool isBlockInRegion(IRDominatorTree* domTree, IRTerminatorInst* regionHe return true; } +static IRInst* findBreakableRegionHeaderInst(IRDominatorTree* domTree, IRBlock* block) +{ + for (auto idom = domTree->getImmediateDominator(block); idom; idom = domTree->getImmediateDominator(idom)) + { + auto terminator = idom->getTerminator(); + switch (terminator->getOp()) + { + case kIROp_Switch: + case kIROp_loop: + return terminator; + } + } + return nullptr; +} + // Test if a loop is trivial: a trivial loop runs for a single iteration without any back edges, and // there is only one break out of the loop at the very end. The function generates `regionTree` if // it is needed and hasn't been generated yet. static bool isTrivialSingleIterationLoop( + CFGSimplificationContext& context, IRGlobalValueWithCode* func, IRLoop* loop) { @@ -80,40 +85,21 @@ static bool isTrivialSingleIterationLoop( // // We need to verify this is a trivial loop by checking if there is any multi-level breaks // that skips out of this loop. - CFGSimplificationContext context; if (!context.domTree) context.domTree = computeDominatorTree(func); - if (!context.regionTree) - context.regionTree = generateRegionTreeForFunc(func, nullptr); - - SimpleRegion* targetBlockRegion = nullptr; - if (!context.regionTree->mapBlockToRegion.tryGetValue(targetBlock, targetBlockRegion)) - return false; - BreakableRegion* loopBreakableRegion = findBreakableRegion(targetBlockRegion); - LoopRegion* loopRegion = as(loopBreakableRegion); - if (!loopRegion) + bool hasMultiLevelBreaks = false; + auto loopBlocks = collectBlocksInRegion(context.domTree, loop, &hasMultiLevelBreaks); + if (hasMultiLevelBreaks) return false; - for (auto block : func->getBlocks()) + for (auto block : loopBlocks) { - if (!context.domTree->dominates(loop->getTargetBlock(), block)) - continue; - if (context.domTree->dominates(loop->getBreakBlock(), block)) - continue; - SimpleRegion* region = nullptr; - if (!context.regionTree->mapBlockToRegion.tryGetValue(block, region)) - return false; - for (auto branchTarget : block->getSuccessors()) { - SimpleRegion* targetRegion = nullptr; - if (!context.regionTree->mapBlockToRegion.tryGetValue(branchTarget, targetRegion)) + if (!context.domTree->dominates(loop->getParent(), branchTarget)) return false; - // If multi-level break out that skips over this loop exists, then this is not a trivial loop. - if (targetRegion->isDescendentOf(loopRegion)) - continue; if (targetBlock != loop->getBreakBlock()) return false; - if (findBreakableRegion(region) != loopRegion) + if (findBreakableRegionHeaderInst(context.domTree, block) != loop) { // If the break is initiated from a nested region, this is not trivial. return false; @@ -164,10 +150,12 @@ static bool isTrivialSingleIterationLoop( return true; } -static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst) +static bool doesLoopHasSideEffect(CFGSimplificationContext& context, ReachabilityContext& reachability, IRGlobalValueWithCode* func, IRLoop* loopInst) { bool hasMultiLevelBreaks = false; - auto blocks = collectBlocksInRegion(func, loopInst, &hasMultiLevelBreaks); + if (!context.domTree) + context.domTree = computeDominatorTree(func); + auto blocks = collectBlocksInRegion(context.domTree.get(), loopInst, &hasMultiLevelBreaks); // We'll currently not deal with loops that contain multi-level breaks. if (hasMultiLevelBreaks) @@ -177,25 +165,26 @@ static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst) for (auto b : blocks) loopBlocks.add(b); - ReachabilityContext reachability = {}; - // Construct a map from a root address to all derived addresses. - Dictionary> relatedAddrMap; - for (auto b : func->getBlocks()) + Dictionary>& relatedAddrMap = context.relatedAddrMap; + if (!relatedAddrMap.getCount()) { - for (auto inst : b->getChildren()) + for (auto b : func->getBlocks()) { - if (as(inst->getDataType())) + for (auto inst : b->getChildren()) { - auto root = getRootAddr(inst); - if (!root) continue; - auto list = relatedAddrMap.tryGetValue(root); - if (!list) + if (as(inst->getDataType())) { - relatedAddrMap.add(root, List()); - list = relatedAddrMap.tryGetValue(root); + auto root = getRootAddr(inst); + if (!root) continue; + auto list = relatedAddrMap.tryGetValue(root); + if (!list) + { + relatedAddrMap.add(root, List()); + list = relatedAddrMap.tryGetValue(root); + } + list->add(inst); } - list->add(inst); } } } @@ -779,7 +768,7 @@ static bool removeTrivialPhiParams(IRBlock* block) return changed; } -static bool processFunc(IRGlobalValueWithCode* func) +static bool processFunc(IRGlobalValueWithCode* func, CFGSimplificationOptions options) { auto firstBlock = func->getFirstBlock(); if (!firstBlock) @@ -787,6 +776,10 @@ static bool processFunc(IRGlobalValueWithCode* func) IRBuilder builder(func->getModule()); + bool isReachabilityContextValid = false; + ReachabilityContext reachabilityContext; + CFGSimplificationContext simplificationContext; + bool changed = false; for (;;) { @@ -815,6 +808,7 @@ static bool processFunc(IRGlobalValueWithCode* func) { loop->continueBlock.set(loop->getTargetBlock()); continueBlock->removeAndDeallocate(); + simplificationContext = CFGSimplificationContext(); changed = true; } @@ -822,7 +816,7 @@ static bool processFunc(IRGlobalValueWithCode* func) // break at the end of the loop, we can remove the header and turn it into // a normal branch. auto targetBlock = loop->getTargetBlock(); - if (isTrivialSingleIterationLoop(func, loop)) + if (options.removeTrivialSingleIterationLoops && isTrivialSingleIterationLoop(simplificationContext, func, loop)) { builder.setInsertBefore(loop); List args; @@ -832,26 +826,44 @@ static bool processFunc(IRGlobalValueWithCode* func) } builder.emitBranch(targetBlock, args.getCount(), args.getBuffer()); loop->removeAndDeallocate(); + simplificationContext = CFGSimplificationContext(); changed = true; } - else if (!doesLoopHasSideEffect(func, loop)) + else if (options.removeSideEffectFreeLoops) { - // The loop isn't computing anything useful outside the loop. - // We can delete the entire loop. - builder.setInsertBefore(loop); - SLANG_ASSERT(loop->getBreakBlock()->getFirstParam() == nullptr); - builder.emitBranch(loop->getBreakBlock()); - loop->removeAndDeallocate(); - changed = true; + if (!isReachabilityContextValid) + { + isReachabilityContextValid = true; + reachabilityContext = ReachabilityContext(func); + } + if (!doesLoopHasSideEffect(simplificationContext, reachabilityContext, func, loop)) + { + // The loop isn't computing anything useful outside the loop. + // We can delete the entire loop. + builder.setInsertBefore(loop); + SLANG_ASSERT(loop->getBreakBlock()->getFirstParam() == nullptr); + builder.emitBranch(loop->getBreakBlock()); + loop->removeAndDeallocate(); + simplificationContext = CFGSimplificationContext(); + changed = true; + } } } else if (auto condBranch = as(block->getTerminator())) { - changed |= trySimplifyIfElse(builder, condBranch); + if (trySimplifyIfElse(builder, condBranch)) + { + simplificationContext = CFGSimplificationContext(); + changed = true; + } } else if (auto switchBranch = as(block->getTerminator())) { - changed |= trySimplifySwitch(builder, switchBranch); + if (trySimplifySwitch(builder, switchBranch)) + { + simplificationContext = CFGSimplificationContext(); + changed = true; + } } // If `block` does not end with an unconditional branch, bail. @@ -867,6 +879,7 @@ static bool processFunc(IRGlobalValueWithCode* func) if (block->hasMoreThanOneUse()) break; changed = true; + simplificationContext = CFGSimplificationContext(); Index paramIndex = 0; auto inst = successor->getFirstDecorationOrChild(); while (inst) @@ -911,7 +924,7 @@ static bool processFunc(IRGlobalValueWithCode* func) return changed; } -bool simplifyCFG(IRModule* module) +bool simplifyCFG(IRModule* module, CFGSimplificationOptions options) { bool changed = false; for (auto inst : module->getGlobalInsts()) @@ -922,15 +935,15 @@ bool simplifyCFG(IRModule* module) } if (auto func = as(inst)) { - changed |= processFunc(func); + changed |= processFunc(func, options); } } return changed; } -bool simplifyCFG(IRGlobalValueWithCode* func) +bool simplifyCFG(IRGlobalValueWithCode* func, CFGSimplificationOptions options) { - return processFunc(func); + return processFunc(func, options); } } // namespace Slang diff --git a/source/slang/slang-ir-simplify-cfg.h b/source/slang/slang-ir-simplify-cfg.h index 6bfa6e2bf..4bc37bd29 100644 --- a/source/slang/slang-ir-simplify-cfg.h +++ b/source/slang/slang-ir-simplify-cfg.h @@ -6,11 +6,19 @@ namespace Slang struct IRModule; struct IRGlobalValueWithCode; + struct CFGSimplificationOptions + { + bool removeTrivialSingleIterationLoops = true; + bool removeSideEffectFreeLoops = true; + static CFGSimplificationOptions getDefault() { return CFGSimplificationOptions(); } + static CFGSimplificationOptions getFast() { return CFGSimplificationOptions{ false, false }; } + }; + /// Simplifies control flow graph by merging basic blocks that /// forms a simple linear chain. /// Returns true if changed. - bool simplifyCFG(IRModule* module); + bool simplifyCFG(IRModule* module, CFGSimplificationOptions options); - bool simplifyCFG(IRGlobalValueWithCode* func); + bool simplifyCFG(IRGlobalValueWithCode* func, CFGSimplificationOptions options); } diff --git a/source/slang/slang-ir-single-return.cpp b/source/slang/slang-ir-single-return.cpp index 0b61e5065..519a4f2d4 100644 --- a/source/slang/slang-ir-single-return.cpp +++ b/source/slang/slang-ir-single-return.cpp @@ -18,7 +18,7 @@ struct SingleReturnContext : public InstPassBase void processFunc(IRGlobalValueWithCode* func) { IRBuilder builder(module); - simplifyCFG(func); + simplifyCFG(func, CFGSimplificationOptions::getFast()); // We make use of the `eliminate-multi-level-break` pass to implement the transformation. // To be able to do that, we need to prepare `func` so that the entire function body diff --git a/source/slang/slang-ir-specialize-function-call.cpp b/source/slang/slang-ir-specialize-function-call.cpp index f9e106920..6aa2e4998 100644 --- a/source/slang/slang-ir-specialize-function-call.cpp +++ b/source/slang/slang-ir-specialize-function-call.cpp @@ -891,7 +891,7 @@ struct FunctionParameterSpecializationContext // addCallsToWorkListRec(newFunc); - simplifyFunc(newFunc); + simplifyFunc(newFunc, IRSimplificationOptions::getFast()); return newFunc; } diff --git a/source/slang/slang-ir-specialize-resources.cpp b/source/slang/slang-ir-specialize-resources.cpp index 49f1b3177..95dc39729 100644 --- a/source/slang/slang-ir-specialize-resources.cpp +++ b/source/slang/slang-ir-specialize-resources.cpp @@ -1185,7 +1185,7 @@ bool specializeResourceUsage( // and turned into SSA temporaries. Such optimization may enable // the following passes to "see" and specialize more cases. // - simplifyIR(irModule); + simplifyIR(irModule, IRSimplificationOptions::getFast()); result |= changed; } if (unspecializableFuncs.getCount() == 0) @@ -1205,7 +1205,7 @@ bool specializeResourceUsage( inlineCall(call); }); } - simplifyIR(irModule); + simplifyIR(irModule, IRSimplificationOptions::getFast()); } return result; } diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index b2cd5edc4..e1a563b04 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -1511,7 +1511,7 @@ struct SpecializationContext // addToWorkList(newFunc); - simplifyFunc(newFunc); + simplifyFunc(newFunc, IRSimplificationOptions::getFast()); return newFunc; } @@ -2369,7 +2369,7 @@ IRInst* specializeGenericImpl( // the same thing. if (auto func = as(specializedVal)) { - simplifyFunc(func); + simplifyFunc(func, IRSimplificationOptions::getFast()); } return specializedVal; diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp index b92f5e910..715b4f6ce 100644 --- a/source/slang/slang-ir-spirv-legalize.cpp +++ b/source/slang/slang-ir-spirv-legalize.cpp @@ -13,7 +13,7 @@ #include "slang-ir-layout.h" #include "slang-ir-util.h" #include "slang-ir-dominators.h" -#include "slang-ir-array-reg-to-mem.h" +#include "slang-ir-composite-reg-to-mem.h" #include "slang-ir-sccp.h" #include "slang-ir-dce.h" #include "slang-ir-simplify-cfg.h" @@ -1516,9 +1516,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase void processModule() { -#if 0 - eliminateArrayTypeSSARegisters(m_module); -#endif + convertCompositeTypeParametersToPointers(m_module); // Process global params before anything else, so we don't generate inefficient // array marhalling code for array-typed global params. @@ -1739,7 +1737,7 @@ void simplifyIRForSpirvLegalization(DiagnosticSink* sink, IRModule* module) funcChanged |= applySparseConditionalConstantPropagation(func, sink); funcChanged |= peepholeOptimize(func); funcChanged |= removeRedundancyInFunc(func); - funcChanged |= simplifyCFG(func); + funcChanged |= simplifyCFG(func, CFGSimplificationOptions::getFast()); eliminateDeadCode(func); } } diff --git a/source/slang/slang-ir-ssa-register-allocate.cpp b/source/slang/slang-ir-ssa-register-allocate.cpp index ce502b454..90a0f05e5 100644 --- a/source/slang/slang-ir-ssa-register-allocate.cpp +++ b/source/slang/slang-ir-ssa-register-allocate.cpp @@ -13,6 +13,12 @@ namespace Slang struct RegisterAllocateContext { OrderedDictionary>> mapTypeToRegisterList; + bool allocateForCompositeTypeOnly; + + RegisterAllocateContext(bool compositeTypeOnly) + :allocateForCompositeTypeOnly(compositeTypeOnly) + {} + List>& getRegisterListForType(IRType* type) { if (auto list = mapTypeToRegisterList.tryGetValue(type)) @@ -142,7 +148,7 @@ struct RegisterAllocateContext RegisterAllocationResult allocateRegisters(IRGlobalValueWithCode* func, RefPtr& inOutDom) { - ReachabilityContext reachabilityContext; + ReachabilityContext reachabilityContext(func); mapTypeToRegisterList.clear(); auto dom = computeDominatorTree(func); @@ -275,6 +281,7 @@ struct RegisterAllocateContext } return result; } + bool instNeedsProcessing(IRGlobalValueWithCode* func, IRInst* inst) { switch (inst->getOp()) @@ -282,6 +289,8 @@ struct RegisterAllocateContext case kIROp_Param: if (inst->getParent() == func->getFirstBlock()) return false; + if (allocateForCompositeTypeOnly && !isCompositeType(inst->getFullType())) + return false; return true; case kIROp_UpdateElement: return true; @@ -303,9 +312,9 @@ struct RegisterAllocateContext } }; -RegisterAllocationResult allocateRegistersForFunc(IRGlobalValueWithCode* func, RefPtr& inOutDom) +RegisterAllocationResult allocateRegistersForFunc(IRGlobalValueWithCode* func, RefPtr& inOutDom, bool allocateForCompositeTypeOnly) { - RegisterAllocateContext context; + RegisterAllocateContext context(allocateForCompositeTypeOnly); if (context.needProcessing(func)) return context.allocateRegisters(func, inOutDom); return RegisterAllocationResult(); diff --git a/source/slang/slang-ir-ssa-register-allocate.h b/source/slang/slang-ir-ssa-register-allocate.h index 1e8c586cd..1ef575716 100644 --- a/source/slang/slang-ir-ssa-register-allocate.h +++ b/source/slang/slang-ir-ssa-register-allocate.h @@ -19,6 +19,6 @@ struct RegisterAllocationResult Dictionary> mapInstToRegister; }; -RegisterAllocationResult allocateRegistersForFunc(IRGlobalValueWithCode* func, RefPtr& inOutDom); +RegisterAllocationResult allocateRegistersForFunc(IRGlobalValueWithCode* func, RefPtr& inOutDom, bool allocateForCompositeTypesOnly); } diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp index dbef20732..104ce05e4 100644 --- a/source/slang/slang-ir-ssa-simplification.cpp +++ b/source/slang/slang-ir-ssa-simplification.cpp @@ -11,12 +11,13 @@ #include "slang-ir-redundancy-removal.h" #include "slang-ir-propagate-func-properties.h" #include "../core/slang-performance-profiler.h" +#include "slang-ir-util.h" namespace Slang { // Run a combination of SSA, SCCP, SimplifyCFG, and DeadCodeElimination pass // until no more changes are possible. - void simplifyIR(IRModule* module, DiagnosticSink* sink) + void simplifyIR(IRModule* module, IRSimplificationOptions options, DiagnosticSink* sink) { SLANG_PROFILE; bool changed = true; @@ -50,7 +51,7 @@ namespace Slang funcChanged |= applySparseConditionalConstantPropagation(func, sink); funcChanged |= peepholeOptimize(func); funcChanged |= removeRedundancyInFunc(func); - funcChanged |= simplifyCFG(func); + funcChanged |= simplifyCFG(func, options.cfgOptions); eliminateDeadCode(func); funcChanged |= constructSSA(func); changed |= funcChanged; @@ -67,7 +68,7 @@ namespace Slang } } - void simplifyNonSSAIR(IRModule* module) + void simplifyNonSSAIR(IRModule* module, IRSimplificationOptions options) { bool changed = true; const int kMaxIterations = 8; @@ -76,20 +77,20 @@ namespace Slang { changed = false; changed |= peepholeOptimize(module); + changed |= removeRedundancy(module); - changed |= simplifyCFG(module); + changed |= simplifyCFG(module, options.cfgOptions); // Note: we disregard the `changed` state from dead code elimination pass since // SCCP pass could be generating temporarily evaluated constant values and never actually use them. // DCE will always remove those nearly generated consts and always returns true here. eliminateDeadCode(module); - iterationCounter++; } } - void simplifyFunc(IRGlobalValueWithCode* func, DiagnosticSink* sink) + void simplifyFunc(IRGlobalValueWithCode* func, IRSimplificationOptions options, DiagnosticSink* sink) { bool changed = true; const int kMaxIterations = 8; @@ -103,7 +104,7 @@ namespace Slang changed |= applySparseConditionalConstantPropagation(func, sink); changed |= peepholeOptimize(func); changed |= removeRedundancyInFunc(func); - changed |= simplifyCFG(func); + changed |= simplifyCFG(func, options.cfgOptions); // Note: we disregard the `changed` state from dead code elimination pass since // SCCP pass could be generating temporarily evaluated constant values and never actually use them. diff --git a/source/slang/slang-ir-ssa-simplification.h b/source/slang/slang-ir-ssa-simplification.h index fd7aa0ad8..88347aa17 100644 --- a/source/slang/slang-ir-ssa-simplification.h +++ b/source/slang/slang-ir-ssa-simplification.h @@ -1,18 +1,37 @@ // slang-ir-ssa-simplification.h #pragma once +#include "slang-ir-simplify-cfg.h" + namespace Slang { struct IRModule; struct IRGlobalValueWithCode; class DiagnosticSink; + struct IRSimplificationOptions + { + CFGSimplificationOptions cfgOptions; + static IRSimplificationOptions getDefault() + { + IRSimplificationOptions result; + return result; + } + static IRSimplificationOptions getFast() + { + IRSimplificationOptions result; + result.cfgOptions.removeSideEffectFreeLoops = false; + result.cfgOptions.removeTrivialSingleIterationLoops = false; + return result; + } + }; + // Run a combination of SSA, SCCP, SimplifyCFG, and DeadCodeElimination pass // until no more changes are possible. - void simplifyIR(IRModule* module, DiagnosticSink* sink = nullptr); + void simplifyIR(IRModule* module, IRSimplificationOptions options, DiagnosticSink* sink = nullptr); // Run simplifications on IR that is out of SSA form. - void simplifyNonSSAIR(IRModule* module); + void simplifyNonSSAIR(IRModule* module, IRSimplificationOptions options); - void simplifyFunc(IRGlobalValueWithCode* func, DiagnosticSink* sink = nullptr); + void simplifyFunc(IRGlobalValueWithCode* func, IRSimplificationOptions options, DiagnosticSink* sink = nullptr); } diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index 730943bf8..788c9a391 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -5,6 +5,7 @@ #include "slang-ir-clone.h" #include "slang-ir-insts.h" #include "slang-ir-validate.h" +#include "slang-ir-util.h" namespace Slang { @@ -843,6 +844,8 @@ void processBlock( IRBlock* block, SSABlockInfo* blockInfo) { + hoistInstOutOfASMBlocks(block); + // Before starting, check if this block can be sealed maybeSealBlock(context, blockInfo); diff --git a/source/slang/slang-ir-use-uninitialized-out-param.cpp b/source/slang/slang-ir-use-uninitialized-out-param.cpp index 479538441..07a2b1bc2 100644 --- a/source/slang/slang-ir-use-uninitialized-out-param.cpp +++ b/source/slang/slang-ir-use-uninitialized-out-param.cpp @@ -20,7 +20,7 @@ namespace Slang if (!firstBlock) return; - ReachabilityContext reachability; + ReachabilityContext reachability(func); for (auto param : firstBlock->getParams()) { diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 5ecbc8121..073b8bf96 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -1166,6 +1166,23 @@ UnownedStringSlice getBuiltinFuncName(IRInst* callee) return decor->getName(); } +void hoistInstOutOfASMBlocks(IRBlock* block) +{ + for (auto inst : block->getChildren()) + { + if (auto asmBlock = as(inst)) + { + IRInst* next = nullptr; + for (auto i = asmBlock->getFirstChild(); i; i = next) + { + next = i->getNextInst(); + if (!as(i) && !as(i)) + i->insertBefore(asmBlock); + } + } + } +} + UnownedStringSlice getBasicTypeNameHint(IRType* basicType) { switch (basicType->getOp()) diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h index c13ce1931..ff6298f39 100644 --- a/source/slang/slang-ir-util.h +++ b/source/slang/slang-ir-util.h @@ -279,6 +279,21 @@ static void overAllBlocks(IRModule* module, F f) } } +void hoistInstOutOfASMBlocks(IRBlock* block); + +inline bool isCompositeType(IRType* type) +{ + switch (type->getOp()) + { + case kIROp_StructType: + case kIROp_ArrayType: + case kIROp_UnsizedArrayType: + return true; + default: + return false; + } +} + } #endif diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 3b7fb9ac8..6a3a26bd5 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -397,7 +397,10 @@ namespace Slang // if (auto lastParam = getLastParam()) { - param->insertAfter(lastParam); + if (lastParam->next) + param->insertAfter(lastParam); + else + param->insertAtEnd(this); } // // Otherwise, if there are any existing diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index f62d4b24a..c90230c1f 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -9371,8 +9371,6 @@ struct DeclLoweringVisitor : DeclVisitor getBuilder()->addRequireGLSLVersionDecoration(irFunc, Int(getIntegerLiteralValue(versionMod->versionNumberToken))); else if (auto spvVersion = as(modifier)) getBuilder()->addRequireSPIRVVersionDecoration(irFunc, spvVersion->version); - else if (auto capMod = as(modifier)) - getBuilder()->addRequireSPIRVCapabilityDecoration(irFunc, capMod->capability, capMod->extensionName.getUnownedSlice()); else if (auto cudasmVersion = as(modifier)) getBuilder()->addRequireCUDASMVersionDecoration(irFunc, cudasmVersion->version); } @@ -10164,7 +10162,7 @@ RefPtr generateIRForTranslationUnit( // temporaries and do basic simplifications. // constructSSA(module); - simplifyCFG(module); + simplifyCFG(module, CFGSimplificationOptions::getDefault()); applySparseConditionalConstantPropagation(module, compileRequest->getSink()); peepholeOptimize(module); @@ -10213,7 +10211,7 @@ RefPtr generateIRForTranslationUnit( bool changed = false; performMandatoryEarlyInlining(module); changed |= constructSSA(module); - simplifyCFG(module); + simplifyCFG(module, CFGSimplificationOptions::getDefault()); changed |= applySparseConditionalConstantPropagation(module, compileRequest->getSink()); changed |= peepholeOptimize(module); for (auto inst : module->getGlobalInsts()) diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 3f8225084..517ab9e10 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -6943,26 +6943,6 @@ namespace Slang return nullptr; } - static NodeBase* parseSPIRVCapabilityModifier(Parser* parser, void*) - { - parser->ReadToken(TokenType::LParent); - Token token = parser->ReadToken(TokenType::Identifier); - auto modifier = parser->astBuilder->create(); - const SPIRVCoreGrammarInfo& spirvInfo = - parser->astBuilder->getGlobalSession()->getSPIRVCoreGrammarInfo(); - const auto cap = spirvInfo.capabilities.lookup(token.getContent()); - if (!cap) - { - parser->sink->diagnose(token, Diagnostics::unknownSPIRVCapability, token); - } - else - { - modifier->capability = (int32_t)cap.value(); - } - parser->ReadToken(TokenType::RParent); - return modifier; - } - static NodeBase* parseCUDASMVersionModifier(Parser* parser, void* /*userData*/) { Token token; @@ -7323,7 +7303,6 @@ namespace Slang _makeParseModifier("__glsl_extension", parseGLSLExtensionModifier), _makeParseModifier("__glsl_version", parseGLSLVersionModifier), _makeParseModifier("__spirv_version", parseSPIRVVersionModifier), - _makeParseModifier("__spirv_capability", parseSPIRVCapabilityModifier), _makeParseModifier("__cuda_sm_version", parseCUDASMVersionModifier), _makeParseModifier("__builtin_type", parseBuiltinTypeModifier), diff --git a/source/slang/slang-spirv-opt.cpp b/source/slang/slang-spirv-opt.cpp index 34c87db9f..786358324 100644 --- a/source/slang/slang-spirv-opt.cpp +++ b/source/slang/slang-spirv-opt.cpp @@ -23,8 +23,6 @@ SlangResult optimizeSPIRV(const List& spirv, String& outErr, List& spirv) const auto out = p->getStream(StdStreamType::Out); const auto err = p->getStream(StdStreamType::ErrorOut); - // Write the assembly - SLANG_RETURN_ON_FAIL(in->write(spirv.getBuffer(), spirv.getCount())); - in->close(); + List outData; + List errData; + SLANG_RETURN_ON_FAIL(StreamUtil::readAndWrite(in, spirv.getArrayView(), out, outData, err, errData)); // Wait for it to finish if(!p->waitForTermination(1000)) return SLANG_FAIL; - - // TODO: allow inheriting stderr in Process - List outData; - SLANG_RETURN_ON_FAIL(StreamUtil::readAll(out, 0, outData)); - fwrite(outData.getBuffer(), outData.getCount(), 1, stderr); - outData.clear(); - SLANG_RETURN_ON_FAIL(StreamUtil::readAll(err, 0, outData)); - // If we failed, dump the spirv first. const auto ret = p->getReturnValue(); if(ret != 0) @@ -83,6 +75,7 @@ SlangResult debugValidateSPIRV(const List& spirv) } fwrite(outData.getBuffer(), outData.getCount(), 1, stderr); + fwrite(errData.getBuffer(), errData.getCount(), 1, stderr); return ret == 0 ? SLANG_OK : SLANG_FAIL; } -- cgit v1.2.3