diff options
| author | Yong He <yonghe@outlook.com> | 2023-08-29 05:55:49 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-29 20:55:49 +0800 |
| commit | 9d4e044bad6161a593806fc6fb610d41aa8b4b22 (patch) | |
| tree | 28214d3a2a56762f3b858299696f4d4f8a85686f /source | |
| parent | b8fcb586f6a931ab674b0da7f375f38aff9608d4 (diff) | |
Add more wave intrinsics. (#3162)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/hlsl.meta.slang | 382 | ||||
| -rw-r--r-- | source/slang/slang-ir-inline.cpp | 52 |
2 files changed, 339 insertions, 95 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 1255e43e0..3c966bb4a 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -4936,6 +4936,7 @@ __generic<T : __BuiltinType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) __cuda_sm_version(7.0) +__spirv_capability(GroupNonUniformVote) bool WaveMaskAllEqual(WaveMask mask, vector<T,N> value) { __target_switch @@ -4967,17 +4968,59 @@ bool WaveMaskAllEqual(WaveMask mask, matrix<T,N,M> value); __generic<T : __BuiltinArithmeticType> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExclusiveMul($1)") -__target_intrinsic(cuda, "_wavePrefixProduct($0, $1)") -__target_intrinsic(hlsl, "WavePrefixProduct($1)") -T WaveMaskPrefixProduct(WaveMask mask, T expr); +__spirv_capability(GroupNonUniformArithmetic) +T WaveMaskPrefixProduct(WaveMask mask, T expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupExclusiveMul($1)"; + case cuda: __intrinsic_asm "_wavePrefixProduct($0, $1)"; + case hlsl: __intrinsic_asm "WavePrefixProduct($1)"; + case spirv: + if (__isFloat<T>()) + return spirv_asm {OpGroupNonUniformFMul $$T result Subgroup ExclusiveScan $expr}; + else if (__isSignedInt<T>()) + { + return spirv_asm + { + // TODO: use the correct integer width + OpBitcast $$uint %uvalue $expr; + OpGroupNonUniformIMul $$uint %mulResult Subgroup ExclusiveScan %uvalue; + OpBitcast $$T result %mulResult + }; + } + else if (__isUnsignedInt<T>()) + return spirv_asm {OpGroupNonUniformIMul $$T result Subgroup ExclusiveScan $expr}; + } +} __generic<T : __BuiltinArithmeticType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExclusiveMul($1)") -__target_intrinsic(cuda, "_wavePrefixProductMultiple($0, $1)") -__target_intrinsic(hlsl, "WavePrefixProduct($1)") -vector<T,N> WaveMaskPrefixProduct(WaveMask mask, vector<T,N> expr); +__spirv_capability(GroupNonUniformArithmetic) +vector<T,N> WaveMaskPrefixProduct(WaveMask mask, vector<T,N> expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupExclusiveMul($1)"; + case cuda: __intrinsic_asm "_wavePrefixProductMultiple($0, $1)"; + case hlsl: __intrinsic_asm "WavePrefixProduct($1)"; + case spirv: + if (__isFloat<T>()) + return spirv_asm {OpGroupNonUniformFMul $$vector<T,N> result Subgroup ExclusiveScan $expr}; + else if (__isSignedInt<T>()) + { + return spirv_asm + { + // TODO: use the correct integer width + OpBitcast $$uint %uvalue $expr; + OpGroupNonUniformIMul $$vector<uint,N> %mulResult Subgroup ExclusiveScan %uvalue; + OpBitcast $$T result %mulResult + }; + } + else if (__isUnsignedInt<T>()) + return spirv_asm {OpGroupNonUniformIMul $$vector<T,N> result Subgroup ExclusiveScan $expr}; + } +} __generic<T : __BuiltinArithmeticType, let N : int, let M : int> __target_intrinsic(cuda, "_wavePrefixProductMultiple($0, $1)") __target_intrinsic(hlsl, "WavePrefixProduct($1)") @@ -4986,17 +5029,60 @@ matrix<T,N,M> WaveMaskPrefixProduct(WaveMask mask, matrix<T,N,M> expr); __generic<T : __BuiltinArithmeticType> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExclusiveAdd($1)") -__target_intrinsic(cuda, "_wavePrefixSum($0, $1)") -__target_intrinsic(hlsl, "WavePrefixSum($1)") -T WaveMaskPrefixSum(WaveMask mask, T expr); +__spirv_capability(GroupNonUniformArithmetic) +T WaveMaskPrefixSum(WaveMask mask, T expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupExclusiveAdd($1)"; + case cuda: __intrinsic_asm "_wavePrefixSum($0, $1)"; + case hlsl: __intrinsic_asm "WavePrefixSum($1)"; + case spirv: + if (__isFloat<T>()) + return spirv_asm {OpGroupNonUniformFAdd $$T result Subgroup ExclusiveScan $expr}; + else if (__isSignedInt<T>()) + { + return spirv_asm + { + // TODO: use the correct integer width + %uvalue:$$uint = OpBitcast $expr; + %mulResult:$$uint = OpGroupNonUniformIAdd Subgroup ExclusiveScan %uvalue; + result:$$T = OpBitcast %mulResult + }; + } + else if (__isUnsignedInt<T>()) + return spirv_asm {OpGroupNonUniformIAdd $$T result Subgroup ExclusiveScan $expr}; + } +} + __generic<T : __BuiltinArithmeticType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExclusiveAdd($1)") -__target_intrinsic(cuda, "_wavePrefixSumMultiple($0, $1)") -__target_intrinsic(hlsl, "WavePrefixSum($1)") -vector<T,N> WaveMaskPrefixSum(WaveMask mask, vector<T,N> expr); +__spirv_capability(GroupNonUniformArithmetic) +vector<T,N> WaveMaskPrefixSum(WaveMask mask, vector<T,N> expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupExclusiveAdd($1)"; + case cuda: __intrinsic_asm "_wavePrefixSumMultiple($0, $1)"; + case hlsl: __intrinsic_asm "WavePrefixSum($1)"; + case spirv: + if (__isFloat<T>()) + return spirv_asm {OpGroupNonUniformFAdd $$vector<T,N> result Subgroup ExclusiveScan $expr}; + else if (__isSignedInt<T>()) + { + return spirv_asm + { + // TODO: use the correct integer width + %uvalue: $$uint = OpBitcast $expr; + %mulResult: $$vector<uint,N> = OpGroupNonUniformIAdd Subgroup ExclusiveScan %uvalue; + result: $$T = OpBitcast %mulResult + }; + } + else if (__isUnsignedInt<T>()) + return spirv_asm {OpGroupNonUniformIAdd $$vector<T,N> result Subgroup ExclusiveScan $expr}; + } +} __generic<T : __BuiltinArithmeticType, let N : int, let M : int> __target_intrinsic(cuda, "_wavePrefixSumMultiple($0, $1)") __target_intrinsic(hlsl, "WavePrefixSum($1)") @@ -5005,15 +5091,34 @@ matrix<T,N,M> WaveMaskPrefixSum(WaveMask mask, matrix<T,N,M> expr); __generic<T : __BuiltinType> __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBroadcastFirst($1)") -__target_intrinsic(cuda, "_waveReadFirst($0, $1)") -T WaveMaskReadLaneFirst(WaveMask mask, T expr); +__spirv_capability(GroupNonUniformBallot) +T WaveMaskReadLaneFirst(WaveMask mask, T expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupBroadcastFirst($1)"; + case cuda: __intrinsic_asm "_waveReadFirst($0, $1)"; + case hlsl: __intrinsic_asm "WaveReadLaneFirst($1)"; + case spirv: + return spirv_asm {OpGroupNonUniformBroadcastFirst $$T result Subgroup $expr}; + } +} __generic<T : __BuiltinType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBroadcastFirst($1)") -__target_intrinsic(cuda, "_waveReadFirstMultiple($0, $1)") -vector<T,N> WaveMaskReadLaneFirst(WaveMask mask, vector<T,N> expr); +__spirv_capability(GroupNonUniformBallot) +vector<T,N> WaveMaskReadLaneFirst(WaveMask mask, vector<T,N> expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupBroadcastFirst($1)"; + case cuda: __intrinsic_asm "_waveReadFirstMultiple($0, $1)"; + case hlsl: __intrinsic_asm "WaveReadLaneFirst($1)"; + case spirv: + return spirv_asm {OpGroupNonUniformBroadcastFirst $$vector<T,N> result Subgroup $expr}; + } +} + __generic<T : __BuiltinType, let N : int, let M : int> __target_intrinsic(cuda, "_waveReadFirstMultiple($0, $1)") matrix<T,N,M> WaveMaskReadLaneFirst(WaveMask mask, matrix<T,N,M> expr); @@ -5023,21 +5128,38 @@ matrix<T,N,M> WaveMaskReadLaneFirst(WaveMask mask, matrix<T,N,M> expr); // TODO(JS): On HLSL it only works for 32 bits or less __generic<T : __BuiltinType> -__target_intrinsic(hlsl, "WaveMatch($1).x") __glsl_extension(GL_NV_shader_subgroup_partitioned) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupPartitionNV($1).x") __cuda_sm_version(7.0) -__target_intrinsic(cuda, "_waveMatchScalar($0, $1).x") -WaveMask WaveMaskMatch(WaveMask mask, T value); +__spirv_capability(GroupNonUniformPartitionedNV) +WaveMask WaveMaskMatch(WaveMask mask, T value) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupPartitionNV($1).x"; + case cuda: __intrinsic_asm "_waveMatchScalar($0, $1).x"; + case hlsl: __intrinsic_asm "WaveMatch($1).x"; + case spirv: + return (spirv_asm {OpGroupNonUniformPartitionNV $$uint4 result $value}).x; + } +} __generic<T : __BuiltinType, let N : int> -__target_intrinsic(hlsl, "WaveMatch($1).x") __glsl_extension(GL_NV_shader_subgroup_partitioned) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupPartitionNV($1).x") __cuda_sm_version(7.0) -__target_intrinsic(cuda, "_waveMatchMultiple($0, $1)") -WaveMask WaveMaskMatch(WaveMask mask, vector<T,N> value); +__spirv_capability(GroupNonUniformPartitionedNV) +WaveMask WaveMaskMatch(WaveMask mask, vector<T,N> value) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupPartitionNV($1).x"; + case cuda: __intrinsic_asm "_waveMatchMultiple($0, $1).x"; + case hlsl: __intrinsic_asm "WaveMatch($1).x"; + case spirv: + return (spirv_asm {OpGroupNonUniformPartitionNV $$uint4 result $value}).x; + } +} + __generic<T : __BuiltinType, let N : int, let M : int> __target_intrinsic(hlsl, "WaveMatch($1).x") __glsl_extension(GL_NV_shader_subgroup_partitioned) @@ -5048,57 +5170,111 @@ __target_intrinsic(cuda, "_waveMatchMultiple($0, $1)") WaveMask WaveMaskMatch(WaveMask mask, matrix<T,N,M> value); __generic<T : __BuiltinArithmeticType> -__target_intrinsic(hlsl, "WaveMultiPrefixBitAnd($1, uint4($0, 0, 0, 0))") __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -//__target_intrinsic(glsl, "subgroupExclusiveAnd($1)") -__target_intrinsic(cuda, "_wavePrefixAnd($0, $1)") -T WaveMaskPrefixBitAnd(WaveMask mask, T expr); -__target_intrinsic(hlsl, "WaveMultiPrefixBitAnd($1, uint4($0, 0, 0, 0))") +__spirv_capability(GroupNonUniformArithmetic) +T WaveMaskPrefixBitAnd(WaveMask mask, T expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupExclusiveAnd($1)"; + case cuda: __intrinsic_asm "_wavePrefixAnd($0, $1)"; + case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd($1, uint4($0, 0, 0, 0))"; + case spirv: + return spirv_asm {OpGroupNonUniformBitwiseAnd $$T result Subgroup ExclusiveScan $expr}; + } +} + +__generic<T : __BuiltinArithmeticType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExclusiveAnd($1)") -__target_intrinsic(cuda, "_wavePrefixAndMultiple($0, $1)") -__generic<T : __BuiltinArithmeticType, let N : int> -vector<T,N> WaveMaskPrefixBitAnd(WaveMask mask, vector<T,N> expr); +__spirv_capability(GroupNonUniformArithmetic) +vector<T,N> WaveMaskPrefixBitAnd(WaveMask mask, vector<T,N> expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupExclusiveAnd($1)"; + case cuda: __intrinsic_asm "_wavePrefixAndMultiple($0, $1)"; + case hlsl: __intrinsic_asm "WaveMultiPrefixBitAnd($1, uint4($0, 0, 0, 0))"; + case spirv: + return spirv_asm {OpGroupNonUniformBitwiseAnd $$vector<T,N> result Subgroup ExclusiveScan $expr}; + } +} + __generic<T : __BuiltinArithmeticType, let N : int, let M : int> __target_intrinsic(hlsl, "WaveMultiPrefixBitAnd($1, uint4($0, 0, 0, 0))") __target_intrinsic(cuda, "_wavePrefixAndMultiple(_getMultiPrefixMask($0, $1)") matrix<T,N,M> WaveMaskPrefixBitAnd(WaveMask mask, matrix<T,N,M> expr); __generic<T : __BuiltinArithmeticType> -__target_intrinsic(hlsl, "WaveMultiPrefixBitOr($1, uint4($0, 0, 0, 0))") __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -//__target_intrinsic(glsl, "subgroupExclusiveOr($1)") -__target_intrinsic(cuda, "_wavePrefixOr($0, $1)") -T WaveMaskPrefixBitOr(WaveMask mask, T expr); +__spirv_capability(GroupNonUniformArithmetic) +T WaveMaskPrefixBitOr(WaveMask mask, T expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupExclusiveOr($1)"; + case cuda: __intrinsic_asm "_wavePrefixOr($0, $1)"; + case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr($1, uint4($0, 0, 0, 0))"; + case spirv: + return spirv_asm {OpGroupNonUniformBitwiseAnd $$T result Subgroup ExclusiveScan $expr}; + } +} + __generic<T : __BuiltinArithmeticType, let N : int> -__target_intrinsic(hlsl, "WaveMultiPrefixBitOr($1, uint4($0, 0, 0, 0))") __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -//__target_intrinsic(glsl, "subgroupExclusiveOr($1)") -__target_intrinsic(cuda, "_wavePrefixOrMultiple($0, $1)") -vector<T,N> WaveMaskPrefixBitOr(WaveMask mask, vector<T,N> expr); +__spirv_capability(GroupNonUniformArithmetic) +vector<T,N> WaveMaskPrefixBitOr(WaveMask mask, vector<T,N> expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupExclusiveOr($1)"; + case cuda: __intrinsic_asm "_wavePrefixOrMultiple($0, $1)"; + case hlsl: __intrinsic_asm "WaveMultiPrefixBitOr($1, uint4($0, 0, 0, 0))"; + case spirv: + return spirv_asm {OpGroupNonUniformBitwiseOr $$vector<T,N> result Subgroup ExclusiveScan $expr}; + } +} + __generic<T : __BuiltinArithmeticType, let N : int, let M : int> __target_intrinsic(hlsl, "WaveMultiPrefixBitOr($1, uint4($0, 0, 0, 0))") __target_intrinsic(cuda, "_wavePrefixOrMultiple($0, $1)") matrix<T,N,M> WaveMaskPrefixBitOr(WaveMask mask, matrix<T,N,M> expr); __generic<T : __BuiltinArithmeticType> -__target_intrinsic(hlsl, "WaveMultiPrefixBitXor($1, uint4($0, 0, 0, 0))") __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExclusiveXor($1)") -__target_intrinsic(cuda, "_wavePrefixXor($0, $1)") -T WaveMaskPrefixBitXor(WaveMask mask, T expr); +__spirv_capability(GroupNonUniformArithmetic) +T WaveMaskPrefixBitXor(WaveMask mask, T expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupExclusiveXor($1)"; + case cuda: __intrinsic_asm "_wavePrefixXor($0, $1)"; + case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor($1, uint4($0, 0, 0, 0))"; + case spirv: + return spirv_asm {OpGroupNonUniformBitwiseXor $$T result Subgroup ExclusiveScan $expr}; + } +} + __generic<T : __BuiltinArithmeticType, let N : int> -__target_intrinsic(hlsl, "WaveMultiPrefixBitXor($1, uint4($0, 0, 0, 0))") __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExclusiveXor($1)") -__target_intrinsic(cuda, "_wavePrefixXorMultiple($0, $1)") -vector<T,N> WaveMaskPrefixBitXor(WaveMask mask, vector<T,N> expr); +__spirv_capability(GroupNonUniformArithmetic) +vector<T,N> WaveMaskPrefixBitXor(WaveMask mask, vector<T,N> expr) +{ + __target_switch + { + case glsl: __intrinsic_asm "subgroupExclusiveXor($1)"; + case cuda: __intrinsic_asm "_wavePrefixXorMultiple($0, $1)"; + case hlsl: __intrinsic_asm "WaveMultiPrefixBitXor($1, uint4($0, 0, 0, 0))"; + case spirv: + return spirv_asm {OpGroupNonUniformBitwiseXor $$vector<T,N> result Subgroup ExclusiveScan $expr}; + } +} + __generic<T : __BuiltinArithmeticType, let N : int, let M : int> __target_intrinsic(hlsl, "WaveMultiPrefixBitXor($1, uint4($0, 0, 0, 0))") __target_intrinsic(cuda, "_wavePrefixXorMultiple($0, $1)") @@ -5129,21 +5305,35 @@ __generic<T : __BuiltinType, let N : int, let M : int> matrix<T,N,M> QuadReadAcr __generic<T : __BuiltinIntegerType> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupAnd($0)") -__target_intrinsic(hlsl) +__spirv_capability(GroupNonUniformArithmetic) T WaveActiveBitAnd(T expr) { - return WaveMaskBitAnd(WaveGetActiveMask(), expr); + __target_switch + { + case glsl: __intrinsic_asm "subgroupAnd($0)"; + case hlsl: __intrinsic_asm "WaveActiveBitAnd"; + case spirv: + return spirv_asm {OpGroupNonUniformBitwiseAnd $$T result Subgroup Reduce $expr}; + default: + return WaveMaskBitAnd(WaveGetActiveMask(), expr); + } } __generic<T : __BuiltinIntegerType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupAnd($0)") -__target_intrinsic(hlsl) +__spirv_capability(GroupNonUniformArithmetic) vector<T, N> WaveActiveBitAnd(vector<T, N> expr) { - return WaveMaskBitAnd(WaveGetActiveMask(), expr); + __target_switch + { + case glsl: __intrinsic_asm "subgroupAnd($0)"; + case hlsl: __intrinsic_asm "WaveActiveBitAnd"; + case spirv: + return spirv_asm {OpGroupNonUniformBitwiseAnd $$vector<T, N> result Subgroup Reduce $expr}; + default: + return WaveMaskBitAnd(WaveGetActiveMask(), expr); + } } __generic<T : __BuiltinIntegerType, let N : int, let M : int> @@ -5156,21 +5346,35 @@ matrix<T, N, M> WaveActiveBitAnd(matrix<T, N, M> expr) __generic<T : __BuiltinIntegerType> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupOr($0)") -__target_intrinsic(hlsl) +__spirv_capability(GroupNonUniformArithmetic) T WaveActiveBitOr(T expr) { - return WaveMaskBitOr(WaveGetActiveMask(), expr); + __target_switch + { + case glsl: __intrinsic_asm "subgroupOr($0)"; + case hlsl: __intrinsic_asm "WaveActiveBitOr"; + case spirv: + return spirv_asm {OpGroupNonUniformBitwiseOr $$T result Subgroup Reduce $expr}; + default: + return WaveMaskBitOr(WaveGetActiveMask(), expr); + } } __generic<T : __BuiltinIntegerType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupOr($0)") -__target_intrinsic(hlsl) +__spirv_capability(GroupNonUniformArithmetic) vector<T,N> WaveActiveBitOr(vector<T,N> expr) { - return WaveMaskBitOr(WaveGetActiveMask(), expr); + __target_switch + { + case glsl: __intrinsic_asm "subgroupOr($0)"; + case hlsl: __intrinsic_asm "WaveActiveBitOr"; + case spirv: + return spirv_asm {OpGroupNonUniformBitwiseOr $$vector<T, N> result Subgroup Reduce $expr}; + default: + return WaveMaskBitOr(WaveGetActiveMask(), expr); + } } __generic<T : __BuiltinIntegerType, let N : int, let M : int> @@ -5183,21 +5387,35 @@ matrix<T, N, M> WaveActiveBitOr(matrix<T, N, M> expr) __generic<T : __BuiltinIntegerType> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupXor($0)") -__target_intrinsic(hlsl) +__spirv_capability(GroupNonUniformArithmetic) T WaveActiveBitXor(T expr) { - return WaveMaskBitXor(WaveGetActiveMask(), expr); + __target_switch + { + case glsl: __intrinsic_asm "subgroupXor($0)"; + case hlsl: __intrinsic_asm "WaveActiveBitXor"; + case spirv: + return spirv_asm {OpGroupNonUniformBitwiseXor $$T result Subgroup Reduce $expr}; + default: + return WaveMaskBitXor(WaveGetActiveMask(), expr); + } } __generic<T : __BuiltinIntegerType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupXor($0)") -__target_intrinsic(hlsl) +__spirv_capability(GroupNonUniformArithmetic) vector<T,N> WaveActiveBitXor(vector<T,N> expr) { - return WaveMaskBitXor(WaveGetActiveMask(), expr); + __target_switch + { + case glsl: __intrinsic_asm "subgroupXor($0)"; + case hlsl: __intrinsic_asm "WaveActiveBitXor"; + case spirv: + return spirv_asm {OpGroupNonUniformBitwiseXor $$vector<T,N> result Subgroup Reduce $expr}; + default: + return WaveMaskBitXor(WaveGetActiveMask(), expr); + } } __generic<T : __BuiltinIntegerType, let N : int, let M : int> @@ -5452,7 +5670,23 @@ __glsl_extension(GL_KHR_shader_subgroup_basic) __spirv_version(1.3) __target_intrinsic(glsl, "(gl_SubgroupInvocationID)") __target_intrinsic(cuda, "_getLaneId()") -uint WaveGetLaneIndex(); +uint WaveGetLaneIndex() +{ + __target_switch + { + case glsl: __intrinsic_asm "(gl_SubgroupInvocationID)"; + case cuda: __intrinsic_asm "_getLaneId()"; + case hlsl: __intrinsic_asm "WaveGetLaneIndex()"; + /* + case spirv: + let _scope = 3u; // subgroup + return spirv_asm + { + OpSubgroupLocalInvocationId $$uint result $_scope + }; + */ + } +} __glsl_extension(GL_KHR_shader_subgroup_basic) __spirv_capability(GroupNonUniformBallot) diff --git a/source/slang/slang-ir-inline.cpp b/source/slang/slang-ir-inline.cpp index f3fa213da..e171e2dd3 100644 --- a/source/slang/slang-ir-inline.cpp +++ b/source/slang/slang-ir-inline.cpp @@ -173,6 +173,19 @@ struct InliningPassBase return false; } + static bool hasGenericAsmInst(IRInst* func) + { + auto f = as<IRFunc>(getResolvedInstForDecorations(func)); + if (!f) + return false; + for (auto b : f->getBlocks()) + { + if (as<IRGenericAsm>(b->getTerminator())) + return true; + } + return false; + } + /// Determine whether `call` can be inlined, and if so write information about it to `outCallSite` bool canInline(IRCall* call, CallSiteInfo& outCallSite) { @@ -236,6 +249,24 @@ struct InliningPassBase if (callee->findDecoration<IRIntrinsicOpDecoration>()) return true; + // We cannot inline a function that is defined by a generic asm inst. + if (hasGenericAsmInst(callee)) + return false; + + for (auto decor : callee->getDecorations()) + { + switch (decor->getOp()) + { + case kIROp_IntrinsicOpDecoration: + return true; + case kIROp_RequireSPIRVCapabilityDecoration: + case kIROp_RequireSPIRVVersionDecoration: + case kIROp_RequireGLSLExtensionDecoration: + case kIROp_RequireGLSLVersionDecoration: + return false; + } + } + // At this point the `CallSiteInfo` is complete and // could be used for inlining, but we have additional // checks to make. @@ -654,19 +685,6 @@ struct InliningPassBase }; -static bool hasGenericAsmInst(IRInst* func) -{ - auto f = as<IRFunc>(getResolvedInstForDecorations(func)); - if (!f) - return false; - for (auto b : f->getBlocks()) - { - if (as<IRGenericAsm>(b->getTerminator())) - return true; - } - return false; -} - /// An inlining pass that inlines calls to `[unsafeForceInlineEarly]` functions struct MandatoryEarlyInliningPass : InliningPassBase { @@ -681,10 +699,6 @@ struct MandatoryEarlyInliningPass : InliningPassBase if (info.callee->findDecoration<IRIntrinsicOpDecoration>()) return true; - // Never inline a callee that has genericASM instruction. - if (hasGenericAsmInst(info.callee)) - return false; - if(info.callee->findDecoration<IRUnsafeForceInlineEarlyDecoration>()) return true; return false; @@ -800,10 +814,6 @@ struct ForceInliningPass : InliningPassBase bool shouldInline(CallSiteInfo const& info) { - // Never inline a callee that has genericASM instruction. - if (hasGenericAsmInst(info.callee)) - return false; - if (info.callee->findDecoration<IRForceInlineDecoration>() || info.callee->findDecoration<IRUnsafeForceInlineEarlyDecoration>()|| info.callee->findDecoration<IRIntrinsicOpDecoration>()) |
