diff options
| -rw-r--r-- | prelude/slang-cuda-prelude.h | 163 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 4 | ||||
| -rw-r--r-- | tests/hlsl-intrinsic/wave-prefix-product.slang | 12 | ||||
| -rw-r--r-- | tests/hlsl-intrinsic/wave-prefix-product.slang.expected.txt | 16 | ||||
| -rw-r--r-- | tests/hlsl-intrinsic/wave-prefix-sum.slang | 11 | ||||
| -rw-r--r-- | tests/hlsl-intrinsic/wave-prefix-sum.slang.expected.txt | 14 |
6 files changed, 196 insertions, 24 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index 0a2ec088b..6a1d87183 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -919,6 +919,7 @@ __device__ T _wavePrefixInvertableScalar(T val) return result; } + // This implementation separately tracks the value to be propogated, and the value // that is the final result template <typename INTF, typename T> @@ -967,6 +968,151 @@ __device__ T _wavePrefixScalar(T val) } return result; } + + +template <typename INTF, typename T, size_t COUNT> +__device__ T _copy(T* dst, const T* src) +{ + for (size_t j = 0; j < COUNT; ++j) + { + dst[j] = src[j]; + } +} + + +template <typename INTF, typename T, size_t COUNT> +__device__ T _doInverse(T* inOut, const T* val) +{ + for (size_t j = 0; j < COUNT; ++j) + { + inOut[j] = INTF::doInverse(inOut[j], val[j]); + } +} + +template <typename INTF, typename T, size_t COUNT> +__device__ T _setInitial(T* out, const T* val) +{ + for (size_t j = 0; j < COUNT; ++j) + { + out[j] = INTF::getInitial(val[j]); + } +} + +template <typename INTF, typename T, size_t COUNT> +__device__ T _wavePrefixInvertableMultiple(T* val) +{ + const int mask = __activemask(); + const int offsetSize = _waveCalcPow2Offset(mask); + + const int laneId = _getLaneId(); + T originalVal[COUNT]; + _copy<INTF, T, COUNT>(originalVal, val); + + if (offsetSize > 0) + { + // Sum is calculated inclusive of this lanes value + for (int i = 1; i < offsetSize; i += i) + { + // TODO(JS): Note that here I don't split the laneId outside so it's only tested once. + // This may be better but it would also mean that there would be shfl between lanes + // that are on different (albeit identical) instructions. So this seems more likely to + // work as expected with everything in lock step. + for (size_t j = 0; j < COUNT; ++j) + { + const T readVal = __shfl_up_sync(mask, val[j], i, offsetSize); + if (laneId >= i) + { + val[j] = INTF::doOp(val[j], readVal); + } + } + } + // Remove originalVal from the result, by applyin inverse + _doInverse<INTF, T, COUNT>(val, originalVal); + } + else + { + _setInitial<INTF, T, COUNT>(val, val); + if (!_waveIsSingleLane(mask)) + { + int remaining = mask; + while (remaining) + { + const int laneBit = remaining & -remaining; + // Get the sourceLane + const int srcLane = __ffs(laneBit) - 1; + + for (size_t j = 0; j < COUNT; ++j) + { + // Broadcast (can also broadcast to self) + const T readValue = __shfl_sync(mask, originalVal[j], srcLane); + // Only accumulate if srcLane is less than this lane + if (srcLane < laneId) + { + val[j] = INTF::doOp(val[j], readValue); + } + remaining &= ~laneBit; + } + } + } + } +} + +template <typename INTF, typename T, size_t COUNT> +__device__ T _wavePrefixMultiple(T* val) +{ + const int mask = __activemask(); + const int offsetSize = _waveCalcPow2Offset(mask); + + const int laneId = _getLaneId(); + + T work[COUNT]; + _copy<INTF, T, COUNT>(work, val); + _setInitial<INTF, T, COUNT>(val, val); + + if (offsetSize > 0) + { + // For transmitted value we will do it inclusively with this lanes value + // For the result we do not include the lanes value. This means an extra op for each iteration + // but means we don't need to have a divide at the end and also removes overflow issues in that scenario. + for (int i = 1; i < offsetSize; i += i) + { + for (size_t j = 0; j < COUNT; ++j) + { + const T readVal = __shfl_up_sync(mask, work[j], i, offsetSize); + if (laneId >= i) + { + work[j] = INTF::doOp(work[j], readVal); + val[j] = INTF::doOp(val[j], readVal); + } + } + } + } + else + { + if (!_waveIsSingleLane(mask)) + { + int remaining = mask; + while (remaining) + { + const int laneBit = remaining & -remaining; + // Get the sourceLane + const int srcLane = __ffs(laneBit) - 1; + + for (size_t j = 0; j < COUNT; ++j) + { + // Broadcast (can also broadcast to self) + const T readValue = __shfl_sync(mask, work[j], srcLane); + // Only accumulate if srcLane is less than this lane + if (srcLane < laneId) + { + val[j] = INTF::doOp(val[j], readValue); + } + } + remaining &= ~laneBit; + } + } + } +} template <typename T> __inline__ __device__ T _wavePrefixProduct(T val) { return _wavePrefixScalar<WaveOpMul<T>, T>(val); } @@ -975,13 +1121,20 @@ template <typename T> __inline__ __device__ T _wavePrefixSum(T val) { return _wavePrefixInvertableScalar<WaveOpAdd<T>, T>(val); } template <typename T> -__inline__ __device__ T _wavePrefixAnd(T val) { return _wavePrefixScalar<WaveOpAnd<T>, T>(val); } - -template <typename T> -__inline__ __device__ T _wavePrefixOr(T val) { return _wavePrefixScalar<WaveOpOr<T>, T>(val); } +__inline__ __device__ T _wavePrefixProductMultiple(T val) +{ + typedef typename ElementTypeTrait<T>::Type ElemType; + _wavePrefixInvertableMultiple<WaveOpMul<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>((ElemType*)&val); + return val; +} template <typename T> -__inline__ __device__ T _wavePrefixXor(T val) { return _wavePrefixInvertableScalar<WaveOpXor<T>, T>(val); } +__inline__ __device__ T _wavePrefixSumMultiple(T val) +{ + typedef typename ElementTypeTrait<T>::Type ElemType; + _wavePrefixMultiple<WaveOpAdd<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>((ElemType*)&val); + return val; +} /* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 03496ccc8..2b556c10b 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -2677,8 +2677,10 @@ __generic<T : __BuiltinArithmeticType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) __target_intrinsic(glsl, "subgroupExclusiveMul($0)") +__target_intrinsic(cuda, "_wavePrefixProductMultiple($0)") vector<T,N> WavePrefixProduct(vector<T,N> expr); __generic<T : __BuiltinArithmeticType, let N : int, let M : int> +__target_intrinsic(cuda, "_wavePrefixProductMultiple($0)") matrix<T,N,M> WavePrefixProduct(matrix<T,N,M> expr); __generic<T : __BuiltinArithmeticType> @@ -2691,8 +2693,10 @@ __generic<T : __BuiltinArithmeticType, let N : int> __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) __target_intrinsic(glsl, "subgroupExclusiveAdd($0)") +__target_intrinsic(cuda, "_wavePrefixSumMultiple($0)") vector<T,N> WavePrefixSum(vector<T,N> expr); __generic<T : __BuiltinArithmeticType, let N : int, let M : int> +__target_intrinsic(cuda, "_wavePrefixSumMultiple($0)") matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr); __generic<T : __BuiltinType> diff --git a/tests/hlsl-intrinsic/wave-prefix-product.slang b/tests/hlsl-intrinsic/wave-prefix-product.slang index bc324ed7d..a56912616 100644 --- a/tests/hlsl-intrinsic/wave-prefix-product.slang +++ b/tests/hlsl-intrinsic/wave-prefix-product.slang @@ -11,6 +11,14 @@ RWStructuredBuffer<int> outputBuffer; void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { int idx = int(dispatchThreadID.x); - int val = WavePrefixProduct(idx + 1); - outputBuffer[idx] = val; + + float2 v1 = float2(1, idx + 1); + + int r0 = WavePrefixProduct(idx + 1); + float2 r1 = WavePrefixProduct(v1); + + int r2 = int(r1.x) + int(r1.y) - 1; + + outputBuffer[idx] = r0 + (r2 << 16); + }
\ No newline at end of file diff --git a/tests/hlsl-intrinsic/wave-prefix-product.slang.expected.txt b/tests/hlsl-intrinsic/wave-prefix-product.slang.expected.txt index 03cb63ab9..1b233efaf 100644 --- a/tests/hlsl-intrinsic/wave-prefix-product.slang.expected.txt +++ b/tests/hlsl-intrinsic/wave-prefix-product.slang.expected.txt @@ -1,8 +1,8 @@ -1 -1 -2 -6 -18 -78 -2D0 -13B0 +10001 +10001 +20002 +60006 +180018 +780078 +2D002D0 +13B013B0 diff --git a/tests/hlsl-intrinsic/wave-prefix-sum.slang b/tests/hlsl-intrinsic/wave-prefix-sum.slang index f8d9bb560..343a7afbd 100644 --- a/tests/hlsl-intrinsic/wave-prefix-sum.slang +++ b/tests/hlsl-intrinsic/wave-prefix-sum.slang @@ -11,6 +11,13 @@ RWStructuredBuffer<int> outputBuffer; void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { int idx = int(dispatchThreadID.x); - int val = WavePrefixSum(1 << idx); - outputBuffer[idx] = val; + + float2 v1 = float2(1, 1 << idx); + + int r0 = WavePrefixSum(1 << idx); + float2 r1 = WavePrefixSum(v1); + + int r2 = int(r1.x) + int(r1.y) - idx; + + outputBuffer[idx] = r0 + (r2 << 16); }
\ No newline at end of file diff --git a/tests/hlsl-intrinsic/wave-prefix-sum.slang.expected.txt b/tests/hlsl-intrinsic/wave-prefix-sum.slang.expected.txt index 6ec6deeea..4b4230415 100644 --- a/tests/hlsl-intrinsic/wave-prefix-sum.slang.expected.txt +++ b/tests/hlsl-intrinsic/wave-prefix-sum.slang.expected.txt @@ -1,8 +1,8 @@ 0 -1 -3 -7 -F -1F -3F -7F +10001 +30003 +70007 +F000F +1F001F +3F003F +7F007F |
