From b380b1af6ba6f5f58e3841c2a5b14db7ee8c372d Mon Sep 17 00:00:00 2001 From: jsmall-nvidia Date: Tue, 10 Mar 2020 16:43:41 -0400 Subject: Wave Prefix Product (#1270) * Fix some typos. * Add wave-prefix-sum.slang test * First pass at implementing prefixSum. * Small improvments to prefixSum CUDA. * Small improvement to prefix sum. * Enable prefix sum in stdlib. * Wave prefix product without using a divide. * Split out SM6.5 Wave intrinsics. Template mechanism for do prefix calculations. --- prelude/slang-cuda-prelude.h | 116 +++++++++++++++++---- source/slang/hlsl.meta.slang | 86 ++++++++------- tests/hlsl-intrinsic/wave-prefix-product.slang | 16 +++ .../wave-prefix-product.slang.expected.txt | 8 ++ 4 files changed, 169 insertions(+), 57 deletions(-) create mode 100644 tests/hlsl-intrinsic/wave-prefix-product.slang create mode 100644 tests/hlsl-intrinsic/wave-prefix-product.slang.expected.txt diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index 6f2122934..457fb4246 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -534,6 +534,7 @@ struct WaveOpXor { __inline__ __device__ static T getInitial(T a) { return 0; } __inline__ __device__ static T doOp(T a, T b) { return a ^ b; } + __inline__ __device__ static T doInverse(T a, T b) { return a ^ b; } }; template @@ -541,6 +542,7 @@ struct WaveOpAdd { __inline__ __device__ static T getInitial(T a) { return 0; } __inline__ __device__ static T doOp(T a, T b) { return a + b; } + __inline__ __device__ static T doInverse(T a, T b) { return a - b; } }; template @@ -548,6 +550,9 @@ struct WaveOpMul { __inline__ __device__ static T getInitial(T a) { return T(1); } __inline__ __device__ static T doOp(T a, T b) { return a * b; } + // Using this inverse for int is probably undesirable - because in general it requires T to have more precision + // There is also a performance aspect to it, where divides are generally significantly slower + __inline__ __device__ static T doInverse(T a, T b) { return a / b; } }; template @@ -823,46 +828,121 @@ __inline__ __device__ T _waveReadLaneAtMultiple(T inVal, int lane) return outVal; } -__device__ int _wavePrefixSum(int val) +// Scalar + +// Invertable means that when we get to the end of the reduce, we can remove val (to make exclusive), using +// the inverse of the op. +template +__device__ T _wavePrefixInvertableScalar(T val) { const int mask = __activemask(); const int offsetSize = _waveCalcPow2Offset(mask); const int laneId = _getLaneId(); + T result; if (offsetSize > 0) { - int sum = val; + // Sum is calculated inclusive of this lanes value + result = val; for (int i = 1; i < offsetSize; i += i) { - const int readVal = __shfl_up_sync(mask, sum, i, offsetSize); + const T readVal = __shfl_up_sync(mask, result, i, offsetSize); if (laneId >= i) { - sum += readVal; + result = INTF::doOp(result, readVal); } } - return sum - val; + // Remove val from the result, by applyin inverse + result = INTF::doInverse(result, val); } else { - int result = 0; - int remaining = mask; - while (remaining) + result = INTF::getInitial(val); + if (!_waveIsSingleLane(mask)) { - const int laneBit = remaining & -remaining; - // Get the sourceLane - const int srcLane = __ffs(laneBit) - 1; - // Broadcast (can also broadcast to self) - int readValue = __shfl_sync(mask, val, srcLane); - // Only accumulate if srcLane is less than this lane - if (srcLane < laneId) + int remaining = mask; + while (remaining) { - result += readValue; + const int laneBit = remaining & -remaining; + // Get the sourceLane + const int srcLane = __ffs(laneBit) - 1; + // Broadcast (can also broadcast to self) + const T readValue = __shfl_sync(mask, val, srcLane); + // Only accumulate if srcLane is less than this lane + if (srcLane < laneId) + { + result = INTF::doOp(result, readValue); + } + remaining &= ~laneBit; + } + } + } + return result; +} + +// This implementation separately tracks the value to be propogated, and the value +// that is the final result +template +__device__ T _wavePrefixScalar(T val) +{ + const int mask = __activemask(); + const int offsetSize = _waveCalcPow2Offset(mask); + + const int laneId = _getLaneId(); + T result = INTF::getInitial(val); + if (offsetSize > 0) + { + // For transmitted value we will do it inclusively with this lanes value + // For the result we do not include the lanes value. This means an extra multiply for each iteration + // but means we don't need to have a divide at the end and also removes overflow issues in that scenario. + for (int i = 1; i < offsetSize; i += i) + { + const T readVal = __shfl_up_sync(mask, val, i, offsetSize); + if (laneId >= i) + { + result = INTF::doOp(result, readVal); + val = INTF::doOp(val, readVal); + } + } + } + else + { + if (!_waveIsSingleLane(mask)) + { + int remaining = mask; + while (remaining) + { + const int laneBit = remaining & -remaining; + // Get the sourceLane + const int srcLane = __ffs(laneBit) - 1; + // Broadcast (can also broadcast to self) + const T readValue = __shfl_sync(mask, val, srcLane); + // Only accumulate if srcLane is less than this lane + if (srcLane < laneId) + { + result = INTF::doOp(result, readValue); + } + remaining &= ~laneBit; } - remaining &= ~laneBit; } - return result; } + return result; } + +template +__inline__ __device__ T _wavePrefixProduct(T val) { return _wavePrefixScalar, T>(val); } + +template +__inline__ __device__ T _wavePrefixSum(T val) { return _wavePrefixInvertableScalar, T>(val); } + +template +__inline__ __device__ T _wavePrefixAnd(T val) { return _wavePrefixScalar, T>(val); } + +template +__inline__ __device__ T _wavePrefixOr(T val) { return _wavePrefixScalar, T>(val); } + +template +__inline__ __device__ T _wavePrefixXor(T val) { return _wavePrefixInvertableScalar, T>(val); } /* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index b43cd009f..20158c1b1 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -2498,6 +2498,7 @@ __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) __target_intrinsic(glsl, "subgroupExclusiveMul($0)") +__target_intrinsic(cuda, "_wavePrefixProduct($0)") T WavePrefixProduct(T expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) @@ -2521,10 +2522,54 @@ vector WavePrefixSum(vector expr); __generic matrix WavePrefixSum(matrix expr); +__generic +__glsl_extension(GL_KHR_shader_subgroup_ballot) +__spirv_version(1.3) +__target_intrinsic(glsl, "subgroupBroadcastFirst($0)") +__target_intrinsic(cuda, "_waveReadFirst($0)") +T WaveReadLaneFirst(T expr); +__generic +__glsl_extension(GL_KHR_shader_subgroup_ballot) +__spirv_version(1.3) +__target_intrinsic(glsl, "subgroupBroadcastFirst($0)") +__target_intrinsic(cuda, "_waveReadFirstMultiple($0)") +vector WaveReadLaneFirst(vector expr); +__generic +__target_intrinsic(cuda, "_waveReadFirstMultiple($0)") +matrix WaveReadLaneFirst(matrix expr); + +// NOTE! On GLSL based targets the lane index *must* be a compile time expression! +// See https://github.com/KhronosGroup/GLSL/blob/master/extensions/khr/GL_KHR_shader_subgroup.txt +__generic +__glsl_extension(GL_KHR_shader_subgroup_ballot) +__spirv_version(1.3) +__target_intrinsic(glsl, "subgroupBroadcast($0, $1)") +__target_intrinsic(cuda, "__shfl_sync(__activemask(), $0, $1)") +T WaveReadLaneAt(T value, int lane); +__generic +__spirv_version(1.3) +__target_intrinsic(glsl, "subgroupBroadcast($0, $1)") +__target_intrinsic(cuda, "_waveReadLaneAtMultiple($0, $1)") +vector WaveReadLaneAt(vector value, int lane); +__generic +__target_intrinsic(cuda, "_waveReadLaneAtMultiple($0, $1)") +matrix WaveReadLaneAt(matrix value, int lane); + +__glsl_extension(GL_KHR_shader_subgroup_ballot) +__spirv_version(1.3) +__target_intrinsic(glsl, "subgroupBallotExclusiveBitCount(subgroupBallot($0))") +__target_intrinsic(cuda, "__popc(__ballot_sync(__activemask(), $0) & _getLaneLtMask())") +uint WavePrefixCountBits(bool value); + +// Shader model 6.5 stuff +// https://github.com/microsoft/DirectX-Specs/blob/master/d3d/HLSL_ShaderModel6_5.md +// TODO(JS): Looks like they need a mask parameter + __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) __target_intrinsic(glsl, "subgroupExclusiveAnd($0)") +__target_intrinsic(cuda, "_wavePrefixAnd($0)") T WaveMultiPrefixBitAnd(T expr); __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) @@ -2538,6 +2583,7 @@ __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) __target_intrinsic(glsl, "subgroupExclusiveOr($0)") +__target_intrinsic(cuda, "_wavePrefixOr($0)") T WaveMultiPrefixBitOr(T expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) @@ -2551,6 +2597,7 @@ __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) __target_intrinsic(glsl, "subgroupExclusiveXor($0)") +__target_intrinsic(cuda, "_wavePrefixXor($0)") T WaveMultiPrefixBitXor(T expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) @@ -2560,11 +2607,6 @@ vector WaveMultiPrefixBitXor(vector expr); __generic matrix WaveMultiPrefixBitXor(matrix expr); -__glsl_extension(GL_KHR_shader_subgroup_ballot) -__spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBallotExclusiveBitCount(subgroupBallot($0))") -__target_intrinsic(cuda, "__popc(__ballot_sync(__activemask(), $0) & _getLaneLtMask())") -uint WavePrefixCountBits(bool value); uint WaveMultiPrefixCountBits(bool value, uint4 mask); @@ -2576,40 +2618,6 @@ __generic T WaveMultiPrefixSum(T value, uint4 mask) __generic vector WaveMultiPrefixSum(vector value, uint4 mask); __generic matrix WaveMultiPrefixSum(matrix value, uint4 mask); -__generic -__glsl_extension(GL_KHR_shader_subgroup_ballot) -__spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBroadcastFirst($0)") -__target_intrinsic(cuda, "_waveReadFirst($0)") -T WaveReadLaneFirst(T expr); -__generic -__glsl_extension(GL_KHR_shader_subgroup_ballot) -__spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBroadcastFirst($0)") -__target_intrinsic(cuda, "_waveReadFirstMultiple($0)") -vector WaveReadLaneFirst(vector expr); -__generic -__target_intrinsic(cuda, "_waveReadFirstMultiple($0)") -matrix WaveReadLaneFirst(matrix expr); - -// NOTE! On GLSL based targets the lane index *must* be a compile time expression! -// See https://github.com/KhronosGroup/GLSL/blob/master/extensions/khr/GL_KHR_shader_subgroup.txt -__generic -__glsl_extension(GL_KHR_shader_subgroup_ballot) -__spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBroadcast($0, $1)") -__target_intrinsic(cuda, "__shfl_sync(__activemask(), $0, $1)") -T WaveReadLaneAt(T value, int lane); -__generic -__spirv_version(1.3) -__target_intrinsic(glsl, "subgroupBroadcast($0, $1)") -__target_intrinsic(cuda, "_waveReadLaneAtMultiple($0, $1)") -vector WaveReadLaneAt(vector value, int lane); -__generic -__target_intrinsic(cuda, "_waveReadLaneAtMultiple($0, $1)") -matrix WaveReadLaneAt(matrix value, int lane); - - // `typedef`s to help with the fact that HLSL has been sorta-kinda case insensitive at various points typedef Texture2D texture2D; diff --git a/tests/hlsl-intrinsic/wave-prefix-product.slang b/tests/hlsl-intrinsic/wave-prefix-product.slang new file mode 100644 index 000000000..bc324ed7d --- /dev/null +++ b/tests/hlsl-intrinsic/wave-prefix-product.slang @@ -0,0 +1,16 @@ +//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute +//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-slang -compute +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0 +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute +//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(8, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + int idx = int(dispatchThreadID.x); + int val = WavePrefixProduct(idx + 1); + outputBuffer[idx] = val; +} \ No newline at end of file diff --git a/tests/hlsl-intrinsic/wave-prefix-product.slang.expected.txt b/tests/hlsl-intrinsic/wave-prefix-product.slang.expected.txt new file mode 100644 index 000000000..03cb63ab9 --- /dev/null +++ b/tests/hlsl-intrinsic/wave-prefix-product.slang.expected.txt @@ -0,0 +1,8 @@ +1 +1 +2 +6 +18 +78 +2D0 +13B0 -- cgit v1.2.3