summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjsmall-nvidia <jsmall@nvidia.com>2020-03-10 12:31:25 -0400
committerGitHub <noreply@github.com>2020-03-10 12:31:25 -0400
commita10d9cd8767e88a064719d71cc97144ba8b112d1 (patch)
treec54745fb698c8cacfeb1c4440261eb899338f20e
parent721d2e8a2d457081cd3d9b081979d436b7002c2c (diff)
WIP Prefix Sum for CUDA (#1268)
* Fix some typos. * Add wave-prefix-sum.slang test * First pass at implementing prefixSum. * Small improvments to prefixSum CUDA. * Small improvement to prefix sum. * Enable prefix sum in stdlib.
-rw-r--r--prelude/slang-cuda-prelude.h41
-rw-r--r--source/slang/hlsl.meta.slang21
-rw-r--r--tests/hlsl-intrinsic/wave-prefix-sum.slang16
-rw-r--r--tests/hlsl-intrinsic/wave-prefix-sum.slang.expected.txt8
4 files changed, 76 insertions, 10 deletions
diff --git a/prelude/slang-cuda-prelude.h b/prelude/slang-cuda-prelude.h
index c764afba1..6f2122934 100644
--- a/prelude/slang-cuda-prelude.h
+++ b/prelude/slang-cuda-prelude.h
@@ -823,6 +823,47 @@ __inline__ __device__ T _waveReadLaneAtMultiple(T inVal, int lane)
return outVal;
}
+__device__ int _wavePrefixSum(int val)
+{
+ const int mask = __activemask();
+ const int offsetSize = _waveCalcPow2Offset(mask);
+
+ const int laneId = _getLaneId();
+ if (offsetSize > 0)
+ {
+ int sum = val;
+ for (int i = 1; i < offsetSize; i += i)
+ {
+ const int readVal = __shfl_up_sync(mask, sum, i, offsetSize);
+ if (laneId >= i)
+ {
+ sum += readVal;
+ }
+ }
+ return sum - val;
+ }
+ else
+ {
+ int result = 0;
+ int remaining = mask;
+ while (remaining)
+ {
+ const int laneBit = remaining & -remaining;
+ // Get the sourceLane
+ const int srcLane = __ffs(laneBit) - 1;
+ // Broadcast (can also broadcast to self)
+ int readValue = __shfl_sync(mask, val, srcLane);
+ // Only accumulate if srcLane is less than this lane
+ if (srcLane < laneId)
+ {
+ result += readValue;
+ }
+ remaining &= ~laneBit;
+ }
+ return result;
+ }
+}
+
/* !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! */
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 4b717d540..b43cd009f 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -2497,12 +2497,12 @@ bool WaveIsFirstLane();
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExcusiveMul($0)")
+__target_intrinsic(glsl, "subgroupExclusiveMul($0)")
T WavePrefixProduct(T expr);
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExcusiveMul($0)")
+__target_intrinsic(glsl, "subgroupExclusiveMul($0)")
vector<T,N> WavePrefixProduct(vector<T,N> expr);
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
matrix<T,N,M> WavePrefixProduct(matrix<T,N,M> expr);
@@ -2510,12 +2510,13 @@ matrix<T,N,M> WavePrefixProduct(matrix<T,N,M> expr);
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExcusiveAdd($0)")
+__target_intrinsic(glsl, "subgroupExclusiveAdd($0)")
+__target_intrinsic(cuda, "_wavePrefixSum($0)")
T WavePrefixSum(T expr);
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExcusiveAdd($0)")
+__target_intrinsic(glsl, "subgroupExclusiveAdd($0)")
vector<T,N> WavePrefixSum(vector<T,N> expr);
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr);
@@ -2523,11 +2524,11 @@ matrix<T,N,M> WavePrefixSum(matrix<T,N,M> expr);
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExcusiveAnd($0)")
+__target_intrinsic(glsl, "subgroupExclusiveAnd($0)")
T WaveMultiPrefixBitAnd(T expr);
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExcusiveAnd($0)")
+__target_intrinsic(glsl, "subgroupExclusiveAnd($0)")
__generic<T : __BuiltinArithmeticType, let N : int>
vector<T,N> WaveMultiPrefixBitAnd(vector<T,N> expr);
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
@@ -2536,12 +2537,12 @@ matrix<T,N,M> WaveMultiPrefixBitAnd(matrix<T,N,M> expr);
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExcusiveOr($0)")
+__target_intrinsic(glsl, "subgroupExclusiveOr($0)")
T WaveMultiPrefixBitOr(T expr);
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExcusiveOr($0)")
+__target_intrinsic(glsl, "subgroupExclusiveOr($0)")
vector<T,N> WaveMultiPrefixBitOr(vector<T,N> expr);
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
matrix<T,N,M> WaveMultiPrefixBitOr(matrix<T,N,M> expr);
@@ -2549,12 +2550,12 @@ matrix<T,N,M> WaveMultiPrefixBitOr(matrix<T,N,M> expr);
__generic<T : __BuiltinArithmeticType>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExcusiveXor($0)")
+__target_intrinsic(glsl, "subgroupExclusiveXor($0)")
T WaveMultiPrefixBitXor(T expr);
__generic<T : __BuiltinArithmeticType, let N : int>
__glsl_extension(GL_KHR_shader_subgroup_arithmetic)
__spirv_version(1.3)
-__target_intrinsic(glsl, "subgroupExcusiveXor($0)")
+__target_intrinsic(glsl, "subgroupExclusiveXor($0)")
vector<T,N> WaveMultiPrefixBitXor(vector<T,N> expr);
__generic<T : __BuiltinArithmeticType, let N : int, let M : int>
matrix<T,N,M> WaveMultiPrefixBitXor(matrix<T,N,M> expr);
diff --git a/tests/hlsl-intrinsic/wave-prefix-sum.slang b/tests/hlsl-intrinsic/wave-prefix-sum.slang
new file mode 100644
index 000000000..f8d9bb560
--- /dev/null
+++ b/tests/hlsl-intrinsic/wave-prefix-sum.slang
@@ -0,0 +1,16 @@
+//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute
+//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-slang -compute
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute
+//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(8, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ int idx = int(dispatchThreadID.x);
+ int val = WavePrefixSum(1 << idx);
+ outputBuffer[idx] = val;
+} \ No newline at end of file
diff --git a/tests/hlsl-intrinsic/wave-prefix-sum.slang.expected.txt b/tests/hlsl-intrinsic/wave-prefix-sum.slang.expected.txt
new file mode 100644
index 000000000..6ec6deeea
--- /dev/null
+++ b/tests/hlsl-intrinsic/wave-prefix-sum.slang.expected.txt
@@ -0,0 +1,8 @@
+0
+1
+3
+7
+F
+1F
+3F
+7F