diff options
| author | sricker-nvidia <115114531+sricker-nvidia@users.noreply.github.com> | 2025-05-05 15:30:33 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-05-05 22:30:33 +0000 |
| commit | 50d9781b7387b0f7f56d19c72afcf390cca72b72 (patch) | |
| tree | 7b6f1401f7a8257fa378930a052ca63f0fda91f4 /source/slang | |
| parent | 698e43372cefe0fff13150925aeb7f389c21a938 (diff) | |
Add countbits 16-bit and 8-bit support (#6433) (#6897)
Change adds 16-bit and 8-bit support for countbits intrinsic. In
cases where a backend's native counbits lacks support, support
is emulated.
New tests are added for 16-bit and 8-bit support. Additional testing
added for 32-bit and minor updates made to 64-bit countbits.
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/hlsl.meta.slang | 81 |
1 files changed, 66 insertions, 15 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 0f04006e5..07160ae9d 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -8047,6 +8047,16 @@ vector<T,N> cospi(vector<T,N> x) } } +// emulate 64-bit countbits when not natively supported. +[__readNone] +[ForceInline] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] +internal uint __emulatedCountbits64(uint64_t value) +{ + uint2 value_uint2 = bit_cast<uint2>(value); + uint2 counted_bits_uint2 = countbits(value_uint2); + return counted_bits_uint2.x + counted_bits_uint2.y; +} /// Population count. /// Counts the number of set bits in the binary representation of a value. @@ -8060,19 +8070,32 @@ vector<T,N> cospi(vector<T,N> x) __generic<T : __BuiltinIntegerType> uint countbits(T value) { + // Emulate 8-bit support + // 8-bit support is not currently supported anywhere natively + if (T is int8_t || T is uint8_t) + { + return countbits(__intCast<uint32_t>(value)); + } + __target_switch { case hlsl: + // 64-bit support dependent on SM6.0 and dxil + // 16-bit support dependent on SM6.2 and dxil __intrinsic_asm "countbits"; case glsl: + // 64-bit support dependent on GL_ARB_gpu_shader_int64 + // 16-bit support dependent on GL_EXT_shader_16bit_storage __intrinsic_asm "bitCount"; case metal: - if(T is int64_t || T is uint64_t) + if (T is int64_t || T is uint64_t) + { + return __emulatedCountbits64(__intCast<uint64_t>(value)); + } + else if (T is int16_t || T is uint16_t) { - // emulate 64-bit - uint2 value_uint2 = bit_cast<uint2>(value); - uint2 counted_bits_uint2 = countbits(value_uint2); - return counted_bits_uint2.x + counted_bits_uint2.y; + // emulate 16-bit + return countbits(__intCast<uint32_t>(value)); } else { @@ -8084,17 +8107,28 @@ uint countbits(T value) case spirv: if(T is int64_t || T is uint64_t) { - // emulate 64-bit - uint2 value_uint2 = bit_cast<uint2>(value); - uint2 counted_bits_uint2 = countbits(value_uint2); - return counted_bits_uint2.x + counted_bits_uint2.y; + return __emulatedCountbits64(__intCast<uint64_t>(value)); + } + else if (T is int16_t || T is uint16_t) + { + // emulate 16-bit + return countbits(__intCast<uint32_t>(value)); } else { + // OpBitCount only supports 32-bit return spirv_asm {OpBitCount $$uint result $value}; } case wgsl: - __intrinsic_asm "countOneBits"; + // wgsl only supports 32-bit integers + if (T is int32_t) + { + // wgsl countOneBits returns the same type as the + // one it was given. Cast signed ints to unsigned + // so we can provide the correct return value. + return countbits(__intCast<uint32_t>(value)); + } + __intrinsic_asm "countOneBits"; } } @@ -8104,6 +8138,13 @@ __generic<T : __BuiltinIntegerType, let N : int> [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] vector<uint, N> countbits(vector<T, N> value) { + // Emulate 8-bit support + // 8-bit support is not currently supported anywhere natively + if (T is int8_t || T is uint8_t) + { + VECTOR_MAP_UNARY(uint, N, countbits, value); + } + __target_switch { case hlsl: @@ -8111,9 +8152,9 @@ vector<uint, N> countbits(vector<T, N> value) case glsl: __intrinsic_asm "bitCount"; case metal: - if(T is int64_t || T is uint64_t) + if(T is int64_t || T is uint64_t || T is int16_t || T is uint16_t) { - // emulate 64-bit + // Emulate 64-bit and 16-bit VECTOR_MAP_UNARY(uint, N, countbits, value); } else @@ -8121,9 +8162,9 @@ vector<uint, N> countbits(vector<T, N> value) __intrinsic_asm "popcount"; } case spirv: - if(T is int64_t || T is uint64_t) + if(T is int64_t || T is uint64_t || T is int16_t || T is uint16_t) { - // emulate 64-bit + // Emulate 64-bit and 16-bit VECTOR_MAP_UNARY(uint, N, countbits, value); } else @@ -8131,7 +8172,17 @@ vector<uint, N> countbits(vector<T, N> value) return spirv_asm {OpBitCount $$vector<uint, N> result $value}; } case wgsl: - __intrinsic_asm "countOneBits"; + // wgsl only supports 32-bit integers + if (T is int32_t) + { + vector<uint32_t, N> ret; + for (int i = 0; i < N; i++) + { + ret[i] = countbits(__intCast<uint32_t>(value[i])); + } + return ret; + } + __intrinsic_asm "countOneBits"; default: VECTOR_MAP_UNARY(uint, N, countbits, value); } |
