summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang-glslang/slang-glslang.cpp2
-rw-r--r--source/slang/hlsl.meta.slang27
-rw-r--r--tests/hlsl-intrinsic/wave-active-count-bits.slang17
-rw-r--r--tests/hlsl-intrinsic/wave-active-count-bits.slang.expected.txt8
4 files changed, 48 insertions, 6 deletions
diff --git a/source/slang-glslang/slang-glslang.cpp b/source/slang-glslang/slang-glslang.cpp
index 1c756cb6c..80087997c 100644
--- a/source/slang-glslang/slang-glslang.cpp
+++ b/source/slang-glslang/slang-glslang.cpp
@@ -168,6 +168,8 @@ static void glslang_optimizeSPIRV(std::vector<unsigned int>& spirv, spv_target_e
break;
case SLANG_OPTIMIZATION_LEVEL_DEFAULT:
// Use a minimal set of performance settings
+ // If we run CreateInlineExhaustivePass, We need to run CreateMergeReturnPass first.
+ optimizer.RegisterPass(spvtools::CreateMergeReturnPass());
optimizer.RegisterPass(spvtools::CreateInlineExhaustivePass());
optimizer.RegisterPass(spvtools::CreateAggressiveDCEPass());
optimizer.RegisterPass(spvtools::CreatePrivateToLocalPass());
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 01fb17851..62a548555 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -2569,11 +2569,12 @@ __target_intrinsic(hlsl, "WaveActiveBallot($1)")
WaveMask WaveMaskBallot(WaveMask mask, bool condition);
__glsl_extension(GL_KHR_shader_subgroup_ballot)
-__spirv_version(1.3)
-__target_intrinsic(glsl, "bitCount(subgroupBallot($1))")
__target_intrinsic(cuda, "__popc(__ballot_sync($0, $1))")
__target_intrinsic(hlsl, "WaveActiveCountBits($1)")
-WaveMask WaveMaskCountBits(WaveMask mask, bool value);
+uint WaveMaskCountBits(WaveMask mask, bool value)
+{
+ return _WaveCountBits(WaveActiveBallot(value));
+}
// Waits until all warp lanes named in mask have executed a WaveMaskSharedSync (with the same mask)
// before resuming execution. Guarantees memory ordering in shared memory among threads participating
@@ -3262,9 +3263,6 @@ uint4 WaveActiveBallot(bool condition)
return WaveMaskBallot(WaveGetActiveMask(), condition);
}
-__glsl_extension(GL_KHR_shader_subgroup_ballot)
-__spirv_version(1.3)
-__target_intrinsic(glsl, "bitCount(subgroupBallot($0))")
__target_intrinsic(hlsl)
uint WaveActiveCountBits(bool value)
{
@@ -3292,6 +3290,23 @@ bool WaveIsFirstLane()
return WaveMaskIsFirstLane(WaveGetActiveMask());
}
+// It's useful to have a wave uint4 version of countbits, because some wave functions return uint4.
+// This implementation tries to limit the amount of work required by the actual lane count.
+uint _WaveCountBits(uint4 value)
+{
+ // Assume since WaveGetLaneCount should be known at compile time, the branches will hopefully boil away
+ const uint waveLaneCount = WaveGetLaneCount();
+ switch ((waveLaneCount - 1) / 32)
+ {
+ default:
+ case 0: return countbits(value.x);
+ case 1: return countbits(value.x) + countbits(value.y);
+ case 2: return countbits(value.x) + countbits(value.y) + countbits(value.z);
+ case 3: return countbits(value.x) + countbits(value.y) + countbits(value.z) + countbits(value.w);
+ }
+}
+
+
// Prefix
__generic<T : __BuiltinArithmeticType>
diff --git a/tests/hlsl-intrinsic/wave-active-count-bits.slang b/tests/hlsl-intrinsic/wave-active-count-bits.slang
new file mode 100644
index 000000000..f337a70bb
--- /dev/null
+++ b/tests/hlsl-intrinsic/wave-active-count-bits.slang
@@ -0,0 +1,17 @@
+//TEST_CATEGORY(wave, compute)
+//DISABLE_TEST:COMPARE_COMPUTE_EX:-cpu -compute
+//DISABLE_TEST:COMPARE_COMPUTE_EX:-slang -compute
+//TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -use-dxil -profile cs_6_0
+//TEST(vulkan):COMPARE_COMPUTE_EX:-vk -compute
+//TEST:COMPARE_COMPUTE_EX:-cuda -compute -render-features cuda_sm_7_0
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0], stride=4):out,name outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(8, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ int idx = int(dispatchThreadID.x);
+
+ outputBuffer[idx] = WaveActiveCountBits(idx & 5);
+} \ No newline at end of file
diff --git a/tests/hlsl-intrinsic/wave-active-count-bits.slang.expected.txt b/tests/hlsl-intrinsic/wave-active-count-bits.slang.expected.txt
new file mode 100644
index 000000000..9d96764ee
--- /dev/null
+++ b/tests/hlsl-intrinsic/wave-active-count-bits.slang.expected.txt
@@ -0,0 +1,8 @@
+6
+6
+6
+6
+6
+6
+6
+6