summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-08-29 05:55:49 -0700
committerGitHub <noreply@github.com>2023-08-29 20:55:49 +0800
commit9d4e044bad6161a593806fc6fb610d41aa8b4b22 (patch)
tree28214d3a2a56762f3b858299696f4d4f8a85686f /source
parentb8fcb586f6a931ab674b0da7f375f38aff9608d4 (diff)
Add more wave intrinsics. (#3162)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/hlsl.meta.slang382
-rw-r--r--source/slang/slang-ir-inline.cpp52
2 files changed, 339 insertions, 95 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 1255e43e0..3c966bb4a 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -4936,6 +4936,7 @@ __generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_vote)
__spirv_version(1.3)
__cuda_sm_version(7.0)
+__spirv_capability(GroupNonUniformVote)
bool WaveMaskAllEqual(WaveMask mask, vector<T,N> value)
{
__target_switch
@@ -4967,17 +4968,59 @@ bool WaveMaskAllEqual(WaveMask mask, matrix<T,N,M> value);
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExclusiveMul($1)")
-__target_intrinsic(cuda, "_wavePrefixProduct($0, $1)")
-__target_intrinsic(hlsl, "WavePrefixProduct($1)")
-T WaveMaskPrefixProduct(WaveMask mask, T expr);
+__spirv_capability(GroupNonUniformArithmetic)
+T WaveMaskPrefixProduct(WaveMask mask, T expr)
+{
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupExclusiveMul($1)";
+ case cuda: __intrinsic_asm "_wavePrefixProduct($0, $1)";
+ case hlsl: __intrinsic_asm "WavePrefixProduct($1)";
+ case spirv:
+ if (__isFloat<T>())
+ return spirv_asm {OpGroupNonUniformFMul $$T result Subgroup ExclusiveScan $expr};
+ else if (__isSignedInt<T>())
+ {
+ return spirv_asm
+ {
+ // TODO: use the correct integer width
+ OpBitcast $$uint %uvalue $expr;
+ OpGroupNonUniformIMul $$uint %mulResult Subgroup ExclusiveScan %uvalue;
+ OpBitcast $$T result %mulResult
+ };
+ }
+ else if (__isUnsignedInt<T>())
+ return spirv_asm {OpGroupNonUniformIMul $$T result Subgroup ExclusiveScan $expr};
+ }
+}
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExclusiveMul($1)")
-__target_intrinsic(cuda, "_wavePrefixProductMultiple($0, $1)")
-__target_intrinsic(hlsl, "WavePrefixProduct($1)")
-vector<T,N> WaveMaskPrefixProduct(WaveMask mask, vector<T,N> expr);
+__spirv_capability(GroupNonUniformArithmetic)
+vector<T,N> WaveMaskPrefixProduct(WaveMask mask, vector<T,N> expr)
+{
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupExclusiveMul($1)";
+ case cuda: __intrinsic_asm "_wavePrefixProductMultiple($0, $1)";
+ case hlsl: __intrinsic_asm "WavePrefixProduct($1)";
+ case spirv:
+ if (__isFloat<T>())
+ return spirv_asm {OpGroupNonUniformFMul $$vector<T,N> result Subgroup ExclusiveScan $expr};
+ else if (__isSignedInt<T>())
+ {
+ return spirv_asm
+ {
+ // TODO: use the correct integer width
+ OpBitcast $$uint %uvalue $expr;
+ OpGroupNonUniformIMul $$vector<uint,N> %mulResult Subgroup ExclusiveScan %uvalue;
+ OpBitcast $$T result %mulResult
+ };
+ }
+ else if (__isUnsignedInt<T>())
+ return spirv_asm {OpGroupNonUniformIMul $$vector<T,N> result Subgroup ExclusiveScan $expr};
+ }
+}
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
__target_intrinsic(cuda, "_wavePrefixProductMultiple($0, $1)")
__target_intrinsic(hlsl, "WavePrefixProduct($1)")
@@ -4986,17 +5029,60 @@ matrix<T,N,M> WaveMaskPrefixProduct(WaveMask mask, matrix<T,N,M> expr);
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExclusiveAdd($1)")
-__target_intrinsic(cuda, "_wavePrefixSum($0, $1)")
-__target_intrinsic(hlsl, "WavePrefixSum($1)")
-T WaveMaskPrefixSum(WaveMask mask, T expr);
+__spirv_capability(GroupNonUniformArithmetic)
+T WaveMaskPrefixSum(WaveMask mask, T expr)
+{
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupExclusiveAdd($1)";
+ case cuda: __intrinsic_asm "_wavePrefixSum($0, $1)";
+ case hlsl: __intrinsic_asm "WavePrefixSum($1)";
+ case spirv:
+ if (__isFloat<T>())
+ return spirv_asm {OpGroupNonUniformFAdd $$T result Subgroup ExclusiveScan $expr};
+ else if (__isSignedInt<T>())
+ {
+ return spirv_asm
+ {
+ // TODO: use the correct integer width
+ %uvalue:$$uint = OpBitcast $expr;
+ %mulResult:$$uint = OpGroupNonUniformIAdd Subgroup ExclusiveScan %uvalue;
+ result:$$T = OpBitcast %mulResult
+ };
+ }
+ else if (__isUnsignedInt<T>())
+ return spirv_asm {OpGroupNonUniformIAdd $$T result Subgroup ExclusiveScan $expr};
+ }
+}
+
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExclusiveAdd($1)")
-__target_intrinsic(cuda, "_wavePrefixSumMultiple($0, $1)")
-__target_intrinsic(hlsl, "WavePrefixSum($1)")
-vector<T,N> WaveMaskPrefixSum(WaveMask mask, vector<T,N> expr);
+__spirv_capability(GroupNonUniformArithmetic)
+vector<T,N> WaveMaskPrefixSum(WaveMask mask, vector<T,N> expr)
+{
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupExclusiveAdd($1)";
+ case cuda: __intrinsic_asm "_wavePrefixSumMultiple($0, $1)";
+ case hlsl: __intrinsic_asm "WavePrefixSum($1)";
+ case spirv:
+ if (__isFloat<T>())
+ return spirv_asm {OpGroupNonUniformFAdd $$vector<T,N> result Subgroup ExclusiveScan $expr};
+ else if (__isSignedInt<T>())
+ {
+ return spirv_asm
+ {
+ // TODO: use the correct integer width
+ %uvalue: $$uint = OpBitcast $expr;
+ %mulResult: $$vector<uint,N> = OpGroupNonUniformIAdd Subgroup ExclusiveScan %uvalue;
+ result: $$T = OpBitcast %mulResult
+ };
+ }
+ else if (__isUnsignedInt<T>())
+ return spirv_asm {OpGroupNonUniformIAdd $$vector<T,N> result Subgroup ExclusiveScan $expr};
+ }
+}
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
__target_intrinsic(cuda, "_wavePrefixSumMultiple($0, $1)")
__target_intrinsic(hlsl, "WavePrefixSum($1)")
@@ -5005,15 +5091,34 @@ matrix<T,N,M> WaveMaskPrefixSum(WaveMask mask, matrix<T,N,M> expr);
__generic<T : __BuiltinType>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupBroadcastFirst($1)")
-__target_intrinsic(cuda, "_waveReadFirst($0, $1)")
-T WaveMaskReadLaneFirst(WaveMask mask, T expr);
+__spirv_capability(GroupNonUniformBallot)
+T WaveMaskReadLaneFirst(WaveMask mask, T expr)
+{
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupBroadcastFirst($1)";
+ case cuda: __intrinsic_asm "_waveReadFirst($0, $1)";
+ case hlsl: __intrinsic_asm "WaveReadLaneFirst($1)";
+ case spirv:
+ return spirv_asm {OpGroupNonUniformBroadcastFirst $$T result Subgroup $expr};
+ }
+}
__generic<T : __BuiltinType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_ballot)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupBroadcastFirst($1)")
-__target_intrinsic(cuda, "_waveReadFirstMultiple($0, $1)")
-vector<T,N> WaveMaskReadLaneFirst(WaveMask mask, vector<T,N> expr);
+__spirv_capability(GroupNonUniformBallot)
+vector<T,N> WaveMaskReadLaneFirst(WaveMask mask, vector<T,N> expr)
+{
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupBroadcastFirst($1)";
+ case cuda: __intrinsic_asm "_waveReadFirstMultiple($0, $1)";
+ case hlsl: __intrinsic_asm "WaveReadLaneFirst($1)";
+ case spirv:
+ return spirv_asm {OpGroupNonUniformBroadcastFirst $$vector<T,N> result Subgroup $expr};
+ }
+}
+
__generic<T : __BuiltinType, let N : int, let M : int>
__target_intrinsic(cuda, "_waveReadFirstMultiple($0, $1)")
matrix<T,N,M> WaveMaskReadLaneFirst(WaveMask mask, matrix<T,N,M> expr);
@@ -5023,21 +5128,38 @@ matrix<T,N,M> WaveMaskReadLaneFirst(WaveMask mask, matrix<T,N,M> expr);
// TODO(JS): On HLSL it only works for 32 bits or less
__generic<T : __BuiltinType>
-__target_intrinsic(hlsl, "WaveMatch($1).x")
__glsl_extension(GL_NV_shader_subgroup_partitioned)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupPartitionNV($1).x")
__cuda_sm_version(7.0)
-__target_intrinsic(cuda, "_waveMatchScalar($0, $1).x")
-WaveMask WaveMaskMatch(WaveMask mask, T value);
+__spirv_capability(GroupNonUniformPartitionedNV)
+WaveMask WaveMaskMatch(WaveMask mask, T value)
+{
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupPartitionNV($1).x";
+ case cuda: __intrinsic_asm "_waveMatchScalar($0, $1).x";
+ case hlsl: __intrinsic_asm "WaveMatch($1).x";
+ case spirv:
+ return (spirv_asm {OpGroupNonUniformPartitionNV $$uint4 result $value}).x;
+ }
+}
__generic<T : __BuiltinType, let N : int>
-__target_intrinsic(hlsl, "WaveMatch($1).x")
__glsl_extension(GL_NV_shader_subgroup_partitioned)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupPartitionNV($1).x")
__cuda_sm_version(7.0)
-__target_intrinsic(cuda, "_waveMatchMultiple($0, $1)")
-WaveMask WaveMaskMatch(WaveMask mask, vector<T,N> value);
+__spirv_capability(GroupNonUniformPartitionedNV)
+WaveMask WaveMaskMatch(WaveMask mask, vector<T,N> value)
+{
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupPartitionNV($1).x";
+ case cuda: __intrinsic_asm "_waveMatchMultiple($0, $1).x";
+ case hlsl: __intrinsic_asm "WaveMatch($1).x";
+ case spirv:
+ return (spirv_asm {OpGroupNonUniformPartitionNV $$uint4 result $value}).x;
+ }
+}
+
__generic<T : __BuiltinType, let N : int, let M : int>
__target_intrinsic(hlsl, "WaveMatch($1).x")
__glsl_extension(GL_NV_shader_subgroup_partitioned)
@@ -5048,57 +5170,111 @@ __target_intrinsic(cuda, "_waveMatchMultiple($0, $1)")
WaveMask WaveMaskMatch(WaveMask mask, matrix<T,N,M> value);
__generic<T : __BuiltinArithmeticType>
-__target_intrinsic(hlsl, "WaveMultiPrefixBitAnd($1, uint4($0, 0, 0, 0))")
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-//__target_intrinsic(glsl, "subgroupExclusiveAnd($1)")
-__target_intrinsic(cuda, "_wavePrefixAnd($0, $1)")
-T WaveMaskPrefixBitAnd(WaveMask mask, T expr);
-__target_intrinsic(hlsl, "WaveMultiPrefixBitAnd($1, uint4($0, 0, 0, 0))")
+__spirv_capability(GroupNonUniformArithmetic)
+T WaveMaskPrefixBitAnd(WaveMask mask, T expr)
+{
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupExclusiveAnd($1)";
+ case cuda: __intrinsic_asm "_wavePrefixAnd($0, $1)";
+ case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd($1, uint4($0, 0, 0, 0))";
+ case spirv:
+ return spirv_asm {OpGroupNonUniformBitwiseAnd $$T result Subgroup ExclusiveScan $expr};
+ }
+}
+
+__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExclusiveAnd($1)")
-__target_intrinsic(cuda, "_wavePrefixAndMultiple($0, $1)")
-__generic<T : __BuiltinArithmeticType, let N : int>
-vector<T,N> WaveMaskPrefixBitAnd(WaveMask mask, vector<T,N> expr);
+__spirv_capability(GroupNonUniformArithmetic)
+vector<T,N> WaveMaskPrefixBitAnd(WaveMask mask, vector<T,N> expr)
+{
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupExclusiveAnd($1)";
+ case cuda: __intrinsic_asm "_wavePrefixAndMultiple($0, $1)";
+ case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd($1, uint4($0, 0, 0, 0))";
+ case spirv:
+ return spirv_asm {OpGroupNonUniformBitwiseAnd $$vector<T,N> result Subgroup ExclusiveScan $expr};
+ }
+}
+
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
__target_intrinsic(hlsl, "WaveMultiPrefixBitAnd($1, uint4($0, 0, 0, 0))")
__target_intrinsic(cuda, "_wavePrefixAndMultiple(_getMultiPrefixMask($0, $1)")
matrix<T,N,M> WaveMaskPrefixBitAnd(WaveMask mask, matrix<T,N,M> expr);
__generic<T : __BuiltinArithmeticType>
-__target_intrinsic(hlsl, "WaveMultiPrefixBitOr($1, uint4($0, 0, 0, 0))")
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-//__target_intrinsic(glsl, "subgroupExclusiveOr($1)")
-__target_intrinsic(cuda, "_wavePrefixOr($0, $1)")
-T WaveMaskPrefixBitOr(WaveMask mask, T expr);
+__spirv_capability(GroupNonUniformArithmetic)
+T WaveMaskPrefixBitOr(WaveMask mask, T expr)
+{
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupExclusiveOr($1)";
+ case cuda: __intrinsic_asm "_wavePrefixOr($0, $1)";
+ case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr($1, uint4($0, 0, 0, 0))";
+ case spirv:
+ return spirv_asm {OpGroupNonUniformBitwiseAnd $$T result Subgroup ExclusiveScan $expr};
+ }
+}
+
__generic<T : __BuiltinArithmeticType, let N : int>
-__target_intrinsic(hlsl, "WaveMultiPrefixBitOr($1, uint4($0, 0, 0, 0))")
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-//__target_intrinsic(glsl, "subgroupExclusiveOr($1)")
-__target_intrinsic(cuda, "_wavePrefixOrMultiple($0, $1)")
-vector<T,N> WaveMaskPrefixBitOr(WaveMask mask, vector<T,N> expr);
+__spirv_capability(GroupNonUniformArithmetic)
+vector<T,N> WaveMaskPrefixBitOr(WaveMask mask, vector<T,N> expr)
+{
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupExclusiveOr($1)";
+ case cuda: __intrinsic_asm "_wavePrefixOrMultiple($0, $1)";
+ case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr($1, uint4($0, 0, 0, 0))";
+ case spirv:
+ return spirv_asm {OpGroupNonUniformBitwiseOr $$vector<T,N> result Subgroup ExclusiveScan $expr};
+ }
+}
+
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
__target_intrinsic(hlsl, "WaveMultiPrefixBitOr($1, uint4($0, 0, 0, 0))")
__target_intrinsic(cuda, "_wavePrefixOrMultiple($0, $1)")
matrix<T,N,M> WaveMaskPrefixBitOr(WaveMask mask, matrix<T,N,M> expr);
__generic<T : __BuiltinArithmeticType>
-__target_intrinsic(hlsl, "WaveMultiPrefixBitXor($1, uint4($0, 0, 0, 0))")
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExclusiveXor($1)")
-__target_intrinsic(cuda, "_wavePrefixXor($0, $1)")
-T WaveMaskPrefixBitXor(WaveMask mask, T expr);
+__spirv_capability(GroupNonUniformArithmetic)
+T WaveMaskPrefixBitXor(WaveMask mask, T expr)
+{
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupExclusiveXor($1)";
+ case cuda: __intrinsic_asm "_wavePrefixXor($0, $1)";
+ case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor($1, uint4($0, 0, 0, 0))";
+ case spirv:
+ return spirv_asm {OpGroupNonUniformBitwiseXor $$T result Subgroup ExclusiveScan $expr};
+ }
+}
+
__generic<T : __BuiltinArithmeticType, let N : int>
-__target_intrinsic(hlsl, "WaveMultiPrefixBitXor($1, uint4($0, 0, 0, 0))")
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExclusiveXor($1)")
-__target_intrinsic(cuda, "_wavePrefixXorMultiple($0, $1)")
-vector<T,N> WaveMaskPrefixBitXor(WaveMask mask, vector<T,N> expr);
+__spirv_capability(GroupNonUniformArithmetic)
+vector<T,N> WaveMaskPrefixBitXor(WaveMask mask, vector<T,N> expr)
+{
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupExclusiveXor($1)";
+ case cuda: __intrinsic_asm "_wavePrefixXorMultiple($0, $1)";
+ case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor($1, uint4($0, 0, 0, 0))";
+ case spirv:
+ return spirv_asm {OpGroupNonUniformBitwiseXor $$vector<T,N> result Subgroup ExclusiveScan $expr};
+ }
+}
+
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
__target_intrinsic(hlsl, "WaveMultiPrefixBitXor($1, uint4($0, 0, 0, 0))")
__target_intrinsic(cuda, "_wavePrefixXorMultiple($0, $1)")
@@ -5129,21 +5305,35 @@ __generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadAcr
__generic<T : __BuiltinIntegerType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupAnd($0)")
-__target_intrinsic(hlsl)
+__spirv_capability(GroupNonUniformArithmetic)
T WaveActiveBitAnd(T expr)
{
- return WaveMaskBitAnd(WaveGetActiveMask(), expr);
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupAnd($0)";
+ case hlsl: __intrinsic_asm "WaveActiveBitAnd";
+ case spirv:
+ return spirv_asm {OpGroupNonUniformBitwiseAnd $$T result Subgroup Reduce $expr};
+ default:
+ return WaveMaskBitAnd(WaveGetActiveMask(), expr);
+ }
}
__generic<T : __BuiltinIntegerType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupAnd($0)")
-__target_intrinsic(hlsl)
+__spirv_capability(GroupNonUniformArithmetic)
vector<T, N> WaveActiveBitAnd(vector<T, N> expr)
{
- return WaveMaskBitAnd(WaveGetActiveMask(), expr);
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupAnd($0)";
+ case hlsl: __intrinsic_asm "WaveActiveBitAnd";
+ case spirv:
+ return spirv_asm {OpGroupNonUniformBitwiseAnd $$vector<T, N> result Subgroup Reduce $expr};
+ default:
+ return WaveMaskBitAnd(WaveGetActiveMask(), expr);
+ }
}
__generic<T : __BuiltinIntegerType, let N : int, let M : int>
@@ -5156,21 +5346,35 @@ matrix<T, N, M> WaveActiveBitAnd(matrix<T, N, M> expr)
__generic<T : __BuiltinIntegerType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupOr($0)")
-__target_intrinsic(hlsl)
+__spirv_capability(GroupNonUniformArithmetic)
T WaveActiveBitOr(T expr)
{
- return WaveMaskBitOr(WaveGetActiveMask(), expr);
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupOr($0)";
+ case hlsl: __intrinsic_asm "WaveActiveBitOr";
+ case spirv:
+ return spirv_asm {OpGroupNonUniformBitwiseOr $$T result Subgroup Reduce $expr};
+ default:
+ return WaveMaskBitOr(WaveGetActiveMask(), expr);
+ }
}
__generic<T : __BuiltinIntegerType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupOr($0)")
-__target_intrinsic(hlsl)
+__spirv_capability(GroupNonUniformArithmetic)
vector<T,N> WaveActiveBitOr(vector<T,N> expr)
{
- return WaveMaskBitOr(WaveGetActiveMask(), expr);
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupOr($0)";
+ case hlsl: __intrinsic_asm "WaveActiveBitOr";
+ case spirv:
+ return spirv_asm {OpGroupNonUniformBitwiseOr $$vector<T, N> result Subgroup Reduce $expr};
+ default:
+ return WaveMaskBitOr(WaveGetActiveMask(), expr);
+ }
}
__generic<T : __BuiltinIntegerType, let N : int, let M : int>
@@ -5183,21 +5387,35 @@ matrix<T, N, M> WaveActiveBitOr(matrix<T, N, M> expr)
__generic<T : __BuiltinIntegerType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupXor($0)")
-__target_intrinsic(hlsl)
+__spirv_capability(GroupNonUniformArithmetic)
T WaveActiveBitXor(T expr)
{
- return WaveMaskBitXor(WaveGetActiveMask(), expr);
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupXor($0)";
+ case hlsl: __intrinsic_asm "WaveActiveBitXor";
+ case spirv:
+ return spirv_asm {OpGroupNonUniformBitwiseXor $$T result Subgroup Reduce $expr};
+ default:
+ return WaveMaskBitXor(WaveGetActiveMask(), expr);
+ }
}
__generic<T : __BuiltinIntegerType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupXor($0)")
-__target_intrinsic(hlsl)
+__spirv_capability(GroupNonUniformArithmetic)
vector<T,N> WaveActiveBitXor(vector<T,N> expr)
{
- return WaveMaskBitXor(WaveGetActiveMask(), expr);
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "subgroupXor($0)";
+ case hlsl: __intrinsic_asm "WaveActiveBitXor";
+ case spirv:
+ return spirv_asm {OpGroupNonUniformBitwiseXor $$vector<T,N> result Subgroup Reduce $expr};
+ default:
+ return WaveMaskBitXor(WaveGetActiveMask(), expr);
+ }
}
__generic<T : __BuiltinIntegerType, let N : int, let M : int>
@@ -5452,7 +5670,23 @@ __glsl_extension(GL_KHR_shader_subgroup_basic)
__spirv_version(1.3)
__target_intrinsic(glsl, "(gl_SubgroupInvocationID)")
__target_intrinsic(cuda, "_getLaneId()")
-uint WaveGetLaneIndex();
+uint WaveGetLaneIndex()
+{
+ __target_switch
+ {
+ case glsl: __intrinsic_asm "(gl_SubgroupInvocationID)";
+ case cuda: __intrinsic_asm "_getLaneId()";
+ case hlsl: __intrinsic_asm "WaveGetLaneIndex()";
+ /*
+ case spirv:
+ let _scope = 3u; // subgroup
+ return spirv_asm
+ {
+ OpSubgroupLocalInvocationId $$uint result $_scope
+ };
+ */
+ }
+}
__glsl_extension(GL_KHR_shader_subgroup_basic)
__spirv_capability(GroupNonUniformBallot)
diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp
index f3fa213da..e171e2dd3 100644
--- a/source/slang/slang-ir-inline.cpp
+++ b/source/slang/slang-ir-inline.cpp
@@ -173,6 +173,19 @@ struct InliningPassBase
return false;
}
+ static bool hasGenericAsmInst(IRInst* func)
+ {
+ auto f = as<IRFunc>(getResolvedInstForDecorations(func));
+ if (!f)
+ return false;
+ for (auto b : f->getBlocks())
+ {
+ if (as<IRGenericAsm>(b->getTerminator()))
+ return true;
+ }
+ return false;
+ }
+
/// Determine whether `call` can be inlined, and if so write information about it to `outCallSite`
bool canInline(IRCall* call, CallSiteInfo& outCallSite)
{
@@ -236,6 +249,24 @@ struct InliningPassBase
if (callee->findDecoration<IRIntrinsicOpDecoration>())
return true;
+ // We cannot inline a function that is defined by a generic asm inst.
+ if (hasGenericAsmInst(callee))
+ return false;
+
+ for (auto decor : callee->getDecorations())
+ {
+ switch (decor->getOp())
+ {
+ case kIROp_IntrinsicOpDecoration:
+ return true;
+ case kIROp_RequireSPIRVCapabilityDecoration:
+ case kIROp_RequireSPIRVVersionDecoration:
+ case kIROp_RequireGLSLExtensionDecoration:
+ case kIROp_RequireGLSLVersionDecoration:
+ return false;
+ }
+ }
+
// At this point the `CallSiteInfo` is complete and
// could be used for inlining, but we have additional
// checks to make.
@@ -654,19 +685,6 @@ struct InliningPassBase
};
-static bool hasGenericAsmInst(IRInst* func)
-{
- auto f = as<IRFunc>(getResolvedInstForDecorations(func));
- if (!f)
- return false;
- for (auto b : f->getBlocks())
- {
- if (as<IRGenericAsm>(b->getTerminator()))
- return true;
- }
- return false;
-}
-
/// An inlining pass that inlines calls to `[unsafeForceInlineEarly]` functions
struct MandatoryEarlyInliningPass : InliningPassBase
{
@@ -681,10 +699,6 @@ struct MandatoryEarlyInliningPass : InliningPassBase
if (info.callee->findDecoration<IRIntrinsicOpDecoration>())
return true;
- // Never inline a callee that has genericASM instruction.
- if (hasGenericAsmInst(info.callee))
- return false;
-
if(info.callee->findDecoration<IRUnsafeForceInlineEarlyDecoration>())
return true;
return false;
@@ -800,10 +814,6 @@ struct ForceInliningPass : InliningPassBase
bool shouldInline(CallSiteInfo const& info)
{
- // Never inline a callee that has genericASM instruction.
- if (hasGenericAsmInst(info.callee))
- return false;
-
if (info.callee->findDecoration<IRForceInlineDecoration>() ||
info.callee->findDecoration<IRUnsafeForceInlineEarlyDecoration>()||
info.callee->findDecoration<IRIntrinsicOpDecoration>())