diff options
| -rw-r--r-- | prelude/slang-cuda-prelude.h | 139 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 103 | ||||
| -rw-r--r-- | source/slang/slang-profile-defs.h | 2 | ||||
| -rw-r--r-- | tests/hlsl-intrinsic/wave-multi-prefix.slang | 26 | ||||
| -rw-r--r-- | tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt | 8 |
5 files changed, 206 insertions, 72 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index 6a1d87183..dcc585b9c 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -513,6 +513,18 @@ __forceinline__ __device__ int _getLaneLtMask() return (int(1) << _getLaneId()) - 1; } +// Return a mask suitable for the straight 'Prefix' style ops +__forceinline__ __device__ int _getPrefixMask() +{ + return __activemask(); +} + +// Return a mask suitable for the 'MultiPrefix' style functions +__forceinline__ __device__ int _getMultiPrefixMask(int mask) +{ + return mask; +} + // Note! Note will return true if mask is 0, but thats okay, because there must be one // lane active to execute anything __inline__ __device__ bool _waveIsSingleLane(int mask) @@ -671,9 +683,9 @@ __device__ T _waveReduceScalar(T val) while (remaining) { const int laneBit = remaining & -remaining; - /* Get the sourceLane */ + // Get the sourceLane const int srcLane = __ffs(laneBit) - 1; - /* Broadcast (can also broadcast to self) */ + // Broadcast (can also broadcast to self) result = INTF::doOp(result, __shfl_sync(mask, val, srcLane)); remaining &= ~laneBit; } @@ -718,9 +730,9 @@ __device__ void _waveReduceMultiple(T* val) while (remaining) { const int laneBit = remaining & -remaining; - /* Get the sourceLane */ + // Get the sourceLane const int srcLane = __ffs(laneBit) - 1; - /* Broadcast (can also broadcast to self) */ + // Broadcast (can also broadcast to self) for (size_t i = 0; i < COUNT; ++i) { val[i] = INTF::doOp(val[i], __shfl_sync(mask, originalVal[i], srcLane)); @@ -786,7 +798,7 @@ __inline__ __device__ T _waveMaxMultiple(T val) { typedef typename ElementTypeT template <typename T> __inline__ __device__ bool _waveAllEqual(T val) { - // __match_all_sync is a synchronises so can use __activemask() + // __match_all_sync synchronizes so can use __activemask() const int mask = __activemask(); int pred; __match_all_sync(mask, val, &pred); @@ -798,13 +810,10 @@ __inline__ __device__ bool _waveAllEqualMultiple(T inVal) { typedef typename ElementTypeTrait<T>::Type ElemType; const size_t count = sizeof(T) / sizeof(ElemType); - - // __match_all_sync is a synchronises so can use __activemask() + // __match_all_sync synchronizes so can use __activemask() const int mask = __activemask(); int pred; - const ElemType* src = (const ElemType*)&inVal; - for (size_t i = 0; i < count; ++i) { __match_all_sync(mask, src[i], &pred); @@ -829,20 +838,15 @@ __inline__ __device__ T _waveReadFirstMultiple(T inVal) { typedef typename ElementTypeTrait<T>::Type ElemType; const size_t count = sizeof(T) / sizeof(ElemType); - T outVal; - const ElemType* src = (const ElemType*)&inVal; ElemType* dst = (ElemType*)&outVal; - const int mask = __activemask(); const int lowestLaneId = __ffs(mask) - 1; - for (size_t i = 0; i < count; ++i) { dst[i] = __shfl_sync(mask, src[i], lowestLaneId); } - return outVal; } @@ -851,19 +855,14 @@ __inline__ __device__ T _waveReadLaneAtMultiple(T inVal, int lane) { typedef typename ElementTypeTrait<T>::Type ElemType; const size_t count = sizeof(T) / sizeof(ElemType); - T outVal; - const ElemType* src = (const ElemType*)&inVal; ElemType* dst = (ElemType*)&outVal; - const int mask = __activemask(); - for (size_t i = 0; i < count; ++i) { dst[i] = __shfl_sync(mask, src[i], lane); } - return outVal; } @@ -872,9 +871,8 @@ __inline__ __device__ T _waveReadLaneAtMultiple(T inVal, int lane) // Invertable means that when we get to the end of the reduce, we can remove val (to make exclusive), using // the inverse of the op. template <typename INTF, typename T> -__device__ T _wavePrefixInvertableScalar(T val) +__device__ T _wavePrefixInvertableScalar(T val, const int mask) { - const int mask = __activemask(); const int offsetSize = _waveCalcPow2Offset(mask); const int laneId = _getLaneId(); @@ -923,9 +921,8 @@ __device__ T _wavePrefixInvertableScalar(T val) // This implementation separately tracks the value to be propogated, and the value // that is the final result template <typename INTF, typename T> -__device__ T _wavePrefixScalar(T val) +__device__ T _wavePrefixScalar(T val, const int mask) { - const int mask = __activemask(); const int offsetSize = _waveCalcPow2Offset(mask); const int laneId = _getLaneId(); @@ -971,7 +968,7 @@ __device__ T _wavePrefixScalar(T val) template <typename INTF, typename T, size_t COUNT> -__device__ T _copy(T* dst, const T* src) +__device__ T _waveOpCopy(T* dst, const T* src) { for (size_t j = 0; j < COUNT; ++j) { @@ -981,7 +978,7 @@ __device__ T _copy(T* dst, const T* src) template <typename INTF, typename T, size_t COUNT> -__device__ T _doInverse(T* inOut, const T* val) +__device__ T _waveOpDoInverse(T* inOut, const T* val) { for (size_t j = 0; j < COUNT; ++j) { @@ -990,7 +987,7 @@ __device__ T _doInverse(T* inOut, const T* val) } template <typename INTF, typename T, size_t COUNT> -__device__ T _setInitial(T* out, const T* val) +__device__ T _waveOpSetInitial(T* out, const T* val) { for (size_t j = 0; j < COUNT; ++j) { @@ -999,14 +996,13 @@ __device__ T _setInitial(T* out, const T* val) } template <typename INTF, typename T, size_t COUNT> -__device__ T _wavePrefixInvertableMultiple(T* val) +__device__ T _wavePrefixInvertableMultiple(T* val, const int mask) { - const int mask = __activemask(); const int offsetSize = _waveCalcPow2Offset(mask); const int laneId = _getLaneId(); T originalVal[COUNT]; - _copy<INTF, T, COUNT>(originalVal, val); + _waveOpCopy<INTF, T, COUNT>(originalVal, val); if (offsetSize > 0) { @@ -1027,11 +1023,11 @@ __device__ T _wavePrefixInvertableMultiple(T* val) } } // Remove originalVal from the result, by applyin inverse - _doInverse<INTF, T, COUNT>(val, originalVal); + _waveOpDoInverse<INTF, T, COUNT>(val, originalVal); } else { - _setInitial<INTF, T, COUNT>(val, val); + _waveOpSetInitial<INTF, T, COUNT>(val, val); if (!_waveIsSingleLane(mask)) { int remaining = mask; @@ -1058,16 +1054,15 @@ __device__ T _wavePrefixInvertableMultiple(T* val) } template <typename INTF, typename T, size_t COUNT> -__device__ T _wavePrefixMultiple(T* val) +__device__ T _wavePrefixMultiple(T* val, const int mask) { - const int mask = __activemask(); const int offsetSize = _waveCalcPow2Offset(mask); const int laneId = _getLaneId(); T work[COUNT]; - _copy<INTF, T, COUNT>(work, val); - _setInitial<INTF, T, COUNT>(val, val); + _waveOpCopy<INTF, T, COUNT>(work, val); + _waveOpSetInitial<INTF, T, COUNT>(val, val); if (offsetSize > 0) { @@ -1113,29 +1108,89 @@ __device__ T _wavePrefixMultiple(T* val) } } } - + template <typename T> -__inline__ __device__ T _wavePrefixProduct(T val) { return _wavePrefixScalar<WaveOpMul<T>, T>(val); } +__inline__ __device__ T _wavePrefixProduct(T val, const int mask = _getPrefixMask()) { return _wavePrefixScalar<WaveOpMul<T>, T>(val, mask); } template <typename T> -__inline__ __device__ T _wavePrefixSum(T val) { return _wavePrefixInvertableScalar<WaveOpAdd<T>, T>(val); } +__inline__ __device__ T _wavePrefixSum(T val, const int mask = _getPrefixMask()) { return _wavePrefixInvertableScalar<WaveOpAdd<T>, T>(val, mask); } + +template <typename T> +__inline__ __device__ T _wavePrefixXor(T val, const int mask = _getPrefixMask()) { return _wavePrefixInvertableScalar<WaveOpXor<T>, T>(val, mask); } + +template <typename T> +__inline__ __device__ T _wavePrefixOr(T val, const int mask = _getPrefixMask()) { return _wavePrefixScalar<WaveOpOr<T>, T>(val, mask); } + +template <typename T> +__inline__ __device__ T _wavePrefixAnd(T val, const int mask = _getPrefixMask()) { return _wavePrefixScalar<WaveOpAnd<T>, T>(val, mask); } + template <typename T> -__inline__ __device__ T _wavePrefixProductMultiple(T val) +__inline__ __device__ T _wavePrefixProductMultiple(T val, const int mask = _getPrefixMask()) { typedef typename ElementTypeTrait<T>::Type ElemType; - _wavePrefixInvertableMultiple<WaveOpMul<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>((ElemType*)&val); + _wavePrefixInvertableMultiple<WaveOpMul<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>((ElemType*)&val, mask); return val; } template <typename T> -__inline__ __device__ T _wavePrefixSumMultiple(T val) +__inline__ __device__ T _wavePrefixSumMultiple(T val, const int mask = _getPrefixMask()) { typedef typename ElementTypeTrait<T>::Type ElemType; - _wavePrefixMultiple<WaveOpAdd<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>((ElemType*)&val); + _wavePrefixInvertableMultiple<WaveOpAdd<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>((ElemType*)&val, mask); return val; } +template <typename T> +__inline__ __device__ T _wavePrefixXorMultiple(T val, const int mask = _getPrefixMask()) +{ + typedef typename ElementTypeTrait<T>::Type ElemType; + _wavePrefixInvertableMultiple<WaveOpXor<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>((ElemType*)&val, mask); + return val; +} + +template <typename T> +__inline__ __device__ T _wavePrefixOrMultiple(T val, const int mask = _getPrefixMask()) +{ + typedef typename ElementTypeTrait<T>::Type ElemType; + _wavePrefixMultiple<WaveOpOr<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>((ElemType*)&val, mask); + return val; +} + +template <typename T> +__inline__ __device__ T _wavePrefixAndMultiple(T val, const int mask = _getPrefixMask()) +{ + typedef typename ElementTypeTrait<T>::Type ElemType; + _wavePrefixMultiple<WaveOpAnd<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>((ElemType*)&val, mask); + return val; +} + +template <typename T> +__inline__ __device__ uint4 _waveMatchScalar(T val) +{ + // __match_all_sync synchronizes so can use __activemask() + const int mask = __activemask(); + int pred; + return make_uint4(__match_all_sync(mask, val, &pred), 0, 0, 0); +} + +template <typename T> +__inline__ __device__ uint4 _waveMatchMultiple(const T& inVal) +{ + typedef typename ElementTypeTrait<T>::Type ElemType; + const size_t count = sizeof(T) / sizeof(ElemType); + // __match_all_sync synchronizes so can use __activemask() + const int mask = __activemask(); + int pred; + const ElemType* src = (const ElemType*)&inVal; + uint matchBits = 0xffffffff; + for (size_t i = 0; i < count && matchBits; ++i) + { + matchBits = matchBits & __match_all_sync(mask, src[i], &pred); + } + return make_uint4(matchBits, 0, 0, 0); +} + /* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 30c86b3eb..6ac1038f8 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -2628,9 +2628,7 @@ __generic<T : __BuiltinType, let N : int, let M : int> __target_intrinsic(cuda, "_waveAllEqualMultiple($0)") bool WaveActiveAllEqual(matrix<T,N,M> value); -__generic<T : __BuiltinType> uint4 WaveMatch(T value); -__generic<T : __BuiltinType, let N : int> uint4 WaveMatch(vector<T,N> value); -__generic<T : __BuiltinType, let N : int, let M : int> uint4 WaveMatch(matrix<T,N,M> value); + __glsl_extension(GL_KHR_shader_subgroup_vote) __spirv_version(1.3) @@ -2650,11 +2648,9 @@ __target_intrinsic(glsl, "subgroupBallot($0)") __target_intrinsic(cuda, "make_uint4(__ballot_sync(__activemask(), $0), 0, 0, 0)") uint4 WaveActiveBallot(bool condition); -// TODO(JS): -// subgroupBallotBitCount seems to take a uint4 parameter. __glsl_extension(GL_KHR_shader_subgroup_ballot) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBallotBitCount($0)") +__target_intrinsic(glsl, "bitCount(subgroupBallot($0))") __target_intrinsic(cuda, "__popc(__ballot_sync(__activemask(), $0))") uint WaveActiveCountBits(bool value); @@ -2751,58 +2747,105 @@ uint WavePrefixCountBits(bool value); // https://github.com/microsoft/DirectX-Specs/blob/master/d3d/HLSL_ShaderModel6_5.md // TODO(JS): Looks like they need a mask parameter +__generic<T : __BuiltinType> +__target_intrinsic(hlsl) +__target_intrinsic(cuda, "_waveMatchScalar($0)") +uint4 WaveMatch(T value); +__generic<T : __BuiltinType, let N : int> +__target_intrinsic(hlsl) +__target_intrinsic(cuda, "_waveMatchMultiple($0)") +uint4 WaveMatch(vector<T,N> value); +__generic<T : __BuiltinType, let N : int, let M : int> +__target_intrinsic(hlsl) +__target_intrinsic(cuda, "_waveMatchMultiple($0)") +uint4 WaveMatch(matrix<T,N,M> value); + +__target_intrinsic(hlsl) +__target_intrinsic(cuda, "_popc(__ballot_sync(($1).x, $0) & _getLaneLtMask())") +uint WaveMultiPrefixCountBits(bool value, uint4 mask); + __generic<T : __BuiltinArithmeticType> +__target_intrinsic(hlsl) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExclusiveAnd($0)") -__target_intrinsic(cuda, "_wavePrefixAnd($0)") -T WaveMultiPrefixBitAnd(T expr); +//__target_intrinsic(glsl, "subgroupExclusiveAnd($0)") +__target_intrinsic(cuda, "_wavePrefixAnd($0, _getMultiPrefixMask(($1).x))") +T WaveMultiPrefixBitAnd(T expr, uint4 mask); +__target_intrinsic(hlsl) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) __target_intrinsic(glsl, "subgroupExclusiveAnd($0)") +__target_intrinsic(cuda, "_wavePrefixAndMultiple($0, _getMultiPrefixMask(($1).x))") __generic<T : __BuiltinArithmeticType, let N : int> -vector<T,N> WaveMultiPrefixBitAnd(vector<T,N> expr); +vector<T,N> WaveMultiPrefixBitAnd(vector<T,N> expr, uint4 mask); __generic<T : __BuiltinArithmeticType, let N : int, let M : int> -matrix<T,N,M> WaveMultiPrefixBitAnd(matrix<T,N,M> expr); +__target_intrinsic(hlsl) +__target_intrinsic(cuda, "_wavePrefixAndMultiple($0, _getMultiPrefixMask(($1).x))") +matrix<T,N,M> WaveMultiPrefixBitAnd(matrix<T,N,M> expr, uint4 mask); __generic<T : __BuiltinArithmeticType> +__target_intrinsic(hlsl) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExclusiveOr($0)") -__target_intrinsic(cuda, "_wavePrefixOr($0)") -T WaveMultiPrefixBitOr(T expr); +//__target_intrinsic(glsl, "subgroupExclusiveOr($0)") +__target_intrinsic(cuda, "_wavePrefixOr($0, _getMultiPrefixMask(($1).x))") +T WaveMultiPrefixBitOr(T expr, uint4 mask); __generic<T : __BuiltinArithmeticType, let N : int> +__target_intrinsic(hlsl) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExclusiveOr($0)") -vector<T,N> WaveMultiPrefixBitOr(vector<T,N> expr); +//__target_intrinsic(glsl, "subgroupExclusiveOr($0)") +__target_intrinsic(cuda, "_wavePrefixOrMultiple($0, _getMultiPrefixMask(($1).x))") +vector<T,N> WaveMultiPrefixBitOr(vector<T,N> expr, uint4 mask); __generic<T : __BuiltinArithmeticType, let N : int, let M : int> -matrix<T,N,M> WaveMultiPrefixBitOr(matrix<T,N,M> expr); +__target_intrinsic(hlsl) +__target_intrinsic(cuda, "_wavePrefixOrMultiple($0, _getMultiPrefixMask(($1).x))") +matrix<T,N,M> WaveMultiPrefixBitOr(matrix<T,N,M> expr, uint4 mask); __generic<T : __BuiltinArithmeticType> +__target_intrinsic(hlsl) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) __target_intrinsic(glsl, "subgroupExclusiveXor($0)") -__target_intrinsic(cuda, "_wavePrefixXor($0)") -T WaveMultiPrefixBitXor(T expr); +__target_intrinsic(cuda, "_wavePrefixXor($0, _getMultiPrefixMask(($1).x))") +T WaveMultiPrefixBitXor(T expr, uint4 mask); __generic<T : __BuiltinArithmeticType, let N : int> +__target_intrinsic(hlsl) __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) __target_intrinsic(glsl, "subgroupExclusiveXor($0)") -vector<T,N> WaveMultiPrefixBitXor(vector<T,N> expr); +__target_intrinsic(cuda, "_wavePrefixXorMultiple($0, _getMultiPrefixMask(($1).x))") +vector<T,N> WaveMultiPrefixBitXor(vector<T,N> expr, uint4 mask); __generic<T : __BuiltinArithmeticType, let N : int, let M : int> -matrix<T,N,M> WaveMultiPrefixBitXor(matrix<T,N,M> expr); - - -uint WaveMultiPrefixCountBits(bool value, uint4 mask); +__target_intrinsic(hlsl) +__target_intrinsic(cuda, "_wavePrefixXorMultiple($0, _getMultiPrefixMask(($1).x))") +matrix<T,N,M> WaveMultiPrefixBitXor(matrix<T,N,M> expr, uint4 mask); -__generic<T : __BuiltinArithmeticType> T WaveMultiPrefixProduct(T value, uint4 mask); -__generic<T : __BuiltinArithmeticType, let N : int> vector<T,N> WaveMultiPrefixProduct(vector<T,N> value, uint4 mask); -__generic<T : __BuiltinArithmeticType, let N : int, let M : int> matrix<T,N,M> WaveMultiPrefixProduct(matrix<T,N,M> value, uint4 mask); +__generic<T : __BuiltinArithmeticType> +__target_intrinsic(hlsl) +__target_intrinsic(cuda, "_wavePrefixProduct($0, _getMultiPrefixMask(($1).x))") +T WaveMultiPrefixProduct(T value, uint4 mask); +__generic<T : __BuiltinArithmeticType, let N : int> +__target_intrinsic(hlsl) +__target_intrinsic(cuda, "_wavePrefixProductMultiple($0, _getMultiPrefixMask(($1).x))") +vector<T,N> WaveMultiPrefixProduct(vector<T,N> value, uint4 mask); +__generic<T : __BuiltinArithmeticType, let N : int, let M : int> +__target_intrinsic(hlsl) +__target_intrinsic(cuda, "_wavePrefixProductMultiple($0, _getMultiPrefixMask(($1).x))") +matrix<T,N,M> WaveMultiPrefixProduct(matrix<T,N,M> value, uint4 mask); -__generic<T : __BuiltinArithmeticType> T WaveMultiPrefixSum(T value, uint4 mask); -__generic<T : __BuiltinArithmeticType, let N : int> vector<T,N> WaveMultiPrefixSum(vector<T,N> value, uint4 mask); -__generic<T : __BuiltinArithmeticType, let N : int, let M : int> matrix<T,N,M> WaveMultiPrefixSum(matrix<T,N,M> value, uint4 mask); +__generic<T : __BuiltinArithmeticType> +__target_intrinsic(hlsl) +__target_intrinsic(cuda, "_wavePrefixSum($0, _getMultiPrefixMask(($1).x))") +T WaveMultiPrefixSum(T value, uint4 mask); +__generic<T : __BuiltinArithmeticType, let N : int> +__target_intrinsic(hlsl) +__target_intrinsic(cuda, "_wavePrefixSumMultiple($0, _getMultiPrefixMask(($1).x))") +vector<T,N> WaveMultiPrefixSum(vector<T,N> value, uint4 mask); +__generic<T : __BuiltinArithmeticType, let N : int, let M : int> +__target_intrinsic(hlsl) +__target_intrinsic(cuda, "_wavePrefixSumMultiple($0, _getMultiPrefixMask(($1).x))") +matrix<T,N,M> WaveMultiPrefixSum(matrix<T,N,M> value, uint4 mask); // `typedef`s to help with the fact that HLSL has been sorta-kinda case insensitive at various points typedef Texture2D texture2D; diff --git a/source/slang/slang-profile-defs.h b/source/slang/slang-profile-defs.h index fc2722160..7066b5942 100644 --- a/source/slang/slang-profile-defs.h +++ b/source/slang/slang-profile-defs.h @@ -129,6 +129,8 @@ PROFILE(DX_Compute_6_0, cs_6_0, Compute, DX_6_0) PROFILE(DX_Compute_6_1, cs_6_1, Compute, DX_6_1) PROFILE(DX_Compute_6_2, cs_6_2, Compute, DX_6_2) PROFILE(DX_Compute_6_3, cs_6_3, Compute, DX_6_3) +PROFILE(DX_Compute_6_4, cs_6_4, Compute, DX_6_4) +PROFILE(DX_Compute_6_5, cs_6_5, Compute, DX_6_5) PROFILE(DX_Domain_5_0, ds_5_0, Domain, DX_5_0) PROFILE(DX_Domain_5_1, ds_5_1, Domain, DX_5_1) diff --git a/tests/hlsl-intrinsic/wave-multi-prefix.slang b/tests/hlsl-intrinsic/wave-multi-prefix.slang new file mode 100644 index 000000000..3eee16e31 --- /dev/null +++ b/tests/hlsl-intrinsic/wave-multi-prefix.slang @@ -0,0 +1,26 @@ +//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute +//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-slang -compute +// We need SM6.5 for these tests +// Disable because version of dxc we are currently using doesn't support SM6.5 +//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile sm_6_5 +// Disabled because we don't have GLSL intrinsics for these it seems +//DISABLE_TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute +//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(8, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + int idx = int(dispatchThreadID.x); + + int value = 0; + + uint4 mask = WaveMatch(true); + + // Scalar + value += WaveMultiPrefixSum(1 << idx, mask); + + outputBuffer[idx] = value; +}
\ No newline at end of file diff --git a/tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt b/tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt new file mode 100644 index 000000000..6ec6deeea --- /dev/null +++ b/tests/hlsl-intrinsic/wave-multi-prefix.slang.expected.txt @@ -0,0 +1,8 @@ +0 +1 +3 +7 +F +1F +3F +7F |
