From a10d9cd8767e88a064719d71cc97144ba8b112d1 Mon Sep 17 00:00:00 2001 From: jsmall-nvidia Date: Tue, 10 Mar 2020 12:31:25 -0400 Subject: WIP Prefix Sum for CUDA (#1268) * Fix some typos. * Add wave-prefix-sum.slang test * First pass at implementing prefixSum. * Small improvments to prefixSum CUDA. * Small improvement to prefix sum. * Enable prefix sum in stdlib. --- prelude/slang-cuda-prelude.h | 41 ++++++++++++++++++++++ source/slang/hlsl.meta.slang | 21 +++++------ tests/hlsl-intrinsic/wave-prefix-sum.slang | 16 +++++++++ .../wave-prefix-sum.slang.expected.txt | 8 +++++ 4 files changed, 76 insertions(+), 10 deletions(-) create mode 100644 tests/hlsl-intrinsic/wave-prefix-sum.slang create mode 100644 tests/hlsl-intrinsic/wave-prefix-sum.slang.expected.txt diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h index c764afba1..6f2122934 100644 --- a/prelude/slang-cuda-prelude.h +++ b/prelude/slang-cuda-prelude.h @@ -823,6 +823,47 @@ __inline__ __device__ T _waveReadLaneAtMultiple(T inVal, int lane) return outVal; } +__device__ int _wavePrefixSum(int val) +{ + const int mask = __activemask(); + const int offsetSize = _waveCalcPow2Offset(mask); + + const int laneId = _getLaneId(); + if (offsetSize > 0) + { + int sum = val; + for (int i = 1; i < offsetSize; i += i) + { + const int readVal = __shfl_up_sync(mask, sum, i, offsetSize); + if (laneId >= i) + { + sum += readVal; + } + } + return sum - val; + } + else + { + int result = 0; + int remaining = mask; + while (remaining) + { + const int laneBit = remaining & -remaining; + // Get the sourceLane + const int srcLane = __ffs(laneBit) - 1; + // Broadcast (can also broadcast to self) + int readValue = __shfl_sync(mask, val, srcLane); + // Only accumulate if srcLane is less than this lane + if (srcLane < laneId) + { + result += readValue; + } + remaining &= ~laneBit; + } + return result; + } +} + /* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */ diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 4b717d540..b43cd009f 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -2497,12 +2497,12 @@ bool WaveIsFirstLane(); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExcusiveMul($0)") +__target_intrinsic(glsl, "subgroupExclusiveMul($0)") T WavePrefixProduct(T expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExcusiveMul($0)") +__target_intrinsic(glsl, "subgroupExclusiveMul($0)") vector WavePrefixProduct(vector expr); __generic matrix WavePrefixProduct(matrix expr); @@ -2510,12 +2510,13 @@ matrix WavePrefixProduct(matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExcusiveAdd($0)") +__target_intrinsic(glsl, "subgroupExclusiveAdd($0)") +__target_intrinsic(cuda, "_wavePrefixSum($0)") T WavePrefixSum(T expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExcusiveAdd($0)") +__target_intrinsic(glsl, "subgroupExclusiveAdd($0)") vector WavePrefixSum(vector expr); __generic matrix WavePrefixSum(matrix expr); @@ -2523,11 +2524,11 @@ matrix WavePrefixSum(matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExcusiveAnd($0)") +__target_intrinsic(glsl, "subgroupExclusiveAnd($0)") T WaveMultiPrefixBitAnd(T expr); __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExcusiveAnd($0)") +__target_intrinsic(glsl, "subgroupExclusiveAnd($0)") __generic vector WaveMultiPrefixBitAnd(vector expr); __generic @@ -2536,12 +2537,12 @@ matrix WaveMultiPrefixBitAnd(matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExcusiveOr($0)") +__target_intrinsic(glsl, "subgroupExclusiveOr($0)") T WaveMultiPrefixBitOr(T expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExcusiveOr($0)") +__target_intrinsic(glsl, "subgroupExclusiveOr($0)") vector WaveMultiPrefixBitOr(vector expr); __generic matrix WaveMultiPrefixBitOr(matrix expr); @@ -2549,12 +2550,12 @@ matrix WaveMultiPrefixBitOr(matrix expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExcusiveXor($0)") +__target_intrinsic(glsl, "subgroupExclusiveXor($0)") T WaveMultiPrefixBitXor(T expr); __generic __glsl_extension(GL_KHR_shader_subgroup_arithmetic) __spirv_version(1.3) -__target_intrinsic(glsl, "subgroupExcusiveXor($0)") +__target_intrinsic(glsl, "subgroupExclusiveXor($0)") vector WaveMultiPrefixBitXor(vector expr); __generic matrix WaveMultiPrefixBitXor(matrix expr); diff --git a/tests/hlsl-intrinsic/wave-prefix-sum.slang b/tests/hlsl-intrinsic/wave-prefix-sum.slang new file mode 100644 index 000000000..f8d9bb560 --- /dev/null +++ b/tests/hlsl-intrinsic/wave-prefix-sum.slang @@ -0,0 +1,16 @@ +//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute +//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-slang -compute +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0 +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute +//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer +RWStructuredBuffer outputBuffer; + +[numthreads(8, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + int idx = int(dispatchThreadID.x); + int val = WavePrefixSum(1 << idx); + outputBuffer[idx] = val; +} \ No newline at end of file diff --git a/tests/hlsl-intrinsic/wave-prefix-sum.slang.expected.txt b/tests/hlsl-intrinsic/wave-prefix-sum.slang.expected.txt new file mode 100644 index 000000000..6ec6deeea --- /dev/null +++ b/tests/hlsl-intrinsic/wave-prefix-sum.slang.expected.txt @@ -0,0 +1,8 @@ +0 +1 +3 +7 +F +1F +3F +7F -- cgit v1.2.3