summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjsmall-nvidia <jsmall@nvidia.com>2020-03-10 16:43:41 -0400
committerGitHub <noreply@github.com>2020-03-10 16:43:41 -0400
commitb380b1af6ba6f5f58e3841c2a5b14db7ee8c372d (patch)
tree2013ac90c39ee20e25bd08513271b5e5538dab15
parenta10d9cd8767e88a064719d71cc97144ba8b112d1 (diff)
Wave Prefix Product (#1270)
* Fix some typos. * Add wave-prefix-sum.slang test * First pass at implementing prefixSum. * Small improvments to prefixSum CUDA. * Small improvement to prefix sum. * Enable prefix sum in stdlib. * Wave prefix product without using a divide. * Split out SM6.5 Wave intrinsics. Template mechanism for do prefix calculations.
-rw-r--r--prelude/slang-cuda-prelude.h116
-rw-r--r--source/slang/hlsl.meta.slang86
-rw-r--r--tests/hlsl-intrinsic/wave-prefix-product.slang16
-rw-r--r--tests/hlsl-intrinsic/wave-prefix-product.slang.expected.txt8
4 files changed, 169 insertions, 57 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index 6f2122934..457fb4246 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -534,6 +534,7 @@ struct WaveOpXor
{
__inline__ __device__ static T getInitial(T a) { return 0; }
__inline__ __device__ static T doOp(T a, T b) { return a ^ b; }
+ __inline__ __device__ static T doInverse(T a, T b) { return a ^ b; }
};
template <typename T>
@@ -541,6 +542,7 @@ struct WaveOpAdd
{
__inline__ __device__ static T getInitial(T a) { return 0; }
__inline__ __device__ static T doOp(T a, T b) { return a + b; }
+ __inline__ __device__ static T doInverse(T a, T b) { return a - b; }
};
template <typename T>
@@ -548,6 +550,9 @@ struct WaveOpMul
{
__inline__ __device__ static T getInitial(T a) { return T(1); }
__inline__ __device__ static T doOp(T a, T b) { return a * b; }
+ // Using this inverse for int is probably undesirable - because in general it requires T to have more precision
+ // There is also a performance aspect to it, where divides are generally significantly slower
+ __inline__ __device__ static T doInverse(T a, T b) { return a / b; }
};
template <typename T>
@@ -823,46 +828,121 @@ __inline__ __device__ T _waveReadLaneAtMultiple(T inVal, int lane)
return outVal;
}
-__device__ int _wavePrefixSum(int val)
+// Scalar
+
+// Invertable means that when we get to the end of the reduce, we can remove val (to make exclusive), using
+// the inverse of the op.
+template <typename INTF, typename T>
+__device__ T _wavePrefixInvertableScalar(T val)
{
const int mask = __activemask();
const int offsetSize = _waveCalcPow2Offset(mask);
const int laneId = _getLaneId();
+ T result;
if (offsetSize > 0)
{
- int sum = val;
+ // Sum is calculated inclusive of this lanes value
+ result = val;
for (int i = 1; i < offsetSize; i += i)
{
- const int readVal = __shfl_up_sync(mask, sum, i, offsetSize);
+ const T readVal = __shfl_up_sync(mask, result, i, offsetSize);
if (laneId >= i)
{
- sum += readVal;
+ result = INTF::doOp(result, readVal);
}
}
- return sum - val;
+ // Remove val from the result, by applyin inverse
+ result = INTF::doInverse(result, val);
}
else
{
- int result = 0;
- int remaining = mask;
- while (remaining)
+ result = INTF::getInitial(val);
+ if (!_waveIsSingleLane(mask))
{
- const int laneBit = remaining & -remaining;
- // Get the sourceLane
- const int srcLane = __ffs(laneBit) - 1;
- // Broadcast (can also broadcast to self)
- int readValue = __shfl_sync(mask, val, srcLane);
- // Only accumulate if srcLane is less than this lane
- if (srcLane < laneId)
+ int remaining = mask;
+ while (remaining)
{
- result += readValue;
+ const int laneBit = remaining & -remaining;
+ // Get the sourceLane
+ const int srcLane = __ffs(laneBit) - 1;
+ // Broadcast (can also broadcast to self)
+ const T readValue = __shfl_sync(mask, val, srcLane);
+ // Only accumulate if srcLane is less than this lane
+ if (srcLane < laneId)
+ {
+ result = INTF::doOp(result, readValue);
+ }
+ remaining &= ~laneBit;
+ }
+ }
+ }
+ return result;
+}
+
+// This implementation separately tracks the value to be propogated, and the value
+// that is the final result
+template <typename INTF, typename T>
+__device__ T _wavePrefixScalar(T val)
+{
+ const int mask = __activemask();
+ const int offsetSize = _waveCalcPow2Offset(mask);
+
+ const int laneId = _getLaneId();
+ T result = INTF::getInitial(val);
+ if (offsetSize > 0)
+ {
+ // For transmitted value we will do it inclusively with this lanes value
+ // For the result we do not include the lanes value. This means an extra multiply for each iteration
+ // but means we don't need to have a divide at the end and also removes overflow issues in that scenario.
+ for (int i = 1; i < offsetSize; i += i)
+ {
+ const T readVal = __shfl_up_sync(mask, val, i, offsetSize);
+ if (laneId >= i)
+ {
+ result = INTF::doOp(result, readVal);
+ val = INTF::doOp(val, readVal);
+ }
+ }
+ }
+ else
+ {
+ if (!_waveIsSingleLane(mask))
+ {
+ int remaining = mask;
+ while (remaining)
+ {
+ const int laneBit = remaining & -remaining;
+ // Get the sourceLane
+ const int srcLane = __ffs(laneBit) - 1;
+ // Broadcast (can also broadcast to self)
+ const T readValue = __shfl_sync(mask, val, srcLane);
+ // Only accumulate if srcLane is less than this lane
+ if (srcLane < laneId)
+ {
+ result = INTF::doOp(result, readValue);
+ }
+ remaining &= ~laneBit;
}
- remaining &= ~laneBit;
}
- return result;
}
+ return result;
}
+
+template <typename T>
+__inline__ __device__ T _wavePrefixProduct(T val) { return _wavePrefixScalar<WaveOpMul<T>, T>(val); }
+
+template <typename T>
+__inline__ __device__ T _wavePrefixSum(T val) { return _wavePrefixInvertableScalar<WaveOpAdd<T>, T>(val); }
+
+template <typename T>
+__inline__ __device__ T _wavePrefixAnd(T val) { return _wavePrefixScalar<WaveOpAnd<T>, T>(val); }
+
+template <typename T>
+__inline__ __device__ T _wavePrefixOr(T val) { return _wavePrefixScalar<WaveOpOr<T>, T>(val); }
+
+template <typename T>
+__inline__ __device__ T _wavePrefixXor(T val) { return _wavePrefixInvertableScalar<WaveOpXor<T>, T>(val); }
/* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index b43cd009f..20158c1b1 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -2498,6 +2498,7 @@ __generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__target_intrinsic(glsl, "subgroupExclusiveMul($0)")
+__target_intrinsic(cuda, "_wavePrefixProduct($0)")
T WavePrefixProduct(T expr);
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
@@ -2521,10 +2522,54 @@ vector<T,N> WavePrefixSum(vector<T,N> expr);
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr);
+__generic<T : __BuiltinType>
+__glsl_extension(GL_KHR_shader_subgroup_ballot)
+__spirv_version(1.3)
+__target_intrinsic(glsl, "subgroupBroadcastFirst($0)")
+__target_intrinsic(cuda, "_waveReadFirst($0)")
+T WaveReadLaneFirst(T expr);
+__generic<T : __BuiltinType, let N : int>
+__glsl_extension(GL_KHR_shader_subgroup_ballot)
+__spirv_version(1.3)
+__target_intrinsic(glsl, "subgroupBroadcastFirst($0)")
+__target_intrinsic(cuda, "_waveReadFirstMultiple($0)")
+vector<T,N> WaveReadLaneFirst(vector<T,N> expr);
+__generic<T : __BuiltinType, let N : int, let M : int>
+__target_intrinsic(cuda, "_waveReadFirstMultiple($0)")
+matrix<T,N,M> WaveReadLaneFirst(matrix<T,N,M> expr);
+
+// NOTE! On GLSL based targets the lane index *must* be a compile time expression!
+// See https://github.com/KhronosGroup/GLSL/blob/master/extensions/khr/GL_KHR_shader_subgroup.txt
+__generic<T : __BuiltinType>
+__glsl_extension(GL_KHR_shader_subgroup_ballot)
+__spirv_version(1.3)
+__target_intrinsic(glsl, "subgroupBroadcast($0, $1)")
+__target_intrinsic(cuda, "__shfl_sync(__activemask(), $0, $1)")
+T WaveReadLaneAt(T value, int lane);
+__generic<T : __BuiltinType, let N : int>
+__spirv_version(1.3)
+__target_intrinsic(glsl, "subgroupBroadcast($0, $1)")
+__target_intrinsic(cuda, "_waveReadLaneAtMultiple($0, $1)")
+vector<T,N> WaveReadLaneAt(vector<T,N> value, int lane);
+__generic<T : __BuiltinType, let N : int, let M : int>
+__target_intrinsic(cuda, "_waveReadLaneAtMultiple($0, $1)")
+matrix<T,N,M> WaveReadLaneAt(matrix<T,N,M> value, int lane);
+
+__glsl_extension(GL_KHR_shader_subgroup_ballot)
+__spirv_version(1.3)
+__target_intrinsic(glsl, "subgroupBallotExclusiveBitCount(subgroupBallot($0))")
+__target_intrinsic(cuda, "__popc(__ballot_sync(__activemask(), $0) & _getLaneLtMask())")
+uint WavePrefixCountBits(bool value);
+
+// Shader model 6.5 stuff
+// https://github.com/microsoft/DirectX-Specs/blob/master/d3d/HLSL_ShaderModel6_5.md
+// TODO(JS): Looks like they need a mask parameter
+
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__target_intrinsic(glsl, "subgroupExclusiveAnd($0)")
+__target_intrinsic(cuda, "_wavePrefixAnd($0)")
T WaveMultiPrefixBitAnd(T expr);
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
@@ -2538,6 +2583,7 @@ __generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__target_intrinsic(glsl, "subgroupExclusiveOr($0)")
+__target_intrinsic(cuda, "_wavePrefixOr($0)")
T WaveMultiPrefixBitOr(T expr);
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
@@ -2551,6 +2597,7 @@ __generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
__target_intrinsic(glsl, "subgroupExclusiveXor($0)")
+__target_intrinsic(cuda, "_wavePrefixXor($0)")
T WaveMultiPrefixBitXor(T expr);
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
@@ -2560,11 +2607,6 @@ vector<T,N> WaveMultiPrefixBitXor(vector<T,N> expr);
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
matrix<T,N,M> WaveMultiPrefixBitXor(matrix<T,N,M> expr);
-__glsl_extension(GL_KHR_shader_subgroup_ballot)
-__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupBallotExclusiveBitCount(subgroupBallot($0))")
-__target_intrinsic(cuda, "__popc(__ballot_sync(__activemask(), $0) & _getLaneLtMask())")
-uint WavePrefixCountBits(bool value);
uint WaveMultiPrefixCountBits(bool value, uint4 mask);
@@ -2576,40 +2618,6 @@ __generic<T : __BuiltinArithmeticType> T WaveMultiPrefixSum(T value, uint4 mask)
__generic<T : __BuiltinArithmeticType, let N : int> vector<T,N> WaveMultiPrefixSum(vector<T,N> value, uint4 mask);
__generic<T : __BuiltinArithmeticType, let N : int, let M : int> matrix<T,N,M> WaveMultiPrefixSum(matrix<T,N,M> value, uint4 mask);
-__generic<T : __BuiltinType>
-__glsl_extension(GL_KHR_shader_subgroup_ballot)
-__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupBroadcastFirst($0)")
-__target_intrinsic(cuda, "_waveReadFirst($0)")
-T WaveReadLaneFirst(T expr);
-__generic<T : __BuiltinType, let N : int>
-__glsl_extension(GL_KHR_shader_subgroup_ballot)
-__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupBroadcastFirst($0)")
-__target_intrinsic(cuda, "_waveReadFirstMultiple($0)")
-vector<T,N> WaveReadLaneFirst(vector<T,N> expr);
-__generic<T : __BuiltinType, let N : int, let M : int>
-__target_intrinsic(cuda, "_waveReadFirstMultiple($0)")
-matrix<T,N,M> WaveReadLaneFirst(matrix<T,N,M> expr);
-
-// NOTE! On GLSL based targets the lane index *must* be a compile time expression!
-// See https://github.com/KhronosGroup/GLSL/blob/master/extensions/khr/GL_KHR_shader_subgroup.txt
-__generic<T : __BuiltinType>
-__glsl_extension(GL_KHR_shader_subgroup_ballot)
-__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupBroadcast($0, $1)")
-__target_intrinsic(cuda, "__shfl_sync(__activemask(), $0, $1)")
-T WaveReadLaneAt(T value, int lane);
-__generic<T : __BuiltinType, let N : int>
-__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupBroadcast($0, $1)")
-__target_intrinsic(cuda, "_waveReadLaneAtMultiple($0, $1)")
-vector<T,N> WaveReadLaneAt(vector<T,N> value, int lane);
-__generic<T : __BuiltinType, let N : int, let M : int>
-__target_intrinsic(cuda, "_waveReadLaneAtMultiple($0, $1)")
-matrix<T,N,M> WaveReadLaneAt(matrix<T,N,M> value, int lane);
-
-
// `typedef`s to help with the fact that HLSL has been sorta-kinda case insensitive at various points
typedef Texture2D texture2D;
diff --git a/tests/hlsl-intrinsic/wave-prefix-product.slang b/tests/hlsl-intrinsic/wave-prefix-product.slang
new file mode 100644
index 000000000..bc324ed7d
--- /dev/null
+++ b/tests/hlsl-intrinsic/wave-prefix-product.slang
@@ -0,0 +1,16 @@
+//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute
+//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-slang -compute
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute
+//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(8, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ int idx = int(dispatchThreadID.x);
+ int val = WavePrefixProduct(idx + 1);
+ outputBuffer[idx] = val;
+} \ No newline at end of file
diff --git a/tests/hlsl-intrinsic/wave-prefix-product.slang.expected.txt b/tests/hlsl-intrinsic/wave-prefix-product.slang.expected.txt
new file mode 100644
index 000000000..03cb63ab9
--- /dev/null
+++ b/tests/hlsl-intrinsic/wave-prefix-product.slang.expected.txt
@@ -0,0 +1,8 @@
+1
+1
+2
+6
+18
+78
+2D0
+13B0