summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--prelude/slang-cuda-prelude.h163
-rw-r--r--source/slang/hlsl.meta.slang4
-rw-r--r--tests/hlsl-intrinsic/wave-prefix-product.slang12
-rw-r--r--tests/hlsl-intrinsic/wave-prefix-product.slang.expected.txt16
-rw-r--r--tests/hlsl-intrinsic/wave-prefix-sum.slang11
-rw-r--r--tests/hlsl-intrinsic/wave-prefix-sum.slang.expected.txt14
6 files changed, 196 insertions, 24 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index 0a2ec088b..6a1d87183 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -919,6 +919,7 @@ __device__ T _wavePrefixInvertableScalar(T val)
return result;
}
+
// This implementation separately tracks the value to be propogated, and the value
// that is the final result
template <typename INTF, typename T>
@@ -967,6 +968,151 @@ __device__ T _wavePrefixScalar(T val)
}
return result;
}
+
+
+template <typename INTF, typename T, size_t COUNT>
+__device__ T _copy(T* dst, const T* src)
+{
+ for (size_t j = 0; j < COUNT; ++j)
+ {
+ dst[j] = src[j];
+ }
+}
+
+
+template <typename INTF, typename T, size_t COUNT>
+__device__ T _doInverse(T* inOut, const T* val)
+{
+ for (size_t j = 0; j < COUNT; ++j)
+ {
+ inOut[j] = INTF::doInverse(inOut[j], val[j]);
+ }
+}
+
+template <typename INTF, typename T, size_t COUNT>
+__device__ T _setInitial(T* out, const T* val)
+{
+ for (size_t j = 0; j < COUNT; ++j)
+ {
+ out[j] = INTF::getInitial(val[j]);
+ }
+}
+
+template <typename INTF, typename T, size_t COUNT>
+__device__ T _wavePrefixInvertableMultiple(T* val)
+{
+ const int mask = __activemask();
+ const int offsetSize = _waveCalcPow2Offset(mask);
+
+ const int laneId = _getLaneId();
+ T originalVal[COUNT];
+ _copy<INTF, T, COUNT>(originalVal, val);
+
+ if (offsetSize > 0)
+ {
+ // Sum is calculated inclusive of this lanes value
+ for (int i = 1; i < offsetSize; i += i)
+ {
+ // TODO(JS): Note that here I don't split the laneId outside so it's only tested once.
+ // This may be better but it would also mean that there would be shfl between lanes
+ // that are on different (albeit identical) instructions. So this seems more likely to
+ // work as expected with everything in lock step.
+ for (size_t j = 0; j < COUNT; ++j)
+ {
+ const T readVal = __shfl_up_sync(mask, val[j], i, offsetSize);
+ if (laneId >= i)
+ {
+ val[j] = INTF::doOp(val[j], readVal);
+ }
+ }
+ }
+ // Remove originalVal from the result, by applyin inverse
+ _doInverse<INTF, T, COUNT>(val, originalVal);
+ }
+ else
+ {
+ _setInitial<INTF, T, COUNT>(val, val);
+ if (!_waveIsSingleLane(mask))
+ {
+ int remaining = mask;
+ while (remaining)
+ {
+ const int laneBit = remaining & -remaining;
+ // Get the sourceLane
+ const int srcLane = __ffs(laneBit) - 1;
+
+ for (size_t j = 0; j < COUNT; ++j)
+ {
+ // Broadcast (can also broadcast to self)
+ const T readValue = __shfl_sync(mask, originalVal[j], srcLane);
+ // Only accumulate if srcLane is less than this lane
+ if (srcLane < laneId)
+ {
+ val[j] = INTF::doOp(val[j], readValue);
+ }
+ remaining &= ~laneBit;
+ }
+ }
+ }
+ }
+}
+
+template <typename INTF, typename T, size_t COUNT>
+__device__ T _wavePrefixMultiple(T* val)
+{
+ const int mask = __activemask();
+ const int offsetSize = _waveCalcPow2Offset(mask);
+
+ const int laneId = _getLaneId();
+
+ T work[COUNT];
+ _copy<INTF, T, COUNT>(work, val);
+ _setInitial<INTF, T, COUNT>(val, val);
+
+ if (offsetSize > 0)
+ {
+ // For transmitted value we will do it inclusively with this lanes value
+ // For the result we do not include the lanes value. This means an extra op for each iteration
+ // but means we don't need to have a divide at the end and also removes overflow issues in that scenario.
+ for (int i = 1; i < offsetSize; i += i)
+ {
+ for (size_t j = 0; j < COUNT; ++j)
+ {
+ const T readVal = __shfl_up_sync(mask, work[j], i, offsetSize);
+ if (laneId >= i)
+ {
+ work[j] = INTF::doOp(work[j], readVal);
+ val[j] = INTF::doOp(val[j], readVal);
+ }
+ }
+ }
+ }
+ else
+ {
+ if (!_waveIsSingleLane(mask))
+ {
+ int remaining = mask;
+ while (remaining)
+ {
+ const int laneBit = remaining & -remaining;
+ // Get the sourceLane
+ const int srcLane = __ffs(laneBit) - 1;
+
+ for (size_t j = 0; j < COUNT; ++j)
+ {
+ // Broadcast (can also broadcast to self)
+ const T readValue = __shfl_sync(mask, work[j], srcLane);
+ // Only accumulate if srcLane is less than this lane
+ if (srcLane < laneId)
+ {
+ val[j] = INTF::doOp(val[j], readValue);
+ }
+ }
+ remaining &= ~laneBit;
+ }
+ }
+ }
+}
template <typename T>
__inline__ __device__ T _wavePrefixProduct(T val) { return _wavePrefixScalar<WaveOpMul<T>, T>(val); }
@@ -975,13 +1121,20 @@ template <typename T>
__inline__ __device__ T _wavePrefixSum(T val) { return _wavePrefixInvertableScalar<WaveOpAdd<T>, T>(val); }
template <typename T>
-__inline__ __device__ T _wavePrefixAnd(T val) { return _wavePrefixScalar<WaveOpAnd<T>, T>(val); }
-
-template <typename T>
-__inline__ __device__ T _wavePrefixOr(T val) { return _wavePrefixScalar<WaveOpOr<T>, T>(val); }
+__inline__ __device__ T _wavePrefixProductMultiple(T val)
+{
+ typedef typename ElementTypeTrait<T>::Type ElemType;
+ _wavePrefixInvertableMultiple<WaveOpMul<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>((ElemType*)&val);
+ return val;
+}
template <typename T>
-__inline__ __device__ T _wavePrefixXor(T val) { return _wavePrefixInvertableScalar<WaveOpXor<T>, T>(val); }
+__inline__ __device__ T _wavePrefixSumMultiple(T val)
+{
+ typedef typename ElementTypeTrait<T>::Type ElemType;
+ _wavePrefixMultiple<WaveOpAdd<ElemType>, ElemType, sizeof(T) / sizeof(ElemType)>((ElemType*)&val);
+ return val;
+}
/* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 03496ccc8..2b556c10b 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -2677,8 +2677,10 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__target_intrinsic(glsl, "subgroupExclusiveMul($0)")
+__target_intrinsic(cuda, "_wavePrefixProductMultiple($0)")
vector<T,N> WavePrefixProduct(vector<T,N> expr);
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
+__target_intrinsic(cuda, "_wavePrefixProductMultiple($0)")
matrix<T,N,M> WavePrefixProduct(matrix<T,N,M> expr);
__generic<T : __BuiltinArithmeticType>
@@ -2691,8 +2693,10 @@ __generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__target_intrinsic(glsl, "subgroupExclusiveAdd($0)")
+__target_intrinsic(cuda, "_wavePrefixSumMultiple($0)")
vector<T,N> WavePrefixSum(vector<T,N> expr);
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
+__target_intrinsic(cuda, "_wavePrefixSumMultiple($0)")
matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr);
__generic<T : __BuiltinType>
diff --git a/tests/hlsl-intrinsic/wave-prefix-product.slang b/tests/hlsl-intrinsic/wave-prefix-product.slang
index bc324ed7d..a56912616 100644
--- a/tests/hlsl-intrinsic/wave-prefix-product.slang
+++ b/tests/hlsl-intrinsic/wave-prefix-product.slang
@@ -11,6 +11,14 @@ RWStructuredBuffer<int> outputBuffer;
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
int idx = int(dispatchThreadID.x);
- int val = WavePrefixProduct(idx + 1);
- outputBuffer[idx] = val;
+
+ float2 v1 = float2(1, idx + 1);
+
+ int r0 = WavePrefixProduct(idx + 1);
+ float2 r1 = WavePrefixProduct(v1);
+
+ int r2 = int(r1.x) + int(r1.y) - 1;
+
+ outputBuffer[idx] = r0 + (r2 << 16);
+
} \ No newline at end of file
diff --git a/tests/hlsl-intrinsic/wave-prefix-product.slang.expected.txt b/tests/hlsl-intrinsic/wave-prefix-product.slang.expected.txt
index 03cb63ab9..1b233efaf 100644
--- a/tests/hlsl-intrinsic/wave-prefix-product.slang.expected.txt
+++ b/tests/hlsl-intrinsic/wave-prefix-product.slang.expected.txt
@@ -1,8 +1,8 @@
-1
-1
-2
-6
-18
-78
-2D0
-13B0
+10001
+10001
+20002
+60006
+180018
+780078
+2D002D0
+13B013B0
diff --git a/tests/hlsl-intrinsic/wave-prefix-sum.slang b/tests/hlsl-intrinsic/wave-prefix-sum.slang
index f8d9bb560..343a7afbd 100644
--- a/tests/hlsl-intrinsic/wave-prefix-sum.slang
+++ b/tests/hlsl-intrinsic/wave-prefix-sum.slang
@@ -11,6 +11,13 @@ RWStructuredBuffer<int> outputBuffer;
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
int idx = int(dispatchThreadID.x);
- int val = WavePrefixSum(1 << idx);
- outputBuffer[idx] = val;
+
+ float2 v1 = float2(1, 1 << idx);
+
+ int r0 = WavePrefixSum(1 << idx);
+ float2 r1 = WavePrefixSum(v1);
+
+ int r2 = int(r1.x) + int(r1.y) - idx;
+
+ outputBuffer[idx] = r0 + (r2 << 16);
} \ No newline at end of file
diff --git a/tests/hlsl-intrinsic/wave-prefix-sum.slang.expected.txt b/tests/hlsl-intrinsic/wave-prefix-sum.slang.expected.txt
index 6ec6deeea..4b4230415 100644
--- a/tests/hlsl-intrinsic/wave-prefix-sum.slang.expected.txt
+++ b/tests/hlsl-intrinsic/wave-prefix-sum.slang.expected.txt
@@ -1,8 +1,8 @@
0
-1
-3
-7
-F
-1F
-3F
-7F
+10001
+30003
+70007
+F000F
+1F001F
+3F003F
+7F007F