summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-10-04 11:20:35 -0700
committerGitHub <noreply@github.com>2023-10-04 11:20:35 -0700
commitac886fd3e329a9599ed1ac7a6d8b26ca5821046c (patch)
tree87bcafb3985775f9d90303d6a4239eb743164407 /source
parentd87493a46c00be37b820a473c0827bbb865eb222 (diff)
SPIRV compiler performance fixes. (#3258)
* SPIRV compiler performance fixes. * Cleanup. * update project files * Cleanup debug code. * Make redundancy removal non-recursive. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/core/slang-uint-set.h21
-rw-r--r--source/slang/hlsl.meta.slang269
-rw-r--r--source/slang/slang-ast-modifier.h9
-rw-r--r--source/slang/slang-emit-spirv.cpp28
-rw-r--r--source/slang/slang-emit.cpp39
-rw-r--r--source/slang/slang-ir-array-reg-to-mem.cpp88
-rw-r--r--source/slang/slang-ir-array-reg-to-mem.h16
-rw-r--r--source/slang/slang-ir-autodiff-cfg-norm.cpp2
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp2
-rw-r--r--source/slang/slang-ir-composite-reg-to-mem.cpp202
-rw-r--r--source/slang/slang-ir-composite-reg-to-mem.h13
-rw-r--r--source/slang/slang-ir-eliminate-multilevel-break.cpp3
-rw-r--r--source/slang/slang-ir-eliminate-phis.cpp51
-rw-r--r--source/slang/slang-ir-eliminate-phis.h11
-rw-r--r--source/slang/slang-ir-inline.cpp35
-rw-r--r--source/slang/slang-ir-inline.h3
-rw-r--r--source/slang/slang-ir-inst-defs.h1
-rw-r--r--source/slang/slang-ir-insts.h19
-rw-r--r--source/slang/slang-ir-loop-unroll.cpp2
-rw-r--r--source/slang/slang-ir-lower-generics.cpp2
-rw-r--r--source/slang/slang-ir-peephole.cpp2
-rw-r--r--source/slang/slang-ir-reachability.cpp81
-rw-r--r--source/slang/slang-ir-reachability.h35
-rw-r--r--source/slang/slang-ir-redundancy-removal.cpp30
-rw-r--r--source/slang/slang-ir-sccp.cpp45
-rw-r--r--source/slang/slang-ir-simplify-cfg.cpp149
-rw-r--r--source/slang/slang-ir-simplify-cfg.h12
-rw-r--r--source/slang/slang-ir-single-return.cpp2
-rw-r--r--source/slang/slang-ir-specialize-function-call.cpp2
-rw-r--r--source/slang/slang-ir-specialize-resources.cpp4
-rw-r--r--source/slang/slang-ir-specialize.cpp4
-rw-r--r--source/slang/slang-ir-spirv-legalize.cpp8
-rw-r--r--source/slang/slang-ir-ssa-register-allocate.cpp15
-rw-r--r--source/slang/slang-ir-ssa-register-allocate.h2
-rw-r--r--source/slang/slang-ir-ssa-simplification.cpp15
-rw-r--r--source/slang/slang-ir-ssa-simplification.h25
-rw-r--r--source/slang/slang-ir-ssa.cpp3
-rw-r--r--source/slang/slang-ir-use-uninitialized-out-param.cpp2
-rw-r--r--source/slang/slang-ir-util.cpp17
-rw-r--r--source/slang/slang-ir-util.h15
-rw-r--r--source/slang/slang-ir.cpp5
-rw-r--r--source/slang/slang-lower-to-ir.cpp6
-rw-r--r--source/slang/slang-parser.cpp21
-rw-r--r--source/slang/slang-spirv-opt.cpp2
-rw-r--r--source/slang/slang-spirv-val.cpp15
45 files changed, 798 insertions, 535 deletions
diff --git a/source/core/slang-uint-set.h b/source/core/slang-uint-set.h
index f2c573865..4912ae504 100644
--- a/source/core/slang-uint-set.h
+++ b/source/core/slang-uint-set.h
@@ -52,6 +52,8 @@ public:
/// Returns true if the value is present
inline bool contains(UInt val) const;
+ inline bool contains(const UIntSet& set) const;
+
/// ==
bool operator==(const UIntSet& set) const;
/// !=
@@ -111,6 +113,25 @@ inline bool UIntSet::contains(UInt val) const
}
// --------------------------------------------------------------------------
+inline bool UIntSet::contains(const UIntSet& set) const
+{
+ for (Index i = 0; i < set.m_buffer.getCount(); i++)
+ {
+ if (i >= m_buffer.getCount())
+ {
+ if (set.m_buffer[i])
+ return false;
+ }
+ else
+ {
+ if ((m_buffer[i] & set.m_buffer[i]) != set.m_buffer[i])
+ return false;
+ }
+ }
+ return true;
+}
+
+// --------------------------------------------------------------------------
inline void UIntSet::add(UInt val)
{
const Index idx = Index(val >> kElementShift);
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 5eb08e980..0fc6fb2a9 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -5116,7 +5116,6 @@ matrix<T, N, M> trunc(matrix<T, N, M> x)
typedef uint WaveMask;
__glsl_extension(GL_KHR_shader_subgroup_ballot)
-__spirv_capability(GroupNonUniformBallot)
__spirv_version(1.3)
WaveMask WaveGetConvergedMask()
{
@@ -5132,6 +5131,7 @@ WaveMask WaveGetConvergedMask()
let _true = true;
return (spirv_asm
{
+ OpCapability GroupNonUniformBallot;
OpGroupNonUniformBallot $$uint4 result Subgroup $_true
}).x;
}
@@ -5141,7 +5141,6 @@ __intrinsic_op($(kIROp_WaveGetActiveMask))
WaveMask __WaveGetActiveMask();
__glsl_extension(GL_KHR_shader_subgroup_ballot)
-__spirv_capability(GroupNonUniformBallot)
__spirv_version(1.3)
WaveMask WaveGetActiveMask()
{
@@ -5155,6 +5154,7 @@ WaveMask WaveGetActiveMask()
let _true = true;
return (spirv_asm
{
+ OpCapability GroupNonUniformBallot;
OpGroupNonUniformBallot $$uint4 result Subgroup $_true
}).x;
default:
@@ -5163,7 +5163,6 @@ WaveMask WaveGetActiveMask()
}
__glsl_extension(GL_KHR_shader_subgroup_basic)
-__spirv_capability(GroupNonUniformBallot)
__spirv_version(1.3)
bool WaveMaskIsFirstLane(WaveMask mask)
{
@@ -5178,6 +5177,7 @@ bool WaveMaskIsFirstLane(WaveMask mask)
case spirv:
return spirv_asm
{
+ OpCapability GroupNonUniformBallot;
OpGroupNonUniformElect $$bool result Subgroup
};
default:
@@ -5187,7 +5187,6 @@ bool WaveMaskIsFirstLane(WaveMask mask)
__glsl_extension(GL_KHR_shader_subgroup_vote)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformBallot)
bool WaveMaskAllTrue(WaveMask mask, bool condition)
{
__target_switch
@@ -5201,6 +5200,7 @@ bool WaveMaskAllTrue(WaveMask mask, bool condition)
case spirv:
return spirv_asm
{
+ OpCapability GroupNonUniformBallot;
OpGroupNonUniformAll $$bool result Subgroup $condition
};
default:
@@ -5210,7 +5210,6 @@ bool WaveMaskAllTrue(WaveMask mask, bool condition)
__glsl_extension(GL_KHR_shader_subgroup_vote)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformBallot)
bool WaveMaskAnyTrue(WaveMask mask, bool condition)
{
__target_switch
@@ -5224,6 +5223,7 @@ bool WaveMaskAnyTrue(WaveMask mask, bool condition)
case spirv:
return spirv_asm
{
+ OpCapability GroupNonUniformBallot;
OpGroupNonUniformAny $$bool result Subgroup $condition
};
default:
@@ -5233,7 +5233,6 @@ bool WaveMaskAnyTrue(WaveMask mask, bool condition)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformBallot)
WaveMask WaveMaskBallot(WaveMask mask, bool condition)
{
__target_switch
@@ -5247,6 +5246,7 @@ WaveMask WaveMaskBallot(WaveMask mask, bool condition)
case spirv:
return (spirv_asm
{
+ OpCapability GroupNonUniformBallot;
OpGroupNonUniformBallot $$uint4 result Subgroup $condition
}).x;
default:
@@ -5368,7 +5368,6 @@ void GroupMemoryBarrierWithWaveSync()
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformBallot)
T WaveMaskBroadcastLaneAt(WaveMask mask, T value, constexpr int lane)
{
__target_switch
@@ -5378,14 +5377,16 @@ T WaveMaskBroadcastLaneAt(WaveMask mask, T value, constexpr int lane)
case hlsl: __intrinsic_asm "WaveReadLaneAt($1, $2)";
case spirv:
let ulane = uint(lane);
- return spirv_asm {OpGroupNonUniformBroadcast $$T result Subgroup $value $ulane};
+ return spirv_asm {
+ OpCapability GroupNonUniformBallot;
+ OpGroupNonUniformBroadcast $$T result Subgroup $value $ulane;
+ };
}
}
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformBallot)
vector<T,N> WaveMaskBroadcastLaneAt(WaveMask mask, vector<T,N> value, constexpr int lane)
{
__target_switch
@@ -5395,7 +5396,10 @@ vector<T,N> WaveMaskBroadcastLaneAt(WaveMask mask, vector<T,N> value, constexpr
case hlsl: __intrinsic_asm "WaveReadLaneAt($1, $2)";
case spirv:
let ulane = uint(lane);
- return spirv_asm {OpGroupNonUniformBroadcast $$vector<T,N> result Subgroup $value $ulane};
+ return spirv_asm {
+ OpCapability GroupNonUniformBallot;
+ OpGroupNonUniformBroadcast $$vector<T,N> result Subgroup $value $ulane;
+ };
}
}
__generic<T : __BuiltinType, let N : int, let M : int>
@@ -5408,7 +5412,6 @@ matrix<T,N,M> WaveMaskBroadcastLaneAt(WaveMask mask, matrix<T,N,M> value, conste
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformShuffle)
T WaveMaskReadLaneAt(WaveMask mask, T value, int lane)
{
__target_switch
@@ -5418,13 +5421,15 @@ T WaveMaskReadLaneAt(WaveMask mask, T value, int lane)
case hlsl: __intrinsic_asm "WaveReadLaneAt($1, $2)";
case spirv:
let ulane = uint(lane);
- return spirv_asm {OpGroupNonUniformShuffle $$T result Subgroup $value $ulane};
+ return spirv_asm {
+ OpCapability GroupNonUniformShuffle;
+ OpGroupNonUniformShuffle $$T result Subgroup $value $ulane;
+ };
}
}
__generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)__glsl_extension(GL_KHR_shader_subgroup_shuffle)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformShuffle)
vector<T,N> WaveMaskReadLaneAt(WaveMask mask, vector<T,N> value, int lane)
{
__target_switch
@@ -5434,7 +5439,10 @@ vector<T,N> WaveMaskReadLaneAt(WaveMask mask, vector<T,N> value, int lane)
case hlsl: __intrinsic_asm "WaveReadLaneAt($1, $2)";
case spirv:
let ulane = uint(lane);
- return spirv_asm {OpGroupNonUniformShuffle $$vector<T,N> result Subgroup $value $ulane};
+ return spirv_asm {
+ OpCapability GroupNonUniformShuffle;
+ OpGroupNonUniformShuffle $$vector<T,N> result Subgroup $value $ulane;
+ };
}
}
__generic<T : __BuiltinType, let N : int, let M : int>
@@ -5466,7 +5474,6 @@ matrix<T,N,M> WaveMaskShuffle(WaveMask mask, matrix<T,N,M> value, int lane)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformBallot)
uint WaveMaskPrefixCountBits(WaveMask mask, bool value)
{
__target_switch
@@ -5477,6 +5484,7 @@ uint WaveMaskPrefixCountBits(WaveMask mask, bool value)
case spirv:
return spirv_asm
{
+ OpCapability GroupNonUniformBallot;
%mask:$$uint4 = OpGroupNonUniformBallot Subgroup $value;
OpGroupNonUniformBallotBitCount $$uint result Subgroup 2 %mask
};
@@ -5488,7 +5496,6 @@ uint WaveMaskPrefixCountBits(WaveMask mask, bool value)
__generic<T : __BuiltinIntegerType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
T WaveMaskBitAnd(WaveMask mask, T expr)
{
__target_switch
@@ -5497,14 +5504,16 @@ T WaveMaskBitAnd(WaveMask mask, T expr)
case cuda: __intrinsic_asm "_waveAnd($0, $1)";
case hlsl: __intrinsic_asm "WaveActiveBitAnd($1)";
case spirv:
- return spirv_asm {OpGroupNonUniformBitwiseAnd $$T result Subgroup 0 $expr};
+ return spirv_asm {
+ OpCapability GroupNonUniformArithmetic;
+ OpGroupNonUniformBitwiseAnd $$T result Subgroup 0 $expr
+ };
}
}
__generic<T : __BuiltinIntegerType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
vector<T,N> WaveMaskBitAnd(WaveMask mask, vector<T,N> expr)
{
__target_switch
@@ -5513,7 +5522,10 @@ vector<T,N> WaveMaskBitAnd(WaveMask mask, vector<T,N> expr)
case cuda: __intrinsic_asm "_waveAndMultiple($0, $1)";
case hlsl: __intrinsic_asm "WaveActiveBitAnd($1)";
case spirv:
- return spirv_asm {OpGroupNonUniformBitwiseAnd $$vector<T,N> result Subgroup 0 $expr};
+ return spirv_asm {
+ OpCapability GroupNonUniformArithmetic;
+ OpGroupNonUniformBitwiseAnd $$vector<T,N> result Subgroup 0 $expr
+ };
}
}
__generic<T : __BuiltinIntegerType, let N : int, let M : int>
@@ -5524,7 +5536,6 @@ matrix<T,N,M> WaveMaskBitAnd(WaveMask mask, matrix<T,N,M> expr);
__generic<T : __BuiltinIntegerType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
T WaveMaskBitOr(WaveMask mask, T expr)
{
__target_switch
@@ -5533,13 +5544,15 @@ T WaveMaskBitOr(WaveMask mask, T expr)
case cuda: __intrinsic_asm "_waveOr($0, $1)";
case hlsl: __intrinsic_asm "WaveActiveBitOr($1)";
case spirv:
- return spirv_asm {OpGroupNonUniformBitwiseOr $$T result Subgroup 0 $expr};
+ return spirv_asm {
+ OpCapability GroupNonUniformArithmetic;
+ OpGroupNonUniformBitwiseOr $$T result Subgroup 0 $expr
+ };
}
}
__generic<T : __BuiltinIntegerType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
vector<T,N> WaveMaskBitOr(WaveMask mask, vector<T,N> expr)
{
__target_switch
@@ -5548,7 +5561,10 @@ vector<T,N> WaveMaskBitOr(WaveMask mask, vector<T,N> expr)
case cuda: __intrinsic_asm "_waveOrMultiple($0, $1)";
case hlsl: __intrinsic_asm "WaveActiveBitOr($1)";
case spirv:
- return spirv_asm {OpGroupNonUniformBitwiseOr $$vector<T,N> result Subgroup 0 $expr};
+ return spirv_asm {
+ OpCapability GroupNonUniformArithmetic;
+ OpGroupNonUniformBitwiseOr $$vector<T,N> result Subgroup 0 $expr
+ };
}
}
__generic<T : __BuiltinIntegerType, let N : int, let M : int>
@@ -5559,7 +5575,6 @@ matrix<T,N,M> WaveMaskBitOr(WaveMask mask, matrix<T,N,M> expr);
__generic<T : __BuiltinIntegerType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
T WaveMaskBitXor(WaveMask mask, T expr)
{
__target_switch
@@ -5568,13 +5583,15 @@ T WaveMaskBitXor(WaveMask mask, T expr)
case cuda: __intrinsic_asm "_waveXor($0, $1)";
case hlsl: __intrinsic_asm "WaveActiveBitXor($1)";
case spirv:
- return spirv_asm {OpGroupNonUniformBitwiseXor $$T result Subgroup 0 $expr};
+ return spirv_asm {
+ OpCapability GroupNonUniformArithmetic;
+ OpGroupNonUniformBitwiseXor $$T result Subgroup 0 $expr
+ };
}
}
__generic<T : __BuiltinIntegerType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
vector<T,N> WaveMaskBitXor(WaveMask mask, vector<T,N> expr)
{
__target_switch
@@ -5583,7 +5600,10 @@ vector<T,N> WaveMaskBitXor(WaveMask mask, vector<T,N> expr)
case cuda: __intrinsic_asm "_waveXorMultiple($0, $1)";
case hlsl: __intrinsic_asm "WaveActiveBitXor($1)";
case spirv:
- return spirv_asm {OpGroupNonUniformBitwiseXor $$vector<T,N> result Subgroup 0 $expr};
+ return spirv_asm {
+ OpCapability GroupNonUniformArithmetic;
+ OpGroupNonUniformBitwiseXor $$vector<T,N> result Subgroup 0 $expr
+ };
}
}
__generic<T : __BuiltinIntegerType, let N : int, let M : int>
@@ -5594,7 +5614,6 @@ matrix<T,N,M> WaveMaskBitXor(WaveMask mask, matrix<T,N,M> expr);
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
T WaveMaskMax(WaveMask mask, T expr)
{
__target_switch
@@ -5604,17 +5623,16 @@ T WaveMaskMax(WaveMask mask, T expr)
case hlsl: __intrinsic_asm "WaveActiveMax($1)";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpGroupNonUniformFMax $$T result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMax $$T result Subgroup 0 $expr};
else if (__isSignedInt<T>())
- return spirv_asm {OpGroupNonUniformSMax $$T result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformSMax $$T result Subgroup 0 $expr};
else if (__isUnsignedInt<T>())
- return spirv_asm {OpGroupNonUniformUMax $$T result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformUMax $$T result Subgroup 0 $expr};
}
}
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
vector<T,N> WaveMaskMax(WaveMask mask, vector<T,N> expr)
{
__target_switch
@@ -5624,11 +5642,11 @@ vector<T,N> WaveMaskMax(WaveMask mask, vector<T,N> expr)
case hlsl: __intrinsic_asm "WaveActiveMax($1)";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpGroupNonUniformFMax $$vector<T,N> result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMax $$vector<T,N> result Subgroup 0 $expr};
else if (__isSignedInt<T>())
- return spirv_asm {OpGroupNonUniformSMax $$vector<T,N> result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformSMax $$vector<T,N> result Subgroup 0 $expr};
else if (__isUnsignedInt<T>())
- return spirv_asm {OpGroupNonUniformUMax $$vector<T,N> result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformUMax $$vector<T,N> result Subgroup 0 $expr};
}
}
@@ -5640,7 +5658,6 @@ matrix<T,N,M> WaveMaskMax(WaveMask mask, matrix<T,N,M> expr);
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
T WaveMaskMin(WaveMask mask, T expr)
{
__target_switch
@@ -5650,18 +5667,17 @@ T WaveMaskMin(WaveMask mask, T expr)
case hlsl: __intrinsic_asm "WaveActiveMin($1)";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpGroupNonUniformFMin $$T result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMin $$T result Subgroup 0 $expr};
else if (__isSignedInt<T>())
- return spirv_asm {OpGroupNonUniformSMin $$T result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformSMin $$T result Subgroup 0 $expr};
else if (__isUnsignedInt<T>())
- return spirv_asm {OpGroupNonUniformUMin $$T result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformUMin $$T result Subgroup 0 $expr};
}
}
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
vector<T,N> WaveMaskMin(WaveMask mask, vector<T,N> expr)
{
__target_switch
@@ -5671,11 +5687,11 @@ vector<T,N> WaveMaskMin(WaveMask mask, vector<T,N> expr)
case hlsl: __intrinsic_asm "WaveActiveMin($1)";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpGroupNonUniformFMin $$vector<T,N> result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMin $$vector<T,N> result Subgroup 0 $expr};
else if (__isSignedInt<T>())
- return spirv_asm {OpGroupNonUniformSMin $$vector<T,N> result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformSMin $$vector<T,N> result Subgroup 0 $expr};
else if (__isUnsignedInt<T>())
- return spirv_asm {OpGroupNonUniformUMin $$vector<T,N> result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformUMin $$vector<T,N> result Subgroup 0 $expr};
}
}
@@ -5687,7 +5703,6 @@ matrix<T,N,M> WaveMaskMin(WaveMask mask, matrix<T,N,M> expr);
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
T WaveMaskProduct(WaveMask mask, T expr)
{
__target_switch
@@ -5697,11 +5712,12 @@ T WaveMaskProduct(WaveMask mask, T expr)
case hlsl: __intrinsic_asm "WaveActiveProduct($1)";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpGroupNonUniformFMul $$T result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMul $$T result Subgroup 0 $expr};
else if (__isSignedInt<T>())
{
return spirv_asm
{
+ OpCapability GroupNonUniformArithmetic;
// TODO: use the correct integer width
OpBitcast $$uint %uvalue $expr;
OpGroupNonUniformIMul $$uint %mulResult Subgroup 0 %uvalue;
@@ -5709,14 +5725,13 @@ T WaveMaskProduct(WaveMask mask, T expr)
};
}
else if (__isUnsignedInt<T>())
- return spirv_asm {OpGroupNonUniformIMul $$T result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIMul $$T result Subgroup 0 $expr};
}
}
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
vector<T,N> WaveMaskProduct(WaveMask mask, vector<T,N> expr)
{
__target_switch
@@ -5726,11 +5741,12 @@ vector<T,N> WaveMaskProduct(WaveMask mask, vector<T,N> expr)
case hlsl: __intrinsic_asm "WaveActiveProduct($1)";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpGroupNonUniformFMul $$vector<T,N> result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMul $$vector<T,N> result Subgroup 0 $expr};
else if (__isSignedInt<T>())
{
return spirv_asm
{
+ OpCapability GroupNonUniformArithmetic;
// TODO: use the correct integer width
OpBitcast $$vector<uint,N> %uvalue $expr;
OpGroupNonUniformIMul $$vector<uint,N> %mulResult Subgroup 0 %uvalue;
@@ -5738,7 +5754,7 @@ vector<T,N> WaveMaskProduct(WaveMask mask, vector<T,N> expr)
};
}
else if (__isUnsignedInt<T>())
- return spirv_asm {OpGroupNonUniformIMul $$vector<T,N> result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIMul $$vector<T,N> result Subgroup 0 $expr};
}
}
@@ -5750,7 +5766,6 @@ matrix<T,N,M> WaveMaskProduct(WaveMask mask, matrix<T,N,M> expr);
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
T WaveMaskSum(WaveMask mask, T expr)
{
__target_switch
@@ -5760,11 +5775,12 @@ T WaveMaskSum(WaveMask mask, T expr)
case hlsl: __intrinsic_asm "WaveActiveSum($1)";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpGroupNonUniformFAdd $$T result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$T result Subgroup 0 $expr};
else if (__isSignedInt<T>())
{
return spirv_asm
{
+ OpCapability GroupNonUniformArithmetic;
// TODO: use the correct integer width
OpBitcast $$uint %uvalue $expr;
OpGroupNonUniformIAdd $$uint %mulResult Subgroup 0 %uvalue;
@@ -5772,13 +5788,12 @@ T WaveMaskSum(WaveMask mask, T expr)
};
}
else if (__isUnsignedInt<T>())
- return spirv_asm {OpGroupNonUniformIAdd $$T result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIAdd $$T result Subgroup 0 $expr};
}
}
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
vector<T,N> WaveMaskSum(WaveMask mask, vector<T,N> expr)
{
__target_switch
@@ -5788,11 +5803,12 @@ vector<T,N> WaveMaskSum(WaveMask mask, vector<T,N> expr)
case hlsl: __intrinsic_asm "WaveActiveSum($1)";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpGroupNonUniformFAdd $$vector<T,N> result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$vector<T,N> result Subgroup 0 $expr};
else if (__isSignedInt<T>())
{
return spirv_asm
{
+ OpCapability GroupNonUniformArithmetic;
// TODO: use the correct integer width
OpBitcast $$vector<uint,N> %uvalue $expr;
OpGroupNonUniformIAdd $$vector<uint,N> %mulResult Subgroup 0 %uvalue;
@@ -5800,7 +5816,7 @@ vector<T,N> WaveMaskSum(WaveMask mask, vector<T,N> expr)
};
}
else if (__isUnsignedInt<T>())
- return spirv_asm {OpGroupNonUniformIAdd $$vector<T,N> result Subgroup 0 $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIAdd $$vector<T,N> result Subgroup 0 $expr};
}
}
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
@@ -5811,7 +5827,6 @@ matrix<T,N,M> WaveMaskSum(WaveMask mask, matrix<T,N,M> expr);
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_vote)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformVote)
__cuda_sm_version(7.0)
bool WaveMaskAllEqual(WaveMask mask, T value)
{
@@ -5826,6 +5841,7 @@ bool WaveMaskAllEqual(WaveMask mask, T value)
case spirv:
return spirv_asm
{
+ OpCapability GroupNonUniformVote;
OpGroupNonUniformAllEqual $$bool result Subgroup $value
};
default:
@@ -5836,7 +5852,6 @@ __generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_vote)
__spirv_version(1.3)
__cuda_sm_version(7.0)
-__spirv_capability(GroupNonUniformVote)
bool WaveMaskAllEqual(WaveMask mask, vector<T,N> value)
{
__target_switch
@@ -5850,6 +5865,7 @@ bool WaveMaskAllEqual(WaveMask mask, vector<T,N> value)
case spirv:
return spirv_asm
{
+ OpCapability GroupNonUniformVote;
OpGroupNonUniformAllEqual $$bool result Subgroup $value
};
default:
@@ -5867,7 +5883,6 @@ bool WaveMaskAllEqual(WaveMask mask, matrix<T,N,M> value);
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
T WaveMaskPrefixProduct(WaveMask mask, T expr)
{
__target_switch
@@ -5877,11 +5892,12 @@ T WaveMaskPrefixProduct(WaveMask mask, T expr)
case hlsl: __intrinsic_asm "WavePrefixProduct($1)";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpGroupNonUniformFMul $$T result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMul $$T result Subgroup ExclusiveScan $expr};
else if (__isSignedInt<T>())
{
return spirv_asm
{
+ OpCapability GroupNonUniformArithmetic;
// TODO: use the correct integer width
OpBitcast $$uint %uvalue $expr;
OpGroupNonUniformIMul $$uint %mulResult Subgroup ExclusiveScan %uvalue;
@@ -5895,7 +5911,6 @@ T WaveMaskPrefixProduct(WaveMask mask, T expr)
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
vector<T,N> WaveMaskPrefixProduct(WaveMask mask, vector<T,N> expr)
{
__target_switch
@@ -5905,11 +5920,12 @@ vector<T,N> WaveMaskPrefixProduct(WaveMask mask, vector<T,N> expr)
case hlsl: __intrinsic_asm "WavePrefixProduct($1)";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpGroupNonUniformFMul $$vector<T,N> result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMul $$vector<T,N> result Subgroup ExclusiveScan $expr};
else if (__isSignedInt<T>())
{
return spirv_asm
{
+ OpCapability GroupNonUniformArithmetic;
// TODO: use the correct integer width
OpBitcast $$vector<uint,N> %uvalue $expr;
OpGroupNonUniformIMul $$vector<uint,N> %mulResult Subgroup ExclusiveScan %uvalue;
@@ -5917,7 +5933,7 @@ vector<T,N> WaveMaskPrefixProduct(WaveMask mask, vector<T,N> expr)
};
}
else if (__isUnsignedInt<T>())
- return spirv_asm {OpGroupNonUniformIMul $$vector<T,N> result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIMul $$vector<T,N> result Subgroup ExclusiveScan $expr};
}
}
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
@@ -5928,7 +5944,6 @@ matrix<T,N,M> WaveMaskPrefixProduct(WaveMask mask, matrix<T,N,M> expr);
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
T WaveMaskPrefixSum(WaveMask mask, T expr)
{
__target_switch
@@ -5938,11 +5953,12 @@ T WaveMaskPrefixSum(WaveMask mask, T expr)
case hlsl: __intrinsic_asm "WavePrefixSum($1)";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpGroupNonUniformFAdd $$T result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$T result Subgroup ExclusiveScan $expr};
else if (__isSignedInt<T>())
{
return spirv_asm
{
+ OpCapability GroupNonUniformArithmetic;
// TODO: use the correct integer width
%uvalue:$$uint = OpBitcast $expr;
%mulResult:$$uint = OpGroupNonUniformIAdd Subgroup ExclusiveScan %uvalue;
@@ -5950,14 +5966,13 @@ T WaveMaskPrefixSum(WaveMask mask, T expr)
};
}
else if (__isUnsignedInt<T>())
- return spirv_asm {OpGroupNonUniformIAdd $$T result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIAdd $$T result Subgroup ExclusiveScan $expr};
}
}
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
vector<T,N> WaveMaskPrefixSum(WaveMask mask, vector<T,N> expr)
{
__target_switch
@@ -5967,11 +5982,12 @@ vector<T,N> WaveMaskPrefixSum(WaveMask mask, vector<T,N> expr)
case hlsl: __intrinsic_asm "WavePrefixSum($1)";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpGroupNonUniformFAdd $$vector<T,N> result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$vector<T,N> result Subgroup ExclusiveScan $expr};
else if (__isSignedInt<T>())
{
return spirv_asm
{
+ OpCapability GroupNonUniformArithmetic;
// TODO: use the correct integer width
%uvalue: $$vector<uint,N> = OpBitcast $expr;
%mulResult: $$vector<uint,N> = OpGroupNonUniformIAdd Subgroup ExclusiveScan %uvalue;
@@ -5979,7 +5995,7 @@ vector<T,N> WaveMaskPrefixSum(WaveMask mask, vector<T,N> expr)
};
}
else if (__isUnsignedInt<T>())
- return spirv_asm {OpGroupNonUniformIAdd $$vector<T,N> result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIAdd $$vector<T,N> result Subgroup ExclusiveScan $expr};
}
}
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
@@ -5990,7 +6006,6 @@ matrix<T,N,M> WaveMaskPrefixSum(WaveMask mask, matrix<T,N,M> expr);
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformBallot)
T WaveMaskReadLaneFirst(WaveMask mask, T expr)
{
__target_switch
@@ -5999,13 +6014,12 @@ T WaveMaskReadLaneFirst(WaveMask mask, T expr)
case cuda: __intrinsic_asm "_waveReadFirst($0, $1)";
case hlsl: __intrinsic_asm "WaveReadLaneFirst($1)";
case spirv:
- return spirv_asm {OpGroupNonUniformBroadcastFirst $$T result Subgroup $expr};
+ return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$T result Subgroup $expr};
}
}
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformBallot)
vector<T,N> WaveMaskReadLaneFirst(WaveMask mask, vector<T,N> expr)
{
__target_switch
@@ -6014,7 +6028,7 @@ vector<T,N> WaveMaskReadLaneFirst(WaveMask mask, vector<T,N> expr)
case cuda: __intrinsic_asm "_waveReadFirstMultiple($0, $1)";
case hlsl: __intrinsic_asm "WaveReadLaneFirst($1)";
case spirv:
- return spirv_asm {OpGroupNonUniformBroadcastFirst $$vector<T,N> result Subgroup $expr};
+ return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$vector<T,N> result Subgroup $expr};
}
}
@@ -6079,7 +6093,6 @@ WaveMask WaveMaskMatch(WaveMask mask, matrix<T,N,M> value);
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
T WaveMaskPrefixBitAnd(WaveMask mask, T expr)
{
__target_switch
@@ -6088,14 +6101,13 @@ T WaveMaskPrefixBitAnd(WaveMask mask, T expr)
case cuda: __intrinsic_asm "_wavePrefixAnd($0, $1)";
case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd($1, uint4($0, 0, 0, 0))";
case spirv:
- return spirv_asm {OpGroupNonUniformBitwiseAnd $$T result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformBitwiseAnd $$T result Subgroup ExclusiveScan $expr};
}
}
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
vector<T,N> WaveMaskPrefixBitAnd(WaveMask mask, vector<T,N> expr)
{
__target_switch
@@ -6104,7 +6116,7 @@ vector<T,N> WaveMaskPrefixBitAnd(WaveMask mask, vector<T,N> expr)
case cuda: __intrinsic_asm "_wavePrefixAndMultiple($0, $1)";
case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd($1, uint4($0, 0, 0, 0))";
case spirv:
- return spirv_asm {OpGroupNonUniformBitwiseAnd $$vector<T,N> result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformBitwiseAnd $$vector<T,N> result Subgroup ExclusiveScan $expr};
}
}
@@ -6116,7 +6128,6 @@ matrix<T,N,M> WaveMaskPrefixBitAnd(WaveMask mask, matrix<T,N,M> expr);
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
T WaveMaskPrefixBitOr(WaveMask mask, T expr)
{
__target_switch
@@ -6125,14 +6136,13 @@ T WaveMaskPrefixBitOr(WaveMask mask, T expr)
case cuda: __intrinsic_asm "_wavePrefixOr($0, $1)";
case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr($1, uint4($0, 0, 0, 0))";
case spirv:
- return spirv_asm {OpGroupNonUniformBitwiseAnd $$T result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformBitwiseAnd $$T result Subgroup ExclusiveScan $expr};
}
}
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
vector<T,N> WaveMaskPrefixBitOr(WaveMask mask, vector<T,N> expr)
{
__target_switch
@@ -6141,7 +6151,7 @@ vector<T,N> WaveMaskPrefixBitOr(WaveMask mask, vector<T,N> expr)
case cuda: __intrinsic_asm "_wavePrefixOrMultiple($0, $1)";
case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr($1, uint4($0, 0, 0, 0))";
case spirv:
- return spirv_asm {OpGroupNonUniformBitwiseOr $$vector<T,N> result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformBitwiseOr $$vector<T,N> result Subgroup ExclusiveScan $expr};
}
}
@@ -6153,7 +6163,6 @@ matrix<T,N,M> WaveMaskPrefixBitOr(WaveMask mask, matrix<T,N,M> expr);
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
T WaveMaskPrefixBitXor(WaveMask mask, T expr)
{
__target_switch
@@ -6162,14 +6171,13 @@ T WaveMaskPrefixBitXor(WaveMask mask, T expr)
case cuda: __intrinsic_asm "_wavePrefixXor($0, $1)";
case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor($1, uint4($0, 0, 0, 0))";
case spirv:
- return spirv_asm {OpGroupNonUniformBitwiseXor $$T result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformBitwiseXor $$T result Subgroup ExclusiveScan $expr};
}
}
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
vector<T,N> WaveMaskPrefixBitXor(WaveMask mask, vector<T,N> expr)
{
__target_switch
@@ -6178,7 +6186,7 @@ vector<T,N> WaveMaskPrefixBitXor(WaveMask mask, vector<T,N> expr)
case cuda: __intrinsic_asm "_wavePrefixXorMultiple($0, $1)";
case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor($1, uint4($0, 0, 0, 0))";
case spirv:
- return spirv_asm {OpGroupNonUniformBitwiseXor $$vector<T,N> result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformBitwiseXor $$vector<T,N> result Subgroup ExclusiveScan $expr};
}
}
@@ -6218,7 +6226,6 @@ for (auto opName : kWaveActiveBitOpEntries) {
__generic<T : __BuiltinIntegerType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
T WaveActive$(opName.hlslName)(T expr)
{
__target_switch
@@ -6226,7 +6233,7 @@ T WaveActive$(opName.hlslName)(T expr)
case glsl: __intrinsic_asm "subgroup$(opName.glslName)($0)";
case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
case spirv:
- return spirv_asm {OpGroupNonUniform$(opName.spirvName) $$T result Subgroup Reduce $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniform$(opName.spirvName) $$T result Subgroup Reduce $expr};
default:
return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
}
@@ -6235,7 +6242,6 @@ T WaveActive$(opName.hlslName)(T expr)
__generic<T : __BuiltinIntegerType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
vector<T, N> WaveActive$(opName.hlslName)(vector<T, N> expr)
{
__target_switch
@@ -6243,7 +6249,7 @@ vector<T, N> WaveActive$(opName.hlslName)(vector<T, N> expr)
case glsl: __intrinsic_asm "subgroup$(opName.glslName)($0)";
case hlsl: __intrinsic_asm "WaveActive$(opName.hlslName)";
case spirv:
- return spirv_asm {OpGroupNonUniform$(opName.spirvName) $$vector<T, N> result Subgroup Reduce $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniform$(opName.spirvName) $$vector<T, N> result Subgroup Reduce $expr};
default:
return WaveMask$(opName.hlslName)(WaveGetActiveMask(), expr);
}
@@ -6268,7 +6274,6 @@ for (const char* opName : kWaveActiveMinMaxNames) {
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
T WaveActive$(opName)(T expr)
{
__target_switch
@@ -6277,11 +6282,11 @@ T WaveActive$(opName)(T expr)
case hlsl: __intrinsic_asm "WaveActive$(opName)";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpGroupNonUniformF$(opName) $$T result Subgroup Reduce $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformF$(opName) $$T result Subgroup Reduce $expr};
else if (__isUnsignedInt<T>())
- return spirv_asm {OpGroupNonUniformU$(opName) $$T result Subgroup Reduce $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformU$(opName) $$T result Subgroup Reduce $expr};
else
- return spirv_asm {OpGroupNonUniformS$(opName) $$T result Subgroup Reduce $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformS$(opName) $$T result Subgroup Reduce $expr};
default:
return WaveMask$(opName)(WaveGetActiveMask(), expr);
}
@@ -6290,7 +6295,6 @@ T WaveActive$(opName)(T expr)
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
vector<T, N> WaveActive$(opName)(vector<T, N> expr)
{
__target_switch
@@ -6299,11 +6303,11 @@ vector<T, N> WaveActive$(opName)(vector<T, N> expr)
case hlsl: __intrinsic_asm "WaveActive$(opName)";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpGroupNonUniformF$(opName) $$vector<T, N> result Subgroup Reduce $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformF$(opName) $$vector<T, N> result Subgroup Reduce $expr};
else if (__isUnsignedInt<T>())
- return spirv_asm {OpGroupNonUniformU$(opName) $$vector<T, N> result Subgroup Reduce $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformU$(opName) $$vector<T, N> result Subgroup Reduce $expr};
else
- return spirv_asm {OpGroupNonUniformS$(opName) $$vector<T, N> result Subgroup Reduce $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformS$(opName) $$vector<T, N> result Subgroup Reduce $expr};
default:
return WaveMask$(opName)(WaveGetActiveMask(), expr);
}
@@ -6415,7 +6419,6 @@ ${{{{
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_vote)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformVote)
bool WaveActiveAllEqual(T value)
{
__target_switch
@@ -6427,6 +6430,7 @@ bool WaveActiveAllEqual(T value)
case spirv:
return spirv_asm
{
+ OpCapability GroupNonUniformVote;
OpGroupNonUniformAllEqual $$bool result Subgroup $value
};
default:
@@ -6437,7 +6441,6 @@ bool WaveActiveAllEqual(T value)
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_vote)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformVote)
bool WaveActiveAllEqual(vector<T,N> value)
{
__target_switch
@@ -6449,6 +6452,7 @@ bool WaveActiveAllEqual(vector<T,N> value)
case spirv:
return spirv_asm
{
+ OpCapability GroupNonUniformVote;
OpGroupNonUniformAllEqual $$bool result Subgroup $value
};
default:
@@ -6465,7 +6469,6 @@ bool WaveActiveAllEqual(matrix<T, N, M> value)
__glsl_extension(GL_KHR_shader_subgroup_vote)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformVote)
bool WaveActiveAllTrue(bool condition)
{
__target_switch
@@ -6477,6 +6480,7 @@ bool WaveActiveAllTrue(bool condition)
case spirv:
return spirv_asm
{
+ OpCapability GroupNonUniformVote;
OpGroupNonUniformAll $$bool result Subgroup $condition
};
default:
@@ -6486,7 +6490,6 @@ bool WaveActiveAllTrue(bool condition)
__glsl_extension(GL_KHR_shader_subgroup_vote)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformVote)
bool WaveActiveAnyTrue(bool condition)
{
__target_switch
@@ -6498,6 +6501,7 @@ bool WaveActiveAnyTrue(bool condition)
case spirv:
return spirv_asm
{
+ OpCapability GroupNonUniformVote;
OpGroupNonUniformAny $$bool result Subgroup $condition
};
default:
@@ -6507,7 +6511,6 @@ bool WaveActiveAnyTrue(bool condition)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformBallot)
uint4 WaveActiveBallot(bool condition)
{
__target_switch
@@ -6519,6 +6522,7 @@ uint4 WaveActiveBallot(bool condition)
case spirv:
return spirv_asm
{
+ OpCapability GroupNonUniformBallot;
OpGroupNonUniformBallot $$uint4 result Subgroup $condition
};
default:
@@ -6534,7 +6538,6 @@ uint WaveActiveCountBits(bool value)
__glsl_extension(GL_KHR_shader_subgroup_basic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniform)
uint WaveGetLaneCount()
{
__target_switch
@@ -6545,6 +6548,7 @@ uint WaveGetLaneCount()
case spirv:
return spirv_asm
{
+ OpCapability GroupNonUniform;
result:$$uint = OpLoad builtin(SubgroupSize:uint)
};
}
@@ -6552,7 +6556,6 @@ uint WaveGetLaneCount()
__glsl_extension(GL_KHR_shader_subgroup_basic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniform)
uint WaveGetLaneIndex()
{
__target_switch
@@ -6563,13 +6566,13 @@ uint WaveGetLaneIndex()
case spirv:
return spirv_asm
{
+ OpCapability GroupNonUniform;
result:$$uint = OpLoad builtin(SubgroupLocalInvocationId:uint)
};
}
}
__glsl_extension(GL_KHR_shader_subgroup_basic)
-__spirv_capability(GroupNonUniformBallot)
__spirv_version(1.3)
bool WaveIsFirstLane()
{
@@ -6582,6 +6585,7 @@ bool WaveIsFirstLane()
case spirv:
return spirv_asm
{
+ OpCapability GroupNonUniformBallot;
OpGroupNonUniformElect $$bool result Subgroup
};
default:
@@ -6591,7 +6595,6 @@ bool WaveIsFirstLane()
// It's useful to have a wave uint4 version of countbits, because some wave functions return uint4.
// This implementation tries to limit the amount of work required by the actual lane count.
-__spirv_capability(GroupNonUniformBallot)
uint _WaveCountBits(uint4 value)
{
__target_switch
@@ -6599,6 +6602,7 @@ uint _WaveCountBits(uint4 value)
case spirv:
return spirv_asm
{
+ OpCapability GroupNonUniformBallot;
OpGroupNonUniformBallotBitCount $$uint result Subgroup Reduce $value
};
default:
@@ -6621,7 +6625,6 @@ uint _WaveCountBits(uint4 value)
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
T WavePrefixProduct(T expr)
{
__target_switch
@@ -6630,11 +6633,15 @@ T WavePrefixProduct(T expr)
case hlsl: __intrinsic_asm "WavePrefixProduct";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpGroupNonUniformFMul $$T result Subgroup ExclusiveScan $expr};
+ return spirv_asm {
+ OpCapability GroupNonUniformArithmetic;
+ OpGroupNonUniformFMul $$T result Subgroup ExclusiveScan $expr
+ };
else if (__isSignedInt<T>())
{
return spirv_asm
{
+ OpCapability GroupNonUniformArithmetic;
// TODO: use the correct integer width
OpBitcast $$uint %uvalue $expr;
OpGroupNonUniformIMul $$uint %mulResult Subgroup ExclusiveScan %uvalue;
@@ -6642,7 +6649,7 @@ T WavePrefixProduct(T expr)
};
}
else if (__isUnsignedInt<T>())
- return spirv_asm {OpGroupNonUniformIMul $$T result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIMul $$T result Subgroup ExclusiveScan $expr};
default:
return WaveMaskPrefixProduct(WaveGetActiveMask(), expr);
}
@@ -6652,7 +6659,6 @@ T WavePrefixProduct(T expr)
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
vector<T,N> WavePrefixProduct(vector<T,N> expr)
{
__target_switch
@@ -6661,11 +6667,12 @@ vector<T,N> WavePrefixProduct(vector<T,N> expr)
case hlsl: __intrinsic_asm "WavePrefixProduct";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpGroupNonUniformFMul $$vector<T,N> result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFMul $$vector<T,N> result Subgroup ExclusiveScan $expr};
else if (__isSignedInt<T>())
{
return spirv_asm
{
+ OpCapability GroupNonUniformArithmetic;
// TODO: use the correct integer width
OpBitcast $$vector<uint,N> %uvalue $expr;
OpGroupNonUniformIMul $$vector<uint,N> %mulResult Subgroup ExclusiveScan %uvalue;
@@ -6673,7 +6680,7 @@ vector<T,N> WavePrefixProduct(vector<T,N> expr)
};
}
else if (__isUnsignedInt<T>())
- return spirv_asm {OpGroupNonUniformIMul $$vector<T,N> result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIMul $$vector<T,N> result Subgroup ExclusiveScan $expr};
default:
return WaveMaskPrefixProduct(WaveGetActiveMask(), expr);
}
@@ -6689,7 +6696,6 @@ matrix<T, N, M> WavePrefixProduct(matrix<T, N, M> expr)
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
T WavePrefixSum(T expr)
{
__target_switch
@@ -6698,11 +6704,12 @@ T WavePrefixSum(T expr)
case hlsl: __intrinsic_asm "WavePrefixSum";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpGroupNonUniformFAdd $$T result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$T result Subgroup ExclusiveScan $expr};
else if (__isSignedInt<T>())
{
return spirv_asm
{
+ OpCapability GroupNonUniformArithmetic;
// TODO: use the correct integer width
%uvalue:$$uint = OpBitcast $expr;
%mulResult:$$uint = OpGroupNonUniformIAdd Subgroup ExclusiveScan %uvalue;
@@ -6710,7 +6717,7 @@ T WavePrefixSum(T expr)
};
}
else if (__isUnsignedInt<T>())
- return spirv_asm {OpGroupNonUniformIAdd $$T result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIAdd $$T result Subgroup ExclusiveScan $expr};
default:
return WaveMaskPrefixSum(WaveGetActiveMask(), expr);
}
@@ -6719,7 +6726,6 @@ T WavePrefixSum(T expr)
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformArithmetic)
vector<T,N> WavePrefixSum(vector<T,N> expr)
{
__target_switch
@@ -6728,11 +6734,12 @@ vector<T,N> WavePrefixSum(vector<T,N> expr)
case hlsl: __intrinsic_asm "WavePrefixSum";
case spirv:
if (__isFloat<T>())
- return spirv_asm {OpGroupNonUniformFAdd $$vector<T,N> result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformFAdd $$vector<T,N> result Subgroup ExclusiveScan $expr};
else if (__isSignedInt<T>())
{
return spirv_asm
{
+ OpCapability GroupNonUniformArithmetic;
// TODO: use the correct integer width
%uvalue:$$vector<uint,N> = OpBitcast $expr;
%mulResult:$$vector<uint,N> = OpGroupNonUniformIAdd Subgroup ExclusiveScan %uvalue;
@@ -6740,7 +6747,7 @@ vector<T,N> WavePrefixSum(vector<T,N> expr)
};
}
else if (__isUnsignedInt<T>())
- return spirv_asm {OpGroupNonUniformIAdd $$vector<T,N> result Subgroup ExclusiveScan $expr};
+ return spirv_asm {OpCapability GroupNonUniformArithmetic; OpGroupNonUniformIAdd $$vector<T,N> result Subgroup ExclusiveScan $expr};
default:
return WaveMaskPrefixSum(WaveGetActiveMask(), expr);
}
@@ -6756,7 +6763,6 @@ matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr)
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformBallot)
T WaveReadLaneFirst(T expr)
{
__target_switch
@@ -6764,7 +6770,7 @@ T WaveReadLaneFirst(T expr)
case glsl: __intrinsic_asm "subgroupBroadcastFirst($0)";
case hlsl: __intrinsic_asm "WaveReadLaneFirst";
case spirv:
- return spirv_asm {OpGroupNonUniformBroadcastFirst $$T result Subgroup $expr};
+ return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$T result Subgroup $expr};
default:
return WaveMaskReadLaneFirst(WaveGetActiveMask(), expr);
}
@@ -6773,7 +6779,6 @@ T WaveReadLaneFirst(T expr)
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformBallot)
vector<T,N> WaveReadLaneFirst(vector<T,N> expr)
{
__target_switch
@@ -6781,7 +6786,7 @@ vector<T,N> WaveReadLaneFirst(vector<T,N> expr)
case glsl: __intrinsic_asm "subgroupBroadcastFirst($0)";
case hlsl: __intrinsic_asm "WaveReadLaneFirst";
case spirv:
- return spirv_asm {OpGroupNonUniformBroadcastFirst $$vector<T,N> result Subgroup $expr};
+ return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcastFirst $$vector<T,N> result Subgroup $expr};
default:
return WaveMaskReadLaneFirst(WaveGetActiveMask(), expr);
}
@@ -6803,7 +6808,6 @@ matrix<T,N,M> WaveReadLaneFirst(matrix<T,N,M> expr)
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformBallot)
T WaveBroadcastLaneAt(T value, constexpr int lane)
{
__target_switch
@@ -6812,7 +6816,7 @@ T WaveBroadcastLaneAt(T value, constexpr int lane)
case hlsl: __intrinsic_asm "WaveReadLaneAt";
case spirv:
let ulane = uint(lane);
- return spirv_asm {OpGroupNonUniformBroadcast $$T result Subgroup $value $ulane};
+ return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcast $$T result Subgroup $value $ulane};
default:
return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, lane);
}
@@ -6821,7 +6825,6 @@ T WaveBroadcastLaneAt(T value, constexpr int lane)
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformBallot)
vector<T,N> WaveBroadcastLaneAt(vector<T,N> value, constexpr int lane)
{
__target_switch
@@ -6830,7 +6833,7 @@ vector<T,N> WaveBroadcastLaneAt(vector<T,N> value, constexpr int lane)
case hlsl: __intrinsic_asm "WaveReadLaneAt";
case spirv:
let ulane = uint(lane);
- return spirv_asm {OpGroupNonUniformBroadcast $$vector<T,N> result Subgroup $value $ulane};
+ return spirv_asm {OpCapability GroupNonUniformBallot; OpGroupNonUniformBroadcast $$vector<T,N> result Subgroup $value $ulane};
default:
return WaveMaskBroadcastLaneAt(WaveGetActiveMask(), value, lane);
}
@@ -6849,7 +6852,6 @@ matrix<T, N, M> WaveBroadcastLaneAt(matrix<T, N, M> value, constexpr int lane)
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformShuffle)
T WaveReadLaneAt(T value, int lane)
{
__target_switch
@@ -6858,7 +6860,7 @@ T WaveReadLaneAt(T value, int lane)
case hlsl: __intrinsic_asm "WaveReadLaneAt";
case spirv:
let ulane = uint(lane);
- return spirv_asm {OpGroupNonUniformShuffle $$T result Subgroup $value $ulane};
+ return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$T result Subgroup $value $ulane};
default:
return WaveMaskReadLaneAt(WaveGetActiveMask(), value, lane);
}
@@ -6867,7 +6869,6 @@ T WaveReadLaneAt(T value, int lane)
__generic<T : __BuiltinType, let N : int>
__spirv_version(1.3)
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
-__spirv_capability(GroupNonUniformShuffle)
vector<T,N> WaveReadLaneAt(vector<T,N> value, int lane)
{
__target_switch
@@ -6876,7 +6877,7 @@ vector<T,N> WaveReadLaneAt(vector<T,N> value, int lane)
case hlsl: __intrinsic_asm "WaveReadLaneAt";
case spirv:
let ulane = uint(lane);
- return spirv_asm {OpGroupNonUniformShuffle $$vector<T,N> result Subgroup $value $ulane};
+ return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$vector<T,N> result Subgroup $value $ulane};
default:
return WaveMaskReadLaneAt(WaveGetActiveMask(), value, lane);
}
@@ -6896,7 +6897,6 @@ matrix<T, N, M> WaveReadLaneAt(matrix<T, N, M> value, int lane)
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformShuffle)
T WaveShuffle(T value, int lane)
{
__target_switch
@@ -6905,7 +6905,7 @@ T WaveShuffle(T value, int lane)
case hlsl: __intrinsic_asm "WaveReadLaneAt";
case spirv:
let ulane = uint(lane);
- return spirv_asm {OpGroupNonUniformShuffle $$T result Subgroup $value $ulane};
+ return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$T result Subgroup $value $ulane};
default:
return WaveMaskShuffle(WaveGetActiveMask(), value, lane);
}
@@ -6914,7 +6914,6 @@ T WaveShuffle(T value, int lane)
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_shuffle)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformShuffle)
vector<T,N> WaveShuffle(vector<T,N> value, int lane)
{
__target_switch
@@ -6923,7 +6922,7 @@ vector<T,N> WaveShuffle(vector<T,N> value, int lane)
case hlsl: __intrinsic_asm "WaveReadLaneAt";
case spirv:
let ulane = uint(lane);
- return spirv_asm {OpGroupNonUniformShuffle $$vector<T,N> result Subgroup $value $ulane};
+ return spirv_asm {OpCapability GroupNonUniformShuffle; OpGroupNonUniformShuffle $$vector<T,N> result Subgroup $value $ulane};
default:
return WaveMaskShuffle(WaveGetActiveMask(), value, lane);
}
@@ -6938,7 +6937,6 @@ matrix<T, N, M> WaveShuffle(matrix<T, N, M> value, int lane)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformBallot)
uint WavePrefixCountBits(bool value)
{
__target_switch
@@ -6948,6 +6946,7 @@ uint WavePrefixCountBits(bool value)
case spirv:
return spirv_asm
{
+ OpCapability GroupNonUniformBallot;
%mask:$$uint4 = OpGroupNonUniformBallot Subgroup $value;
OpGroupNonUniformBallotBitCount $$uint result Subgroup 2 %mask
};
@@ -6958,7 +6957,6 @@ uint WavePrefixCountBits(bool value)
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-__spirv_capability(GroupNonUniformBallot)
uint4 WaveGetConvergedMulti()
{
__target_switch
@@ -6970,6 +6968,7 @@ uint4 WaveGetConvergedMulti()
let _true = true;
return spirv_asm
{
+ OpCapability GroupNonUniformBallot;
OpGroupNonUniformBallot $$uint4 result Subgroup $_true
};
}
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 68824b931..bf79028d7 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -132,15 +132,6 @@ class RequiredSPIRVVersionModifier : public Modifier
};
// A modifier to tag something as an intrinsic that requires
-// a certain SPIRV capability to be enabled when used.
-class RequiredSPIRVCapabilityModifier : public Modifier
-{
- SLANG_AST_CLASS(RequiredSPIRVCapabilityModifier)
- int32_t capability;
- String extensionName;
-};
-
-// A modifier to tag something as an intrinsic that requires
// a certain CUDA SM version to be enabled when used. Specified as "major.minor"
class RequiredCUDASMVersionModifier : public Modifier
{
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 62e7d428b..0ba637978 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -3310,7 +3310,12 @@ struct SPIRVEmitContext
// Find phi arguments from incoming branch instructions that target `block`.
for (auto use = block->firstUse; use; use = use->nextUse)
{
- auto branchInst = use->getUser();
+ auto branchInst = as<IRUnconditionalBranch>(use->getUser());
+ if (!branchInst)
+ continue;
+ if (branchInst->getTargetBlock() != inst->getParent())
+ continue;
+
UInt argStartIndex = 0;
switch (branchInst->getOp())
{
@@ -4742,27 +4747,6 @@ struct SPIRVEmitContext
}
}
- void handleRequiredCapabilitiesImpl(IRInst* inst)
- {
- for (auto decoration : inst->getDecorations())
- {
- switch (decoration->getOp())
- {
- default:
- break;
- case kIROp_RequireSPIRVCapabilityDecoration:
- requireSPIRVCapability((SpvCapability)getIntVal(decoration->getOperand(0)));
- if (decoration->getOperandCount() == 2)
- {
- auto stringLit = as<IRStringLit>(decoration->getOperand(1));
- if (stringLit->getStringSlice().getLength())
- ensureExtensionDeclaration(stringLit->getStringSlice());
- }
- break;
- }
- }
- }
-
SPIRVEmitContext(IRModule* module, TargetRequest* target, DiagnosticSink* sink)
: SPIRVEmitSharedContext(module, target, sink)
, m_irModule(module)
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 6a49e9842..bbf6885a8 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -10,6 +10,7 @@
#include "slang-ir-byte-address-legalize.h"
#include "slang-ir-collect-global-uniforms.h"
#include "slang-ir-cleanup-void.h"
+#include "slang-ir-composite-reg-to-mem.h"
#include "slang-ir-dce.h"
#include "slang-ir-diff-call.h"
#include "slang-ir-autodiff.h"
@@ -360,7 +361,7 @@ Result linkAndOptimizeIR(
// Lower all the LValue implict casts (used for out/inout/ref scenarios)
lowerLValueCast(targetRequest, irModule);
- simplifyIR(irModule, sink);
+ simplifyIR(irModule, IRSimplificationOptions::getDefault(), sink);
// Fill in default matrix layout into matrix types that left layout unspecified.
specializeMatrixLayout(codeGenContext->getTargetReq(), irModule);
@@ -472,7 +473,7 @@ Result linkAndOptimizeIR(
validateIRModuleIfEnabled(codeGenContext, irModule);
- simplifyIR(irModule, sink);
+ simplifyIR(irModule, IRSimplificationOptions::getFast(), sink);
if (!ArtifactDescUtil::isCpuLikeTarget(artifactDesc))
{
@@ -501,7 +502,7 @@ Result linkAndOptimizeIR(
// up downstream passes like type legalization, so we
// will run a DCE pass to clean up after the specialization.
//
- simplifyIR(irModule, sink);
+ simplifyIR(irModule, IRSimplificationOptions::getDefault(), sink);
validateIRModuleIfEnabled(codeGenContext, irModule);
@@ -591,7 +592,7 @@ Result linkAndOptimizeIR(
// to see if we can clean up any temporaries created by legalization.
// (e.g., things that used to be aggregated might now be split up,
// so that we can work with the individual fields).
- simplifyIR(irModule, sink);
+ simplifyIR(irModule, IRSimplificationOptions::getFast(), sink);
#if 0
dumpIRIfEnabled(codeGenContext, irModule, "AFTER SSA");
@@ -924,12 +925,20 @@ Result linkAndOptimizeIR(
// bit_cast on basic types.
lowerBitCast(targetRequest, irModule);
- eliminateMultiLevelBreak(irModule);
if (isKhronosTarget(targetRequest) && targetRequest->shouldEmitSPIRVDirectly())
- performIntrinsicFunctionFunctionInlining(irModule);
+ {
+ //performIntrinsicFunctionFunctionInlining(irModule);
+ performSpirvInlining(irModule);
+ eliminateDeadCode(irModule);
+ }
+ eliminateMultiLevelBreak(irModule);
- simplifyIR(irModule, sink);
+ {
+ IRSimplificationOptions simplificationOptions = IRSimplificationOptions::getFast();
+ simplificationOptions.cfgOptions.removeTrivialSingleIterationLoops = true;
+ simplifyIR(irModule, IRSimplificationOptions::getFast(), sink);
+ }
// As a late step, we need to take the SSA-form IR and move things *out*
// of SSA form, by eliminating all "phi nodes" (block parameters) and
@@ -956,7 +965,13 @@ Result linkAndOptimizeIR(
}
// We only want to accumulate locations if liveness tracking is enabled.
- eliminatePhis(livenessMode, irModule);
+ PhiEliminationOptions phiEliminationOptions;
+ if (isKhronosTarget(targetRequest) && targetRequest->shouldEmitSPIRVDirectly())
+ {
+ phiEliminationOptions.eliminateCompositeTypedPhiOnly = false;
+ phiEliminationOptions.useRegisterAllocation = true;
+ }
+ eliminatePhis(livenessMode, irModule, phiEliminationOptions);
#if 0
dumpIRIfEnabled(codeGenContext, irModule, "PHIS ELIMINATED");
#endif
@@ -1000,7 +1015,7 @@ Result linkAndOptimizeIR(
}
// Run a final round of simplifications to clean up unused things after phi-elimination.
- simplifyNonSSAIR(irModule);
+ simplifyNonSSAIR(irModule, IRSimplificationOptions::getFast());
// We include one final step to (optionally) dump the IR and validate
// it after all of the optimization passes are complete. This should
@@ -1263,15 +1278,17 @@ SlangResult emitSPIRVForEntryPointsDirectly(
List<uint8_t> spirv, outSpirv;
emitSPIRVFromIR(codeGenContext, irModule, irEntryPoints, spirv);
+#if 0
String optErr;
if (SLANG_FAILED(optimizeSPIRV(spirv, optErr, outSpirv)))
{
codeGenContext->getSink()->diagnose(SourceLoc(), Diagnostics::spirvOptFailed, optErr);
- outSpirv = _Move(spirv);
+ spirv = _Move(outSpirv);
}
+#endif
auto artifact = ArtifactUtil::createArtifactForCompileTarget(asExternal(codeGenContext->getTargetFormat()));
- artifact->addRepresentationUnknown(ListBlob::moveCreate(outSpirv));
+ artifact->addRepresentationUnknown(ListBlob::moveCreate(spirv));
ArtifactUtil::addAssociated(artifact, linkedIR.metadata);
diff --git a/source/slang/slang-ir-array-reg-to-mem.cpp b/source/slang/slang-ir-array-reg-to-mem.cpp
deleted file mode 100644
index 34bd5b148..000000000
--- a/source/slang/slang-ir-array-reg-to-mem.cpp
+++ /dev/null
@@ -1,88 +0,0 @@
-#include "slang-ir-array-reg-to-mem.h"
-
-#include "slang-ir.h"
-#include "slang-ir-insts.h"
-#include "slang-ir-util.h"
-
-namespace Slang
-{
- bool eliminateArrayTypeParameters(IRFunc* func)
- {
- IRBuilder builder(func);
- bool changed = false;
- List<UInt> arrayParamIds;
- UInt idx = 0;
- List<IRParam*> paramWorkList;
- for (auto param : func->getParams())
- {
- if (auto arrayType = as<IRArrayTypeBase>(param->getFullType()))
- {
- paramWorkList.add(param);
- arrayParamIds.add(idx);
- }
- idx++;
- }
- for (auto param : paramWorkList)
- {
- // We have an array type parameter, so we need to replace it with a pointer to the array
- // type.
- //
- // We will also need to insert a `load` instruction at the start of the function body
- // to load the actual pointer value from the parameter.
- //
- if (auto arrayType = as<IRArrayTypeBase>(param->getFullType()))
- {
- changed = true;
- auto ptrArrayType = builder.getPtrType(arrayType);
- auto newParam = builder.createParam(ptrArrayType);
- newParam->insertBefore(param);
- setInsertAfterOrdinaryInst(&builder, param);
- auto regVal = builder.emitLoad(newParam);
- param->replaceUsesWith(regVal);
- param->removeAndDeallocate();
- }
- }
- if (changed)
- {
- // The function is modified, we need to also update its type.
- List<IRType*> paramTypes;
- for (auto param : func->getParams())
- {
- paramTypes.add(param->getFullType());
- }
- auto newFuncType = builder.getFuncType((UInt)paramTypes.getCount(), paramTypes.getBuffer(), func->getResultType());
- func->setFullType(newFuncType);
-
- // Update all the call sites to pass the arrays by pointer.
- traverseUses(func, [&](IRUse* use)
- {
- if (const auto call = as<IRCall>(use->getUser()))
- {
- builder.setInsertBefore(call);
- for (auto paramId : arrayParamIds)
- {
- auto arg = call->getArg(paramId);
- SLANG_ASSERT(as<IRPtrTypeBase>(paramTypes[paramId]));
- auto var = builder.emitVar(as<IRPtrTypeBase>(paramTypes[paramId])->getValueType());
- builder.emitStore(var, arg);
- call->setArg(paramId, var);
- }
- }
- });
- }
- return changed;
- }
-
- bool eliminateArrayTypeSSARegisters(IRModule* module)
- {
- bool changed = false;
- for (auto inst : module->getGlobalInsts())
- {
- if (auto func = as<IRFunc>(inst))
- {
- changed |= eliminateArrayTypeParameters(func);
- }
- }
- return changed;
- }
-}
diff --git a/source/slang/slang-ir-array-reg-to-mem.h b/source/slang/slang-ir-array-reg-to-mem.h
deleted file mode 100644
index 941125e1c..000000000
--- a/source/slang/slang-ir-array-reg-to-mem.h
+++ /dev/null
@@ -1,16 +0,0 @@
-// slang-ir-array-reg-to-mem.h
-#pragma once
-
-namespace Slang
-{
- struct IRModule;
- struct IRCall;
- struct IRInst;
- struct IRFunc;
-
- /// Eliminate SSA registers and IRParams of array type and turn them into pointers to memory objects.
- bool eliminateArrayTypeSSARegisters(IRModule* module);
-
- bool eliminateArrayTypeParameters(IRFunc* func);
-
-}
diff --git a/source/slang/slang-ir-autodiff-cfg-norm.cpp b/source/slang/slang-ir-autodiff-cfg-norm.cpp
index 3602e77ae..bb0b4a8c6 100644
--- a/source/slang/slang-ir-autodiff-cfg-norm.cpp
+++ b/source/slang/slang-ir-autodiff-cfg-norm.cpp
@@ -798,7 +798,7 @@ void normalizeCFG(
// Remove phis to simplify our pass. We'll add them back in later
// with constructSSA.
//
- eliminatePhisInFunc(LivenessMode::Disabled, func->getModule(), func, false);
+ eliminatePhisInFunc(LivenessMode::Disabled, func->getModule(), func, PhiEliminationOptions::getFast());
CFGNormalizationContext context = {module, options.sink};
CFGNormalizationPass cfgPass(context);
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 2465c1b74..8f3ab2a34 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -1702,7 +1702,7 @@ SlangResult ForwardDiffTranscriber::prepareFuncForForwardDiff(IRFunc* func)
if (SLANG_SUCCEEDED(result))
{
disableIRValidationAtInsert();
- simplifyFunc(func);
+ simplifyFunc(func, IRSimplificationOptions::getDefault());
enableIRValidationAtInsert();
}
return result;
diff --git a/source/slang/slang-ir-composite-reg-to-mem.cpp b/source/slang/slang-ir-composite-reg-to-mem.cpp
new file mode 100644
index 000000000..243a0e2b0
--- /dev/null
+++ b/source/slang/slang-ir-composite-reg-to-mem.cpp
@@ -0,0 +1,202 @@
+#include "slang-ir-composite-reg-to-mem.h"
+
+#include "slang-ir.h"
+#include "slang-ir-insts.h"
+#include "slang-ir-util.h"
+#include "slang-ir-dce.h"
+
+namespace Slang
+{
+ struct RegisterReplacementWorkItem
+ {
+ IRInst* ssaValue;
+ IRInst* addr;
+ IRInst* initialStore;
+ };
+
+ void replaceRegisterUseWithAddrUse(
+ List<RegisterReplacementWorkItem>& workList,
+ IRInst* ssaValue,
+ IRInst* addr,
+ IRInst* initialStore)
+ {
+ IRBuilder builder(ssaValue);
+ traverseUses(ssaValue, [&](IRUse* use)
+ {
+ auto user = use->getUser();
+ if (user == initialStore)
+ return;
+ builder.setInsertBefore(user);
+ IRInst* newAddr = nullptr;
+ // If the user is itself a getElement/getField inst,
+ // we want to follow that chain and recursively replace
+ // their users.
+ if (auto getElementUser = as<IRGetElement>(user))
+ {
+ if (getElementUser->getOperands() == use)
+ {
+ newAddr = builder.emitElementAddress(
+ builder.getPtrType(user->getFullType()),
+ addr,
+ getElementUser->getIndex());
+ }
+ }
+ else if (auto getFieldUser = as<IRFieldExtract>(user))
+ {
+ if (getFieldUser->getOperands() == use)
+ {
+ newAddr = builder.emitFieldAddress(
+ builder.getPtrType(user->getFullType()),
+ addr,
+ getFieldUser->getField());
+ }
+ }
+ if (newAddr)
+ {
+ workList.add(RegisterReplacementWorkItem{ user, newAddr, nullptr });
+ }
+ else
+ {
+ // For all other uses, we emit a load from addr and use it.
+ auto val = builder.emitLoad(addr);
+ builder.replaceOperand(use, val);
+ }
+ });
+ }
+
+ void replaceRegisterUseWithAddrUse(IRInst* ssaValue, IRInst* addr, IRInst* initialStore)
+ {
+ List<RegisterReplacementWorkItem> workList, pendingWorkList;
+ workList.add(RegisterReplacementWorkItem{ ssaValue, addr, initialStore });
+
+ while (workList.getCount())
+ {
+ for (auto item : workList)
+ {
+ replaceRegisterUseWithAddrUse(pendingWorkList, item.ssaValue, item.addr, item.initialStore);
+ }
+ workList.swapWith(pendingWorkList);
+ pendingWorkList.clear();
+ }
+ }
+
+ void convertCompositeTypeParametersToPointers(IRFunc* func)
+ {
+ IRBuilder builder(func);
+ List<UInt> compositeParamIds;
+ UInt idx = 0;
+ List<IRParam*> paramWorkList;
+ if (!func->findDecoration<IREntryPointDecoration>())
+ {
+ // Only translate function parameters for non entry points.
+ for (auto param : func->getParams())
+ {
+ if (as<IRArrayTypeBase>(param->getFullType()) ||
+ as<IRStructType>(param->getFullType()))
+ {
+ paramWorkList.add(param);
+ compositeParamIds.add(idx);
+ }
+ idx++;
+ }
+ }
+ for (auto param : paramWorkList)
+ {
+ // We have a composite type parameter, so we need to replace it with a pointer.
+ //
+
+ auto ptrCompositeType = builder.getPtrType(param->getFullType());
+ auto newParam = builder.createParam(ptrCompositeType);
+ newParam->insertBefore(param);
+ replaceRegisterUseWithAddrUse(param, newParam, nullptr);
+ param->removeAndDeallocate();
+ }
+ if (paramWorkList.getCount())
+ {
+ // The function is modified, we need to also update its type.
+ List<IRType*> paramTypes;
+ for (auto param : func->getParams())
+ {
+ paramTypes.add(param->getFullType());
+ }
+ auto newFuncType = builder.getFuncType((UInt)paramTypes.getCount(), paramTypes.getBuffer(), func->getResultType());
+ func->setFullType(newFuncType);
+
+ // Update all the call sites to pass the composite by pointer.
+ traverseUses(func, [&](IRUse* use)
+ {
+ if (const auto call = as<IRCall>(use->getUser()))
+ {
+ builder.setInsertBefore(call);
+ for (auto paramId : compositeParamIds)
+ {
+ auto arg = call->getArg(paramId);
+ SLANG_ASSERT(as<IRPtrTypeBase>(paramTypes[paramId]));
+ auto var = builder.emitVar(as<IRPtrTypeBase>(paramTypes[paramId])->getValueType());
+ builder.emitStore(var, arg);
+ call->setArg(paramId, var);
+ }
+ }
+ });
+ }
+
+ // Now work through all the local values and process uses of `Load(composite)`.
+ for (auto block : func->getBlocks())
+ {
+ for (auto inst : block->getModifiableChildren())
+ {
+ if (!as<IRArrayTypeBase>(inst->getDataType()) &&
+ !as<IRStructType>(inst->getDataType()))
+ continue;
+ if (inst->getParent() != block)
+ continue;
+ IRInst* tempVar = nullptr;
+ IRInst* initialStore = nullptr;
+ builder.setInsertAfter(inst);
+ switch (inst->getOp())
+ {
+ case kIROp_Load:
+ {
+ auto ptr = inst->getOperand(0);
+ auto rootPtr = getRootAddr(ptr);
+ if (as<IRConstantBufferType>(rootPtr->getDataType()) ||
+ as<IRParameterBlockType>(rootPtr->getDataType()))
+ {
+ tempVar = ptr;
+ }
+ else
+ {
+ tempVar = builder.emitVar(inst->getFullType());
+ initialStore = builder.emitStore(tempVar, inst);
+ }
+ break;
+ }
+ case kIROp_Call:
+ {
+ tempVar = builder.emitVar(inst->getFullType());
+ initialStore = builder.emitStore(tempVar, inst);
+ break;
+ }
+ default:
+ break;
+ }
+
+ if (!tempVar)
+ continue;
+ replaceRegisterUseWithAddrUse(inst, tempVar, initialStore);
+ }
+ }
+ eliminateDeadCode(func);
+ }
+
+ void convertCompositeTypeParametersToPointers(IRModule* module)
+ {
+ for (auto inst : module->getGlobalInsts())
+ {
+ if (auto func = as<IRFunc>(inst))
+ {
+ convertCompositeTypeParametersToPointers(func);
+ }
+ }
+ }
+}
diff --git a/source/slang/slang-ir-composite-reg-to-mem.h b/source/slang/slang-ir-composite-reg-to-mem.h
new file mode 100644
index 000000000..9ccfc8734
--- /dev/null
+++ b/source/slang/slang-ir-composite-reg-to-mem.h
@@ -0,0 +1,13 @@
+// slang-ir-composite-reg-to-mem.h
+#pragma once
+
+namespace Slang
+{
+ struct IRModule;
+ struct IRCall;
+ struct IRInst;
+ struct IRFunc;
+
+ /// Convert parameters of composite type into pointers and modify the callsites accordingly.
+ void convertCompositeTypeParametersToPointers(IRModule* module);
+}
diff --git a/source/slang/slang-ir-eliminate-multilevel-break.cpp b/source/slang/slang-ir-eliminate-multilevel-break.cpp
index 5ff71a248..eafdbbcec 100644
--- a/source/slang/slang-ir-eliminate-multilevel-break.cpp
+++ b/source/slang/slang-ir-eliminate-multilevel-break.cpp
@@ -303,7 +303,6 @@ struct EliminateMultiLevelBreakContext
void processFunc(IRGlobalValueWithCode* func)
{
-
normalizeBranchesIntoBreakBlocks(func);
// If func does not have any multi-level breaks, return.
@@ -316,7 +315,7 @@ struct EliminateMultiLevelBreakContext
}
// To make things easy, eliminate Phis before perform transformations.
- eliminatePhisInFunc(LivenessMode::Disabled, irModule, func);
+ eliminatePhisInFunc(LivenessMode::Disabled, irModule, func, PhiEliminationOptions::getFast());
// Before modifying the cfg, we gather all required info from the existing cfg.
FuncContext funcInfo;
diff --git a/source/slang/slang-ir-eliminate-phis.cpp b/source/slang/slang-ir-eliminate-phis.cpp
index 1023e6148..b17fad6ec 100644
--- a/source/slang/slang-ir-eliminate-phis.cpp
+++ b/source/slang/slang-ir-eliminate-phis.cpp
@@ -1,6 +1,7 @@
// slang-ir-eliminate-phis.cpp
#include "slang-ir-eliminate-phis.h"
#include "slang-ir-ssa-register-allocate.h"
+#include "slang-ir-util.h"
// This file implements a pass to take code in the Slang IR out out SSA form
// by eliminating all "phi nodes."
@@ -68,20 +69,20 @@ struct PhiEliminationContext
IRModule* m_module = nullptr;
IRBuilder m_builder;
LivenessMode m_livenessMode;
- bool m_useRegisterAllocation;
+ PhiEliminationOptions m_options;
PhiEliminationContext(LivenessMode livenessMode, IRModule* module)
: m_module(module)
, m_builder(module)
, m_livenessMode(livenessMode)
- , m_useRegisterAllocation(true)
+ , m_options()
{}
- PhiEliminationContext(LivenessMode livenessMode, IRModule* module, bool useRegisterAllocation)
+ PhiEliminationContext(LivenessMode livenessMode, IRModule* module, PhiEliminationOptions options)
: m_module(module)
, m_builder(module)
, m_livenessMode(livenessMode)
- , m_useRegisterAllocation(useRegisterAllocation)
+ , m_options(options)
{}
// We start with the top-down logic of the pass, which is to process
@@ -220,9 +221,9 @@ struct PhiEliminationContext
m_func = func;
m_dominatorTree = nullptr;
- if (m_useRegisterAllocation)
+ if (m_options.useRegisterAllocation)
{
- m_registerAllocation = allocateRegistersForFunc(func, m_dominatorTree);
+ m_registerAllocation = allocateRegistersForFunc(func, m_dominatorTree, m_options.eliminateCompositeTypedPhiOnly);
m_mapRegToTempVar = createTempVarForInsts(func);
}
}
@@ -435,7 +436,7 @@ struct PhiEliminationContext
{
Index paramIndex = paramCounter++;
mapParamToIndex.add(param, paramIndex);
-
+
IRInst* temp = nullptr;
// Have we already allocated a register for this inst?
@@ -445,7 +446,9 @@ struct PhiEliminationContext
m_mapRegToTempVar.tryGetValue(registerInfo->get(), temp);
}
- if (!temp)
+ bool shouldAllocTemp = !m_options.eliminateCompositeTypedPhiOnly || isCompositeType(param->getFullType());
+
+ if (!temp && shouldAllocTemp)
{
// Note that the `emitVar` operation expects to be passed the
// type *stored* in the variable, but the IR `var` instruction
@@ -471,7 +474,7 @@ struct PhiEliminationContext
PhiInfo phiInfo;
auto& paramInfo = phiInfo.param;
paramInfo.param = param;
- paramInfo.temp = cast<IRVar>(temp);
+ paramInfo.temp = as<IRVar>(temp);
phiInfos.add(phiInfo);
}
}
@@ -515,6 +518,9 @@ struct PhiEliminationContext
auto& paramInfo = phiInfo.param;
auto param = paramInfo.param;
auto temp = paramInfo.temp;
+
+ if (!temp)
+ continue;
// We will repeatedly replace whatever the *first*
// use of `param` is, until there are no more uses
@@ -738,6 +744,10 @@ struct PhiEliminationContext
{
assignment.state = kState_Done;
}
+ else if (!dstParam.temp)
+ {
+ assignment.state = kState_Done;
+ }
else
{
// Otherwise we start out assuming that the assignment is ready
@@ -909,17 +919,24 @@ struct PhiEliminationContext
static const Count kMaxNewOperandCount = 3;
SLANG_ASSERT(newOperandCount <= kMaxNewOperandCount);
- IRInst* newOperands[kMaxNewOperandCount] = {};
+ ShortList<IRInst*> newOperands;
for (Index i = 0; i < newOperandCount; ++i)
{
- newOperands[i] = oldBranch->getOperand(i);
+ newOperands.add(oldBranch->getOperand(i));
+ }
+
+ // Add operands for any remaining phi parameters that has not been eliminated.
+ for (UInt i = 0; i < (UInt)phiInfos.getCount(); i++)
+ {
+ if (!phiInfos[i].param.temp)
+ newOperands.add(oldBranch->getArg(i));
}
auto newBranch = m_builder.emitIntrinsicInst(
oldBranch->getFullType(),
oldBranch->getOp(),
- newOperandCount,
- newOperands);
+ newOperands.getCount(),
+ newOperands.getArrayView().getBuffer());
oldBranch->transferDecorationsTo(newBranch);
// TODO: We could consider just modifying `branch` in-place by clearing
@@ -1122,9 +1139,9 @@ struct PhiEliminationContext
}
};
-void eliminatePhis(LivenessMode livenessMode, IRModule* module, bool useRegisterAllocation)
+void eliminatePhis(LivenessMode livenessMode, IRModule* module, PhiEliminationOptions options)
{
- PhiEliminationContext context(livenessMode, module, useRegisterAllocation);
+ PhiEliminationContext context(livenessMode, module, options);
context.eliminatePhisInModule();
}
@@ -1132,9 +1149,9 @@ void eliminatePhisInFunc(
LivenessMode livenessMode,
IRModule* module,
IRGlobalValueWithCode* func,
- bool useRegisterAllocation)
+ PhiEliminationOptions options)
{
- PhiEliminationContext context(livenessMode, module, useRegisterAllocation);
+ PhiEliminationContext context(livenessMode, module, options);
context.eliminatePhisInFunc(func);
}
diff --git a/source/slang/slang-ir-eliminate-phis.h b/source/slang/slang-ir-eliminate-phis.h
index 9bfbe51b8..6655a16e4 100644
--- a/source/slang/slang-ir-eliminate-phis.h
+++ b/source/slang/slang-ir-eliminate-phis.h
@@ -8,6 +8,13 @@ namespace Slang
struct CodeGenContext;
struct IRModule;
+ struct PhiEliminationOptions
+ {
+ bool eliminateCompositeTypedPhiOnly = false;
+ bool useRegisterAllocation = true;
+ static PhiEliminationOptions getFast() { return PhiEliminationOptions{ false, false }; }
+ };
+
/// Eliminate all "phi nodes" from the given `module`.
///
/// This process moves the code in `module` *out* of SSA form,
@@ -15,11 +22,11 @@ namespace Slang
/// are not themselves based on an SSA representation.
///
/// If livenessMode is enabled LiveRangeStarts will be inserted into the module.
- void eliminatePhis(LivenessMode livenessMode, IRModule* module, bool useRegisterAllocation = true);
+ void eliminatePhis(LivenessMode livenessMode, IRModule* module, PhiEliminationOptions options);
void eliminatePhisInFunc(
LivenessMode livenessMode,
IRModule* module,
IRGlobalValueWithCode* func,
- bool useRegisterAllocation = true);
+ PhiEliminationOptions options);
}
diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp
index 8412a1912..3b01d7dde 100644
--- a/source/slang/slang-ir-inline.cpp
+++ b/source/slang/slang-ir-inline.cpp
@@ -252,9 +252,6 @@ struct InliningPassBase
{
case kIROp_IntrinsicOpDecoration:
return true;
- case kIROp_RequireSPIRVCapabilityDecoration:
- // Don't inline a function with spirv capability decoration to avoid losing it.
- return false;
}
}
@@ -947,7 +944,7 @@ void performGLSLResourceReturnFunctionInlining(IRModule* module)
while (changed)
{
changed = pass.considerAllCallSites();
- simplifyIR(module);
+ simplifyIR(module, IRSimplificationOptions::getFast());
}
}
@@ -964,8 +961,6 @@ struct IntrinsicFunctionInliningPass : InliningPassBase
auto func = as<IRFunc>(getResolvedInstForDecorations(info.callee));
if (!func)
return false;
- if (func->findDecorationImpl(kIROp_RequireSPIRVCapabilityDecoration))
- return false;
auto returnInst = as<IRReturn>(func->getFirstBlock()->getTerminator());
if (!returnInst)
return false;
@@ -1025,4 +1020,32 @@ bool inlineCall(IRCall* call)
}
+struct SpirvInliningPass : InliningPassBase
+{
+ typedef InliningPassBase Super;
+
+ SpirvInliningPass(IRModule* module)
+ : Super(module)
+ {}
+
+ bool shouldInline(CallSiteInfo const& info)
+ {
+ if (!info.callee->findDecoration<IREntryPointDecoration>())
+ return true;
+ return false;
+ }
+};
+
+void performSpirvInlining(IRModule* module)
+{
+ SLANG_PROFILE;
+ while (true)
+ {
+ SpirvInliningPass pass(module);
+ if (pass.considerAllCallSites())
+ continue;
+ break;
+ }
+}
+
} // namespace Slang
diff --git a/source/slang/slang-ir-inline.h b/source/slang/slang-ir-inline.h
index 539bb26c0..5e888ad07 100644
--- a/source/slang/slang-ir-inline.h
+++ b/source/slang/slang-ir-inline.h
@@ -34,6 +34,9 @@ namespace Slang
/// Inline simple intrinsic functions whose definition is a single asm block.
void performIntrinsicFunctionFunctionInlining(IRModule* module);
+ /// Inline all functions for SPIRV emit.
+ void performSpirvInlining(IRModule* module);
+
/// Inline a specific call.
bool inlineCall(IRCall* call);
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 5a8d7fe06..bd32a1896 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -688,7 +688,6 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(VulkanHitObjectAttributesDecoration, vulkanHitObjectAttributes, 0, 0)
INST(RequireSPIRVVersionDecoration, requireSPIRVVersion, 1, 0)
- INST(RequireSPIRVCapabilityDecoration, requireSPIRVCapability, 1, 0)
INST(RequireGLSLVersionDecoration, requireGLSLVersion, 1, 0)
INST(RequireGLSLExtensionDecoration, requireGLSLExtension, 1, 0)
INST(RequireCUDASMVersionDecoration, requireCUDASMVersion, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 9ff2df889..070f989b5 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -4264,25 +4264,6 @@ public:
addDecoration(value, kIROp_RequireSPIRVVersionDecoration, getIntValue(getBasicType(BaseType::UInt64), intValue));
}
- void addRequireSPIRVCapabilityDecoration(IRInst* value, int32_t capabilityName, UnownedStringSlice extensionName)
- {
- if (extensionName.getLength())
- {
- addDecoration(
- value,
- kIROp_RequireSPIRVCapabilityDecoration,
- getIntValue(getIntType(), IRIntegerValue(capabilityName)),
- getStringValue(extensionName));
- }
- else
- {
- addDecoration(
- value,
- kIROp_RequireSPIRVCapabilityDecoration,
- getIntValue(getIntType(), IRIntegerValue(capabilityName)));
- }
- }
-
void addRequireCUDASMVersionDecoration(IRInst* value, const SemanticVersion& version)
{
SemanticVersion::IntegerType intValue = version.toInteger();
diff --git a/source/slang/slang-ir-loop-unroll.cpp b/source/slang/slang-ir-loop-unroll.cpp
index c4ef1650c..b5af2d974 100644
--- a/source/slang/slang-ir-loop-unroll.cpp
+++ b/source/slang/slang-ir-loop-unroll.cpp
@@ -459,7 +459,7 @@ bool unrollLoopsInFunc(
// Make sure we simplify things as much as possible before
// attempting to potentially unroll outer loop.
- simplifyCFG(func);
+ simplifyCFG(func, CFGSimplificationOptions::getDefault());
eliminateDeadCode(func);
}
return true;
diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp
index 722515692..2002b42cc 100644
--- a/source/slang/slang-ir-lower-generics.cpp
+++ b/source/slang/slang-ir-lower-generics.cpp
@@ -255,7 +255,7 @@ namespace Slang
// real RTTI objects and witness tables.
specializeRTTIObjects(&sharedContext, sink);
- simplifyIR(module);
+ simplifyIR(module, IRSimplificationOptions::getFast());
lowerTuples(module, sink);
if (sink->getErrorCount() != 0)
diff --git a/source/slang/slang-ir-peephole.cpp b/source/slang/slang-ir-peephole.cpp
index a7e89eb4f..b6b7823c9 100644
--- a/source/slang/slang-ir-peephole.cpp
+++ b/source/slang/slang-ir-peephole.cpp
@@ -359,6 +359,7 @@ struct PeepholeContext : InstPassBase
changed = true;
break;
}
+ startIndex += vecSize->getValue();
}
else
{
@@ -807,7 +808,6 @@ struct PeepholeContext : InstPassBase
case kIROp_Or:
changed = tryOptimizeArithmeticInst(inst);
break;
-
case kIROp_Param:
{
auto block = as<IRBlock>(inst->parent);
diff --git a/source/slang/slang-ir-reachability.cpp b/source/slang/slang-ir-reachability.cpp
index 1a3cab386..1a4aa271b 100644
--- a/source/slang/slang-ir-reachability.cpp
+++ b/source/slang/slang-ir-reachability.cpp
@@ -5,35 +5,64 @@ namespace Slang
// Computes whether block1 can reach block2.
// A block is considered not reachable from itself unless there is a backedge in the CFG.
-bool ReachabilityContext::computeReachability(IRBlock* block1, IRBlock* block2)
-{
- workList.clear();
- reachableBlocks.clear();
- workList.add(block1);
- for (Index i = 0; i < workList.getCount(); i++)
+ ReachabilityContext::ReachabilityContext(IRGlobalValueWithCode* code)
{
- auto src = workList[i];
- for (auto successor : src->getSuccessors())
+ int id = 0;
+ for (auto block : code->getBlocks())
+ {
+ mapBlockToId[block] = id;
+ id++;
+ allBlocks.add(block);
+ }
+ sourceBlocks.setCount(allBlocks.getCount());
+ for (auto &srcBlock : sourceBlocks)
+ srcBlock.resizeAndClear(allBlocks.getCount());
+
+ if (allBlocks.getCount() == 0)
+ return;
+
+ List<IRBlock*> workList;
+ List<IRBlock*> pendingWorkList;
+ workList.add(allBlocks[0]);
+ while (workList.getCount())
{
- if (successor == block2)
- return true;
- if (reachableBlocks.add(successor))
- workList.add(successor);
+ pendingWorkList.clear();
+ for (Index i = 0; i < workList.getCount(); i++)
+ {
+ auto src = workList[i];
+ auto srcId = mapBlockToId.getValue(src);
+ for (auto successor : src->getSuccessors())
+ {
+ auto successorId = mapBlockToId.getValue(successor);
+ auto& blockSet = sourceBlocks[successorId];
+ bool changed = false;
+ if (!blockSet.contains(srcId))
+ {
+ blockSet.add(srcId);
+ changed = true;
+ }
+ if (!blockSet.contains(sourceBlocks[srcId]))
+ {
+ blockSet.unionWith(sourceBlocks[srcId]);
+ changed = true;
+ }
+ if (changed)
+ pendingWorkList.add(successor);
+ }
+ }
+ workList.swapWith(pendingWorkList);
}
+
}
- return false;
-}
-bool ReachabilityContext::isBlockReachable(IRBlock* from, IRBlock* to)
-{
- BlockPair pair;
- pair.first = from;
- pair.second = to;
- bool result = false;
- if (reachabilityResults.tryGetValue(pair, result))
- return result;
- result = computeReachability(from, to);
- reachabilityResults[pair] = result;
- return result;
-}
+ bool ReachabilityContext::isBlockReachable(IRBlock* from, IRBlock* to)
+ {
+ if (!from) return false;
+ if (!to) return false;
+ int* fromId = mapBlockToId.tryGetValue(from);
+ int* toId = mapBlockToId.tryGetValue(to);
+ if (!fromId || !toId)
+ return true;
+ return sourceBlocks[*toId].contains(*fromId);
+ }
}
diff --git a/source/slang/slang-ir-reachability.h b/source/slang/slang-ir-reachability.h
index 74463b7fe..ef6a182df 100644
--- a/source/slang/slang-ir-reachability.h
+++ b/source/slang/slang-ir-reachability.h
@@ -9,39 +9,18 @@ namespace Slang
// A context for computing and caching reachability between blocks on the CFG.
struct ReachabilityContext
{
- struct BlockPair
- {
- IRBlock* first;
- IRBlock* second;
- bool operator == (const BlockPair& other) const
- {
- return first == other.first && second == other.second;
- }
- HashCode getHashCode() const
- {
- Hasher h;
- h.hashValue(first);
- h.hashValue(second);
- return h.getResult();
- }
- };
- Dictionary<BlockPair, bool> reachabilityResults;
-
- List<IRBlock*> workList;
- HashSet<IRBlock*> reachableBlocks;
+ Dictionary<IRBlock*, int> mapBlockToId;
+ List<IRBlock*> allBlocks;
+ List<UIntSet> sourceBlocks; // sourcesBlocks[i] stores the set of blocks from which block i can be reached.
- // Computes whether block1 can reach block2.
- // A block is considered not reachable from itself unless there is a backedge in the CFG.
- bool computeReachability(IRBlock* block1, IRBlock* block2);
+ ReachabilityContext() = default;
+ ReachabilityContext(IRGlobalValueWithCode* code);
bool isBlockReachable(IRBlock* from, IRBlock* to);
bool isInstReachable(IRInst* inst1, IRInst* inst2)
{
- if (isBlockReachable(as<IRBlock>(inst1->getParent()), as<IRBlock>(inst2->getParent())))
- return true;
-
- // If the parent blocks are not reachable, but inst1 and inst2 are in the same block,
+ // If inst1 and inst2 are in the same block,
// we test if inst2 appears after inst1.
if (inst1->getParent() == inst2->getParent())
{
@@ -52,7 +31,7 @@ struct ReachabilityContext
}
}
- return false;
+ return isBlockReachable(as<IRBlock>(inst1->getParent()), as<IRBlock>(inst2->getParent()));
}
};
diff --git a/source/slang/slang-ir-redundancy-removal.cpp b/source/slang/slang-ir-redundancy-removal.cpp
index 42a31ea91..c7986cfbc 100644
--- a/source/slang/slang-ir-redundancy-removal.cpp
+++ b/source/slang/slang-ir-redundancy-removal.cpp
@@ -53,9 +53,10 @@ struct RedundancyRemovalContext
return changed;
}
- bool removeRedundancyInBlock(DeduplicateContext& deduplicateContext, IRGlobalValueWithCode* func, IRBlock* block)
+ bool removeRedundancyInBlock(Dictionary<IRBlock*, DeduplicateContext>& mapBlockToDedupContext, IRGlobalValueWithCode* func, IRBlock* block)
{
bool result = false;
+ auto& deduplicateContext = mapBlockToDedupContext.getValue(block);
for (auto instP : block->getModifiableChildren())
{
auto resultInst = deduplicateContext.deduplicate(instP, [&](IRInst* inst)
@@ -82,9 +83,8 @@ struct RedundancyRemovalContext
}
for (auto child : dom->getImmediatelyDominatedBlocks(block))
{
- DeduplicateContext subContext;
+ DeduplicateContext& subContext = mapBlockToDedupContext.getValue(child);
subContext.deduplicateMap = deduplicateContext.deduplicateMap;
- result |= removeRedundancyInBlock(subContext, func, child);
}
return result;
}
@@ -116,8 +116,28 @@ bool removeRedundancyInFunc(IRGlobalValueWithCode* func)
RedundancyRemovalContext context;
context.dom = computeDominatorTree(func);
- DeduplicateContext deduplicateCtx;
- bool result = context.removeRedundancyInBlock(deduplicateCtx, func, root);
+ Dictionary<IRBlock*, DeduplicateContext> mapBlockToDeduplicateContext;
+ for (auto block : func->getBlocks())
+ {
+ mapBlockToDeduplicateContext[block] = DeduplicateContext();
+ }
+ List<IRBlock*> workList, pendingWorkList;
+ workList.add(root);
+ bool result = false;
+ while (workList.getCount())
+ {
+ for (auto block : workList)
+ {
+ result |= context.removeRedundancyInBlock(mapBlockToDeduplicateContext, func, block);
+
+ for (auto child : context.dom->getImmediatelyDominatedBlocks(block))
+ {
+ pendingWorkList.add(child);
+ }
+ }
+ workList.swapWith(pendingWorkList);
+ pendingWorkList.clear();
+ }
if (auto normalFunc = as<IRFunc>(func))
{
result |= eliminateRedundantLoadStore(normalFunc);
diff --git a/source/slang/slang-ir-sccp.cpp b/source/slang/slang-ir-sccp.cpp
index d874514fe..fc3766e63 100644
--- a/source/slang/slang-ir-sccp.cpp
+++ b/source/slang/slang-ir-sccp.cpp
@@ -720,6 +720,39 @@ struct SCCPContext
{
SLANG_SCCP_RETURN_IF_NONE_OR_ANY(v0)
auto c0 = as<IRConstant>(v0.value);
+
+ uint64_t sourceValueBits = 0;
+ switch (c0->getDataType()->getOp())
+ {
+ case kIROp_FloatType:
+ {
+ float fval = (float)c0->value.floatVal;
+ memcpy(&sourceValueBits, &fval, sizeof(fval));
+ break;
+ }
+ case kIROp_DoubleType:
+ {
+ double dval = c0->value.floatVal;
+ memcpy(&sourceValueBits, &dval, sizeof(dval));
+ break;
+ }
+ case kIROp_BoolType:
+ {
+ sourceValueBits = c0->value.intVal;
+ break;
+ }
+ default:
+ if (isIntegralType(c0->getDataType()))
+ {
+ sourceValueBits = c0->value.intVal;
+ }
+ else
+ {
+ return LatticeVal::getAny();
+ }
+ break;
+ }
+
IRInst* resultVal = nullptr;
switch (type->getOp())
{
@@ -729,7 +762,7 @@ struct SCCPContext
case kIROp_IntPtrType:
case kIROp_UIntPtrType:
#endif
- resultVal = getBuilder()->getIntValue(type, c0->value.intVal);
+ resultVal = getBuilder()->getIntValue(type, sourceValueBits);
break;
case kIROp_IntType:
case kIROp_UIntType:
@@ -737,21 +770,17 @@ struct SCCPContext
case kIROp_IntPtrType:
case kIROp_UIntPtrType:
#endif
- {
- float val = (float)c0->value.floatVal;
- uint32_t intVal = (uint32_t)FloatAsInt(val);
- resultVal = getBuilder()->getIntValue(type, intVal);
- }
+ resultVal = getBuilder()->getIntValue(type, (uint32_t)sourceValueBits);
break;
case kIROp_FloatType:
{
- uint32_t val = (uint32_t)c0->value.intVal;
+ uint32_t val = (uint32_t)sourceValueBits;
float floatVal = IntAsFloat((int)val);
resultVal = getBuilder()->getFloatValue(type, floatVal);
}
break;
case kIROp_DoubleType:
- resultVal = getBuilder()->getFloatValue(type, Int64AsDouble(c0->value.intVal));
+ resultVal = getBuilder()->getFloatValue(type, Int64AsDouble(sourceValueBits));
break;
default:
break;
diff --git a/source/slang/slang-ir-simplify-cfg.cpp b/source/slang/slang-ir-simplify-cfg.cpp
index e00c24bdc..0fe752c8f 100644
--- a/source/slang/slang-ir-simplify-cfg.cpp
+++ b/source/slang/slang-ir-simplify-cfg.cpp
@@ -15,20 +15,9 @@ struct CFGSimplificationContext
{
RefPtr<RegionTree> regionTree;
RefPtr<IRDominatorTree> domTree;
+ Dictionary<IRInst*, List<IRInst*>> relatedAddrMap;
};
-static BreakableRegion* findBreakableRegion(Region* region)
-{
- for (;;)
- {
- if (auto b = as<BreakableRegion>(region))
- return b;
- region = region->getParent();
- if (!region)
- return nullptr;
- }
-}
-
static bool isBlockInRegion(IRDominatorTree* domTree, IRTerminatorInst* regionHeader, IRBlock* block)
{
auto headerBlock = cast<IRBlock>(regionHeader->getParent());
@@ -55,10 +44,26 @@ static bool isBlockInRegion(IRDominatorTree* domTree, IRTerminatorInst* regionHe
return true;
}
+static IRInst* findBreakableRegionHeaderInst(IRDominatorTree* domTree, IRBlock* block)
+{
+ for (auto idom = domTree->getImmediateDominator(block); idom; idom = domTree->getImmediateDominator(idom))
+ {
+ auto terminator = idom->getTerminator();
+ switch (terminator->getOp())
+ {
+ case kIROp_Switch:
+ case kIROp_loop:
+ return terminator;
+ }
+ }
+ return nullptr;
+}
+
// Test if a loop is trivial: a trivial loop runs for a single iteration without any back edges, and
// there is only one break out of the loop at the very end. The function generates `regionTree` if
// it is needed and hasn't been generated yet.
static bool isTrivialSingleIterationLoop(
+ CFGSimplificationContext& context,
IRGlobalValueWithCode* func,
IRLoop* loop)
{
@@ -80,40 +85,21 @@ static bool isTrivialSingleIterationLoop(
//
// We need to verify this is a trivial loop by checking if there is any multi-level breaks
// that skips out of this loop.
- CFGSimplificationContext context;
if (!context.domTree)
context.domTree = computeDominatorTree(func);
- if (!context.regionTree)
- context.regionTree = generateRegionTreeForFunc(func, nullptr);
-
- SimpleRegion* targetBlockRegion = nullptr;
- if (!context.regionTree->mapBlockToRegion.tryGetValue(targetBlock, targetBlockRegion))
- return false;
- BreakableRegion* loopBreakableRegion = findBreakableRegion(targetBlockRegion);
- LoopRegion* loopRegion = as<LoopRegion>(loopBreakableRegion);
- if (!loopRegion)
+ bool hasMultiLevelBreaks = false;
+ auto loopBlocks = collectBlocksInRegion(context.domTree, loop, &hasMultiLevelBreaks);
+ if (hasMultiLevelBreaks)
return false;
- for (auto block : func->getBlocks())
+ for (auto block : loopBlocks)
{
- if (!context.domTree->dominates(loop->getTargetBlock(), block))
- continue;
- if (context.domTree->dominates(loop->getBreakBlock(), block))
- continue;
- SimpleRegion* region = nullptr;
- if (!context.regionTree->mapBlockToRegion.tryGetValue(block, region))
- return false;
-
for (auto branchTarget : block->getSuccessors())
{
- SimpleRegion* targetRegion = nullptr;
- if (!context.regionTree->mapBlockToRegion.tryGetValue(branchTarget, targetRegion))
+ if (!context.domTree->dominates(loop->getParent(), branchTarget))
return false;
- // If multi-level break out that skips over this loop exists, then this is not a trivial loop.
- if (targetRegion->isDescendentOf(loopRegion))
- continue;
if (targetBlock != loop->getBreakBlock())
return false;
- if (findBreakableRegion(region) != loopRegion)
+ if (findBreakableRegionHeaderInst(context.domTree, block) != loop)
{
// If the break is initiated from a nested region, this is not trivial.
return false;
@@ -164,10 +150,12 @@ static bool isTrivialSingleIterationLoop(
return true;
}
-static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst)
+static bool doesLoopHasSideEffect(CFGSimplificationContext& context, ReachabilityContext& reachability, IRGlobalValueWithCode* func, IRLoop* loopInst)
{
bool hasMultiLevelBreaks = false;
- auto blocks = collectBlocksInRegion(func, loopInst, &hasMultiLevelBreaks);
+ if (!context.domTree)
+ context.domTree = computeDominatorTree(func);
+ auto blocks = collectBlocksInRegion(context.domTree.get(), loopInst, &hasMultiLevelBreaks);
// We'll currently not deal with loops that contain multi-level breaks.
if (hasMultiLevelBreaks)
@@ -177,25 +165,26 @@ static bool doesLoopHasSideEffect(IRGlobalValueWithCode* func, IRLoop* loopInst)
for (auto b : blocks)
loopBlocks.add(b);
- ReachabilityContext reachability = {};
-
// Construct a map from a root address to all derived addresses.
- Dictionary<IRInst*, List<IRInst*>> relatedAddrMap;
- for (auto b : func->getBlocks())
+ Dictionary<IRInst*, List<IRInst*>>& relatedAddrMap = context.relatedAddrMap;
+ if (!relatedAddrMap.getCount())
{
- for (auto inst : b->getChildren())
+ for (auto b : func->getBlocks())
{
- if (as<IRPtrTypeBase>(inst->getDataType()))
+ for (auto inst : b->getChildren())
{
- auto root = getRootAddr(inst);
- if (!root) continue;
- auto list = relatedAddrMap.tryGetValue(root);
- if (!list)
+ if (as<IRPtrTypeBase>(inst->getDataType()))
{
- relatedAddrMap.add(root, List<IRInst*>());
- list = relatedAddrMap.tryGetValue(root);
+ auto root = getRootAddr(inst);
+ if (!root) continue;
+ auto list = relatedAddrMap.tryGetValue(root);
+ if (!list)
+ {
+ relatedAddrMap.add(root, List<IRInst*>());
+ list = relatedAddrMap.tryGetValue(root);
+ }
+ list->add(inst);
}
- list->add(inst);
}
}
}
@@ -779,7 +768,7 @@ static bool removeTrivialPhiParams(IRBlock* block)
return changed;
}
-static bool processFunc(IRGlobalValueWithCode* func)
+static bool processFunc(IRGlobalValueWithCode* func, CFGSimplificationOptions options)
{
auto firstBlock = func->getFirstBlock();
if (!firstBlock)
@@ -787,6 +776,10 @@ static bool processFunc(IRGlobalValueWithCode* func)
IRBuilder builder(func->getModule());
+ bool isReachabilityContextValid = false;
+ ReachabilityContext reachabilityContext;
+ CFGSimplificationContext simplificationContext;
+
bool changed = false;
for (;;)
{
@@ -815,6 +808,7 @@ static bool processFunc(IRGlobalValueWithCode* func)
{
loop->continueBlock.set(loop->getTargetBlock());
continueBlock->removeAndDeallocate();
+ simplificationContext = CFGSimplificationContext();
changed = true;
}
@@ -822,7 +816,7 @@ static bool processFunc(IRGlobalValueWithCode* func)
// break at the end of the loop, we can remove the header and turn it into
// a normal branch.
auto targetBlock = loop->getTargetBlock();
- if (isTrivialSingleIterationLoop(func, loop))
+ if (options.removeTrivialSingleIterationLoops && isTrivialSingleIterationLoop(simplificationContext, func, loop))
{
builder.setInsertBefore(loop);
List<IRInst*> args;
@@ -832,26 +826,44 @@ static bool processFunc(IRGlobalValueWithCode* func)
}
builder.emitBranch(targetBlock, args.getCount(), args.getBuffer());
loop->removeAndDeallocate();
+ simplificationContext = CFGSimplificationContext();
changed = true;
}
- else if (!doesLoopHasSideEffect(func, loop))
+ else if (options.removeSideEffectFreeLoops)
{
- // The loop isn't computing anything useful outside the loop.
- // We can delete the entire loop.
- builder.setInsertBefore(loop);
- SLANG_ASSERT(loop->getBreakBlock()->getFirstParam() == nullptr);
- builder.emitBranch(loop->getBreakBlock());
- loop->removeAndDeallocate();
- changed = true;
+ if (!isReachabilityContextValid)
+ {
+ isReachabilityContextValid = true;
+ reachabilityContext = ReachabilityContext(func);
+ }
+ if (!doesLoopHasSideEffect(simplificationContext, reachabilityContext, func, loop))
+ {
+ // The loop isn't computing anything useful outside the loop.
+ // We can delete the entire loop.
+ builder.setInsertBefore(loop);
+ SLANG_ASSERT(loop->getBreakBlock()->getFirstParam() == nullptr);
+ builder.emitBranch(loop->getBreakBlock());
+ loop->removeAndDeallocate();
+ simplificationContext = CFGSimplificationContext();
+ changed = true;
+ }
}
}
else if (auto condBranch = as<IRIfElse>(block->getTerminator()))
{
- changed |= trySimplifyIfElse(builder, condBranch);
+ if (trySimplifyIfElse(builder, condBranch))
+ {
+ simplificationContext = CFGSimplificationContext();
+ changed = true;
+ }
}
else if (auto switchBranch = as<IRSwitch>(block->getTerminator()))
{
- changed |= trySimplifySwitch(builder, switchBranch);
+ if (trySimplifySwitch(builder, switchBranch))
+ {
+ simplificationContext = CFGSimplificationContext();
+ changed = true;
+ }
}
// If `block` does not end with an unconditional branch, bail.
@@ -867,6 +879,7 @@ static bool processFunc(IRGlobalValueWithCode* func)
if (block->hasMoreThanOneUse())
break;
changed = true;
+ simplificationContext = CFGSimplificationContext();
Index paramIndex = 0;
auto inst = successor->getFirstDecorationOrChild();
while (inst)
@@ -911,7 +924,7 @@ static bool processFunc(IRGlobalValueWithCode* func)
return changed;
}
-bool simplifyCFG(IRModule* module)
+bool simplifyCFG(IRModule* module, CFGSimplificationOptions options)
{
bool changed = false;
for (auto inst : module->getGlobalInsts())
@@ -922,15 +935,15 @@ bool simplifyCFG(IRModule* module)
}
if (auto func = as<IRFunc>(inst))
{
- changed |= processFunc(func);
+ changed |= processFunc(func, options);
}
}
return changed;
}
-bool simplifyCFG(IRGlobalValueWithCode* func)
+bool simplifyCFG(IRGlobalValueWithCode* func, CFGSimplificationOptions options)
{
- return processFunc(func);
+ return processFunc(func, options);
}
} // namespace Slang
diff --git a/source/slang/slang-ir-simplify-cfg.h b/source/slang/slang-ir-simplify-cfg.h
index 6bfa6e2bf..4bc37bd29 100644
--- a/source/slang/slang-ir-simplify-cfg.h
+++ b/source/slang/slang-ir-simplify-cfg.h
@@ -6,11 +6,19 @@ namespace Slang
struct IRModule;
struct IRGlobalValueWithCode;
+ struct CFGSimplificationOptions
+ {
+ bool removeTrivialSingleIterationLoops = true;
+ bool removeSideEffectFreeLoops = true;
+ static CFGSimplificationOptions getDefault() { return CFGSimplificationOptions(); }
+ static CFGSimplificationOptions getFast() { return CFGSimplificationOptions{ false, false }; }
+ };
+
/// Simplifies control flow graph by merging basic blocks that
/// forms a simple linear chain.
/// Returns true if changed.
- bool simplifyCFG(IRModule* module);
+ bool simplifyCFG(IRModule* module, CFGSimplificationOptions options);
- bool simplifyCFG(IRGlobalValueWithCode* func);
+ bool simplifyCFG(IRGlobalValueWithCode* func, CFGSimplificationOptions options);
}
diff --git a/source/slang/slang-ir-single-return.cpp b/source/slang/slang-ir-single-return.cpp
index 0b61e5065..519a4f2d4 100644
--- a/source/slang/slang-ir-single-return.cpp
+++ b/source/slang/slang-ir-single-return.cpp
@@ -18,7 +18,7 @@ struct SingleReturnContext : public InstPassBase
void processFunc(IRGlobalValueWithCode* func)
{
IRBuilder builder(module);
- simplifyCFG(func);
+ simplifyCFG(func, CFGSimplificationOptions::getFast());
// We make use of the `eliminate-multi-level-break` pass to implement the transformation.
// To be able to do that, we need to prepare `func` so that the entire function body
diff --git a/source/slang/slang-ir-specialize-function-call.cpp b/source/slang/slang-ir-specialize-function-call.cpp
index f9e106920..6aa2e4998 100644
--- a/source/slang/slang-ir-specialize-function-call.cpp
+++ b/source/slang/slang-ir-specialize-function-call.cpp
@@ -891,7 +891,7 @@ struct FunctionParameterSpecializationContext
//
addCallsToWorkListRec(newFunc);
- simplifyFunc(newFunc);
+ simplifyFunc(newFunc, IRSimplificationOptions::getFast());
return newFunc;
}
diff --git a/source/slang/slang-ir-specialize-resources.cpp b/source/slang/slang-ir-specialize-resources.cpp
index 49f1b3177..95dc39729 100644
--- a/source/slang/slang-ir-specialize-resources.cpp
+++ b/source/slang/slang-ir-specialize-resources.cpp
@@ -1185,7 +1185,7 @@ bool specializeResourceUsage(
// and turned into SSA temporaries. Such optimization may enable
// the following passes to "see" and specialize more cases.
//
- simplifyIR(irModule);
+ simplifyIR(irModule, IRSimplificationOptions::getFast());
result |= changed;
}
if (unspecializableFuncs.getCount() == 0)
@@ -1205,7 +1205,7 @@ bool specializeResourceUsage(
inlineCall(call);
});
}
- simplifyIR(irModule);
+ simplifyIR(irModule, IRSimplificationOptions::getFast());
}
return result;
}
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index b2cd5edc4..e1a563b04 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -1511,7 +1511,7 @@ struct SpecializationContext
//
addToWorkList(newFunc);
- simplifyFunc(newFunc);
+ simplifyFunc(newFunc, IRSimplificationOptions::getFast());
return newFunc;
}
@@ -2369,7 +2369,7 @@ IRInst* specializeGenericImpl(
// the same thing.
if (auto func = as<IRFunc>(specializedVal))
{
- simplifyFunc(func);
+ simplifyFunc(func, IRSimplificationOptions::getFast());
}
return specializedVal;
diff --git a/source/slang/slang-ir-spirv-legalize.cpp b/source/slang/slang-ir-spirv-legalize.cpp
index b92f5e910..715b4f6ce 100644
--- a/source/slang/slang-ir-spirv-legalize.cpp
+++ b/source/slang/slang-ir-spirv-legalize.cpp
@@ -13,7 +13,7 @@
#include "slang-ir-layout.h"
#include "slang-ir-util.h"
#include "slang-ir-dominators.h"
-#include "slang-ir-array-reg-to-mem.h"
+#include "slang-ir-composite-reg-to-mem.h"
#include "slang-ir-sccp.h"
#include "slang-ir-dce.h"
#include "slang-ir-simplify-cfg.h"
@@ -1516,9 +1516,7 @@ struct SPIRVLegalizationContext : public SourceEmitterBase
void processModule()
{
-#if 0
- eliminateArrayTypeSSARegisters(m_module);
-#endif
+ convertCompositeTypeParametersToPointers(m_module);
// Process global params before anything else, so we don't generate inefficient
// array marhalling code for array-typed global params.
@@ -1739,7 +1737,7 @@ void simplifyIRForSpirvLegalization(DiagnosticSink* sink, IRModule* module)
funcChanged |= applySparseConditionalConstantPropagation(func, sink);
funcChanged |= peepholeOptimize(func);
funcChanged |= removeRedundancyInFunc(func);
- funcChanged |= simplifyCFG(func);
+ funcChanged |= simplifyCFG(func, CFGSimplificationOptions::getFast());
eliminateDeadCode(func);
}
}
diff --git a/source/slang/slang-ir-ssa-register-allocate.cpp b/source/slang/slang-ir-ssa-register-allocate.cpp
index ce502b454..90a0f05e5 100644
--- a/source/slang/slang-ir-ssa-register-allocate.cpp
+++ b/source/slang/slang-ir-ssa-register-allocate.cpp
@@ -13,6 +13,12 @@ namespace Slang
struct RegisterAllocateContext
{
OrderedDictionary<IRType*, List<RefPtr<RegisterInfo>>> mapTypeToRegisterList;
+ bool allocateForCompositeTypeOnly;
+
+ RegisterAllocateContext(bool compositeTypeOnly)
+ :allocateForCompositeTypeOnly(compositeTypeOnly)
+ {}
+
List<RefPtr<RegisterInfo>>& getRegisterListForType(IRType* type)
{
if (auto list = mapTypeToRegisterList.tryGetValue(type))
@@ -142,7 +148,7 @@ struct RegisterAllocateContext
RegisterAllocationResult allocateRegisters(IRGlobalValueWithCode* func, RefPtr<IRDominatorTree>& inOutDom)
{
- ReachabilityContext reachabilityContext;
+ ReachabilityContext reachabilityContext(func);
mapTypeToRegisterList.clear();
auto dom = computeDominatorTree(func);
@@ -275,6 +281,7 @@ struct RegisterAllocateContext
}
return result;
}
+
bool instNeedsProcessing(IRGlobalValueWithCode* func, IRInst* inst)
{
switch (inst->getOp())
@@ -282,6 +289,8 @@ struct RegisterAllocateContext
case kIROp_Param:
if (inst->getParent() == func->getFirstBlock())
return false;
+ if (allocateForCompositeTypeOnly && !isCompositeType(inst->getFullType()))
+ return false;
return true;
case kIROp_UpdateElement:
return true;
@@ -303,9 +312,9 @@ struct RegisterAllocateContext
}
};
-RegisterAllocationResult allocateRegistersForFunc(IRGlobalValueWithCode* func, RefPtr<IRDominatorTree>& inOutDom)
+RegisterAllocationResult allocateRegistersForFunc(IRGlobalValueWithCode* func, RefPtr<IRDominatorTree>& inOutDom, bool allocateForCompositeTypeOnly)
{
- RegisterAllocateContext context;
+ RegisterAllocateContext context(allocateForCompositeTypeOnly);
if (context.needProcessing(func))
return context.allocateRegisters(func, inOutDom);
return RegisterAllocationResult();
diff --git a/source/slang/slang-ir-ssa-register-allocate.h b/source/slang/slang-ir-ssa-register-allocate.h
index 1e8c586cd..1ef575716 100644
--- a/source/slang/slang-ir-ssa-register-allocate.h
+++ b/source/slang/slang-ir-ssa-register-allocate.h
@@ -19,6 +19,6 @@ struct RegisterAllocationResult
Dictionary<IRInst*, RefPtr<RegisterInfo>> mapInstToRegister;
};
-RegisterAllocationResult allocateRegistersForFunc(IRGlobalValueWithCode* func, RefPtr<IRDominatorTree>& inOutDom);
+RegisterAllocationResult allocateRegistersForFunc(IRGlobalValueWithCode* func, RefPtr<IRDominatorTree>& inOutDom, bool allocateForCompositeTypesOnly);
}
diff --git a/source/slang/slang-ir-ssa-simplification.cpp b/source/slang/slang-ir-ssa-simplification.cpp
index dbef20732..104ce05e4 100644
--- a/source/slang/slang-ir-ssa-simplification.cpp
+++ b/source/slang/slang-ir-ssa-simplification.cpp
@@ -11,12 +11,13 @@
#include "slang-ir-redundancy-removal.h"
#include "slang-ir-propagate-func-properties.h"
#include "../core/slang-performance-profiler.h"
+#include "slang-ir-util.h"
namespace Slang
{
// Run a combination of SSA, SCCP, SimplifyCFG, and DeadCodeElimination pass
// until no more changes are possible.
- void simplifyIR(IRModule* module, DiagnosticSink* sink)
+ void simplifyIR(IRModule* module, IRSimplificationOptions options, DiagnosticSink* sink)
{
SLANG_PROFILE;
bool changed = true;
@@ -50,7 +51,7 @@ namespace Slang
funcChanged |= applySparseConditionalConstantPropagation(func, sink);
funcChanged |= peepholeOptimize(func);
funcChanged |= removeRedundancyInFunc(func);
- funcChanged |= simplifyCFG(func);
+ funcChanged |= simplifyCFG(func, options.cfgOptions);
eliminateDeadCode(func);
funcChanged |= constructSSA(func);
changed |= funcChanged;
@@ -67,7 +68,7 @@ namespace Slang
}
}
- void simplifyNonSSAIR(IRModule* module)
+ void simplifyNonSSAIR(IRModule* module, IRSimplificationOptions options)
{
bool changed = true;
const int kMaxIterations = 8;
@@ -76,20 +77,20 @@ namespace Slang
{
changed = false;
changed |= peepholeOptimize(module);
+
changed |= removeRedundancy(module);
- changed |= simplifyCFG(module);
+ changed |= simplifyCFG(module, options.cfgOptions);
// Note: we disregard the `changed` state from dead code elimination pass since
// SCCP pass could be generating temporarily evaluated constant values and never actually use them.
// DCE will always remove those nearly generated consts and always returns true here.
eliminateDeadCode(module);
-
iterationCounter++;
}
}
- void simplifyFunc(IRGlobalValueWithCode* func, DiagnosticSink* sink)
+ void simplifyFunc(IRGlobalValueWithCode* func, IRSimplificationOptions options, DiagnosticSink* sink)
{
bool changed = true;
const int kMaxIterations = 8;
@@ -103,7 +104,7 @@ namespace Slang
changed |= applySparseConditionalConstantPropagation(func, sink);
changed |= peepholeOptimize(func);
changed |= removeRedundancyInFunc(func);
- changed |= simplifyCFG(func);
+ changed |= simplifyCFG(func, options.cfgOptions);
// Note: we disregard the `changed` state from dead code elimination pass since
// SCCP pass could be generating temporarily evaluated constant values and never actually use them.
diff --git a/source/slang/slang-ir-ssa-simplification.h b/source/slang/slang-ir-ssa-simplification.h
index fd7aa0ad8..88347aa17 100644
--- a/source/slang/slang-ir-ssa-simplification.h
+++ b/source/slang/slang-ir-ssa-simplification.h
@@ -1,18 +1,37 @@
// slang-ir-ssa-simplification.h
#pragma once
+#include "slang-ir-simplify-cfg.h"
+
namespace Slang
{
struct IRModule;
struct IRGlobalValueWithCode;
class DiagnosticSink;
+ struct IRSimplificationOptions
+ {
+ CFGSimplificationOptions cfgOptions;
+ static IRSimplificationOptions getDefault()
+ {
+ IRSimplificationOptions result;
+ return result;
+ }
+ static IRSimplificationOptions getFast()
+ {
+ IRSimplificationOptions result;
+ result.cfgOptions.removeSideEffectFreeLoops = false;
+ result.cfgOptions.removeTrivialSingleIterationLoops = false;
+ return result;
+ }
+ };
+
// Run a combination of SSA, SCCP, SimplifyCFG, and DeadCodeElimination pass
// until no more changes are possible.
- void simplifyIR(IRModule* module, DiagnosticSink* sink = nullptr);
+ void simplifyIR(IRModule* module, IRSimplificationOptions options, DiagnosticSink* sink = nullptr);
// Run simplifications on IR that is out of SSA form.
- void simplifyNonSSAIR(IRModule* module);
+ void simplifyNonSSAIR(IRModule* module, IRSimplificationOptions options);
- void simplifyFunc(IRGlobalValueWithCode* func, DiagnosticSink* sink = nullptr);
+ void simplifyFunc(IRGlobalValueWithCode* func, IRSimplificationOptions options, DiagnosticSink* sink = nullptr);
}
diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp
index 730943bf8..788c9a391 100644
--- a/source/slang/slang-ir-ssa.cpp
+++ b/source/slang/slang-ir-ssa.cpp
@@ -5,6 +5,7 @@
#include "slang-ir-clone.h"
#include "slang-ir-insts.h"
#include "slang-ir-validate.h"
+#include "slang-ir-util.h"
namespace Slang {
@@ -843,6 +844,8 @@ void processBlock(
IRBlock* block,
SSABlockInfo* blockInfo)
{
+ hoistInstOutOfASMBlocks(block);
+
// Before starting, check if this block can be sealed
maybeSealBlock(context, blockInfo);
diff --git a/source/slang/slang-ir-use-uninitialized-out-param.cpp b/source/slang/slang-ir-use-uninitialized-out-param.cpp
index 479538441..07a2b1bc2 100644
--- a/source/slang/slang-ir-use-uninitialized-out-param.cpp
+++ b/source/slang/slang-ir-use-uninitialized-out-param.cpp
@@ -20,7 +20,7 @@ namespace Slang
if (!firstBlock)
return;
- ReachabilityContext reachability;
+ ReachabilityContext reachability(func);
for (auto param : firstBlock->getParams())
{
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 5ecbc8121..073b8bf96 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -1166,6 +1166,23 @@ UnownedStringSlice getBuiltinFuncName(IRInst* callee)
return decor->getName();
}
+void hoistInstOutOfASMBlocks(IRBlock* block)
+{
+ for (auto inst : block->getChildren())
+ {
+ if (auto asmBlock = as<IRSPIRVAsm>(inst))
+ {
+ IRInst* next = nullptr;
+ for (auto i = asmBlock->getFirstChild(); i; i = next)
+ {
+ next = i->getNextInst();
+ if (!as<IRSPIRVAsmInst>(i) && !as<IRSPIRVAsmOperand>(i))
+ i->insertBefore(asmBlock);
+ }
+ }
+ }
+}
+
UnownedStringSlice getBasicTypeNameHint(IRType* basicType)
{
switch (basicType->getOp())
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index c13ce1931..ff6298f39 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -279,6 +279,21 @@ static void overAllBlocks(IRModule* module, F f)
}
}
+void hoistInstOutOfASMBlocks(IRBlock* block);
+
+inline bool isCompositeType(IRType* type)
+{
+ switch (type->getOp())
+ {
+ case kIROp_StructType:
+ case kIROp_ArrayType:
+ case kIROp_UnsizedArrayType:
+ return true;
+ default:
+ return false;
+ }
+}
+
}
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 3b7fb9ac8..6a3a26bd5 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -397,7 +397,10 @@ namespace Slang
//
if (auto lastParam = getLastParam())
{
- param->insertAfter(lastParam);
+ if (lastParam->next)
+ param->insertAfter(lastParam);
+ else
+ param->insertAtEnd(this);
}
//
// Otherwise, if there are any existing
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index f62d4b24a..c90230c1f 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -9371,8 +9371,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addRequireGLSLVersionDecoration(irFunc, Int(getIntegerLiteralValue(versionMod->versionNumberToken)));
else if (auto spvVersion = as<RequiredSPIRVVersionModifier>(modifier))
getBuilder()->addRequireSPIRVVersionDecoration(irFunc, spvVersion->version);
- else if (auto capMod = as<RequiredSPIRVCapabilityModifier>(modifier))
- getBuilder()->addRequireSPIRVCapabilityDecoration(irFunc, capMod->capability, capMod->extensionName.getUnownedSlice());
else if (auto cudasmVersion = as<RequiredCUDASMVersionModifier>(modifier))
getBuilder()->addRequireCUDASMVersionDecoration(irFunc, cudasmVersion->version);
}
@@ -10164,7 +10162,7 @@ RefPtr<IRModule> generateIRForTranslationUnit(
// temporaries and do basic simplifications.
//
constructSSA(module);
- simplifyCFG(module);
+ simplifyCFG(module, CFGSimplificationOptions::getDefault());
applySparseConditionalConstantPropagation(module, compileRequest->getSink());
peepholeOptimize(module);
@@ -10213,7 +10211,7 @@ RefPtr<IRModule> generateIRForTranslationUnit(
bool changed = false;
performMandatoryEarlyInlining(module);
changed |= constructSSA(module);
- simplifyCFG(module);
+ simplifyCFG(module, CFGSimplificationOptions::getDefault());
changed |= applySparseConditionalConstantPropagation(module, compileRequest->getSink());
changed |= peepholeOptimize(module);
for (auto inst : module->getGlobalInsts())
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 3f8225084..517ab9e10 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -6943,26 +6943,6 @@ namespace Slang
return nullptr;
}
- static NodeBase* parseSPIRVCapabilityModifier(Parser* parser, void*)
- {
- parser->ReadToken(TokenType::LParent);
- Token token = parser->ReadToken(TokenType::Identifier);
- auto modifier = parser->astBuilder->create<RequiredSPIRVCapabilityModifier>();
- const SPIRVCoreGrammarInfo& spirvInfo =
- parser->astBuilder->getGlobalSession()->getSPIRVCoreGrammarInfo();
- const auto cap = spirvInfo.capabilities.lookup(token.getContent());
- if (!cap)
- {
- parser->sink->diagnose(token, Diagnostics::unknownSPIRVCapability, token);
- }
- else
- {
- modifier->capability = (int32_t)cap.value();
- }
- parser->ReadToken(TokenType::RParent);
- return modifier;
- }
-
static NodeBase* parseCUDASMVersionModifier(Parser* parser, void* /*userData*/)
{
Token token;
@@ -7323,7 +7303,6 @@ namespace Slang
_makeParseModifier("__glsl_extension", parseGLSLExtensionModifier),
_makeParseModifier("__glsl_version", parseGLSLVersionModifier),
_makeParseModifier("__spirv_version", parseSPIRVVersionModifier),
- _makeParseModifier("__spirv_capability", parseSPIRVCapabilityModifier),
_makeParseModifier("__cuda_sm_version", parseCUDASMVersionModifier),
_makeParseModifier("__builtin_type", parseBuiltinTypeModifier),
diff --git a/source/slang/slang-spirv-opt.cpp b/source/slang/slang-spirv-opt.cpp
index 34c87db9f..786358324 100644
--- a/source/slang/slang-spirv-opt.cpp
+++ b/source/slang/slang-spirv-opt.cpp
@@ -23,8 +23,6 @@ SlangResult optimizeSPIRV(const List<uint8_t>& spirv, String& outErr, List<uint8
// Set up our process
CommandLine commandLine;
commandLine.m_executableLocation.setName("spirv-opt");
- commandLine.addArg("--merge-return");
- commandLine.addArg("--inline-entry-points-exhaustive");
commandLine.addArg("--eliminate-dead-functions");
commandLine.addArg("--eliminate-local-single-block");
commandLine.addArg("--eliminate-local-single-store");
diff --git a/source/slang/slang-spirv-val.cpp b/source/slang/slang-spirv-val.cpp
index c03102d9f..e62564cc4 100644
--- a/source/slang/slang-spirv-val.cpp
+++ b/source/slang/slang-spirv-val.cpp
@@ -55,22 +55,14 @@ SlangResult debugValidateSPIRV(const List<uint8_t>& spirv)
const auto out = p->getStream(StdStreamType::Out);
const auto err = p->getStream(StdStreamType::ErrorOut);
- // Write the assembly
- SLANG_RETURN_ON_FAIL(in->write(spirv.getBuffer(), spirv.getCount()));
- in->close();
+ List<Byte> outData;
+ List<Byte> errData;
+ SLANG_RETURN_ON_FAIL(StreamUtil::readAndWrite(in, spirv.getArrayView(), out, outData, err, errData));
// Wait for it to finish
if(!p->waitForTermination(1000))
return SLANG_FAIL;
-
- // TODO: allow inheriting stderr in Process
- List<Byte> outData;
- SLANG_RETURN_ON_FAIL(StreamUtil::readAll(out, 0, outData));
- fwrite(outData.getBuffer(), outData.getCount(), 1, stderr);
- outData.clear();
- SLANG_RETURN_ON_FAIL(StreamUtil::readAll(err, 0, outData));
-
// If we failed, dump the spirv first.
const auto ret = p->getReturnValue();
if(ret != 0)
@@ -83,6 +75,7 @@ SlangResult debugValidateSPIRV(const List<uint8_t>& spirv)
}
fwrite(outData.getBuffer(), outData.getCount(), 1, stderr);
+ fwrite(errData.getBuffer(), errData.getCount(), 1, stderr);
return ret == 0 ? SLANG_OK : SLANG_FAIL;
}