summaryrefslogtreecommitdiff
path: root/source/slang
diff options
context:
space:
mode:
authorsricker-nvidia <115114531+sricker-nvidia@users.noreply.github.com>2025-05-05 15:30:33 -0700
committerGitHub <noreply@github.com>2025-05-05 22:30:33 +0000
commit50d9781b7387b0f7f56d19c72afcf390cca72b72 (patch)
tree7b6f1401f7a8257fa378930a052ca63f0fda91f4 /source/slang
parent698e43372cefe0fff13150925aeb7f389c21a938 (diff)
Add countbits 16-bit and 8-bit support (#6433) (#6897)
Change adds 16-bit and 8-bit support for countbits intrinsic. In cases where a backend's native counbits lacks support, support is emulated. New tests are added for 16-bit and 8-bit support. Additional testing added for 32-bit and minor updates made to 64-bit countbits.
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/hlsl.meta.slang81
1 files changed, 66 insertions, 15 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 0f04006e5..07160ae9d 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -8047,6 +8047,16 @@ vector<T,N> cospi(vector<T,N> x)
}
}
+// emulate 64-bit countbits when not natively supported.
+[__readNone]
+[ForceInline]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)]
+internal uint __emulatedCountbits64(uint64_t value)
+{
+ uint2 value_uint2 = bit_cast<uint2>(value);
+ uint2 counted_bits_uint2 = countbits(value_uint2);
+ return counted_bits_uint2.x + counted_bits_uint2.y;
+}
/// Population count.
/// Counts the number of set bits in the binary representation of a value.
@@ -8060,19 +8070,32 @@ vector<T,N> cospi(vector<T,N> x)
__generic<T : __BuiltinIntegerType>
uint countbits(T value)
{
+ // Emulate 8-bit support
+ // 8-bit support is not currently supported anywhere natively
+ if (T is int8_t || T is uint8_t)
+ {
+ return countbits(__intCast<uint32_t>(value));
+ }
+
__target_switch
{
case hlsl:
+ // 64-bit support dependent on SM6.0 and dxil
+ // 16-bit support dependent on SM6.2 and dxil
__intrinsic_asm "countbits";
case glsl:
+ // 64-bit support dependent on GL_ARB_gpu_shader_int64
+ // 16-bit support dependent on GL_EXT_shader_16bit_storage
__intrinsic_asm "bitCount";
case metal:
- if(T is int64_t || T is uint64_t)
+ if (T is int64_t || T is uint64_t)
+ {
+ return __emulatedCountbits64(__intCast<uint64_t>(value));
+ }
+ else if (T is int16_t || T is uint16_t)
{
- // emulate 64-bit
- uint2 value_uint2 = bit_cast<uint2>(value);
- uint2 counted_bits_uint2 = countbits(value_uint2);
- return counted_bits_uint2.x + counted_bits_uint2.y;
+ // emulate 16-bit
+ return countbits(__intCast<uint32_t>(value));
}
else
{
@@ -8084,17 +8107,28 @@ uint countbits(T value)
case spirv:
if(T is int64_t || T is uint64_t)
{
- // emulate 64-bit
- uint2 value_uint2 = bit_cast<uint2>(value);
- uint2 counted_bits_uint2 = countbits(value_uint2);
- return counted_bits_uint2.x + counted_bits_uint2.y;
+ return __emulatedCountbits64(__intCast<uint64_t>(value));
+ }
+ else if (T is int16_t || T is uint16_t)
+ {
+ // emulate 16-bit
+ return countbits(__intCast<uint32_t>(value));
}
else
{
+ // OpBitCount only supports 32-bit
return spirv_asm {OpBitCount $$uint result $value};
}
case wgsl:
- __intrinsic_asm "countOneBits";
+ // wgsl only supports 32-bit integers
+ if (T is int32_t)
+ {
+ // wgsl countOneBits returns the same type as the
+ // one it was given. Cast signed ints to unsigned
+ // so we can provide the correct return value.
+ return countbits(__intCast<uint32_t>(value));
+ }
+ __intrinsic_asm "countOneBits";
}
}
@@ -8104,6 +8138,13 @@ __generic<T : __BuiltinIntegerType, let N : int>
[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)]
vector<uint, N> countbits(vector<T, N> value)
{
+ // Emulate 8-bit support
+ // 8-bit support is not currently supported anywhere natively
+ if (T is int8_t || T is uint8_t)
+ {
+ VECTOR_MAP_UNARY(uint, N, countbits, value);
+ }
+
__target_switch
{
case hlsl:
@@ -8111,9 +8152,9 @@ vector<uint, N> countbits(vector<T, N> value)
case glsl:
__intrinsic_asm "bitCount";
case metal:
- if(T is int64_t || T is uint64_t)
+ if(T is int64_t || T is uint64_t || T is int16_t || T is uint16_t)
{
- // emulate 64-bit
+ // Emulate 64-bit and 16-bit
VECTOR_MAP_UNARY(uint, N, countbits, value);
}
else
@@ -8121,9 +8162,9 @@ vector<uint, N> countbits(vector<T, N> value)
__intrinsic_asm "popcount";
}
case spirv:
- if(T is int64_t || T is uint64_t)
+ if(T is int64_t || T is uint64_t || T is int16_t || T is uint16_t)
{
- // emulate 64-bit
+ // Emulate 64-bit and 16-bit
VECTOR_MAP_UNARY(uint, N, countbits, value);
}
else
@@ -8131,7 +8172,17 @@ vector<uint, N> countbits(vector<T, N> value)
return spirv_asm {OpBitCount $$vector<uint, N> result $value};
}
case wgsl:
- __intrinsic_asm "countOneBits";
+ // wgsl only supports 32-bit integers
+ if (T is int32_t)
+ {
+ vector<uint32_t, N> ret;
+ for (int i = 0; i < N; i++)
+ {
+ ret[i] = countbits(__intCast<uint32_t>(value[i]));
+ }
+ return ret;
+ }
+ __intrinsic_asm "countOneBits";
default:
VECTOR_MAP_UNARY(uint, N, countbits, value);
}