From 50d9781b7387b0f7f56d19c72afcf390cca72b72 Mon Sep 17 00:00:00 2001 From: sricker-nvidia <115114531+sricker-nvidia@users.noreply.github.com> Date: Mon, 5 May 2025 15:30:33 -0700 Subject: Add countbits 16-bit and 8-bit support (#6433) (#6897) Change adds 16-bit and 8-bit support for countbits intrinsic. In cases where a backend's native counbits lacks support, support is emulated. New tests are added for 16-bit and 8-bit support. Additional testing added for 32-bit and minor updates made to 64-bit countbits. --- source/slang/hlsl.meta.slang | 81 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 66 insertions(+), 15 deletions(-) (limited to 'source') diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 0f04006e5..07160ae9d 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -8047,6 +8047,16 @@ vector cospi(vector x) } } +// emulate 64-bit countbits when not natively supported. +[__readNone] +[ForceInline] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] +internal uint __emulatedCountbits64(uint64_t value) +{ + uint2 value_uint2 = bit_cast(value); + uint2 counted_bits_uint2 = countbits(value_uint2); + return counted_bits_uint2.x + counted_bits_uint2.y; +} /// Population count. /// Counts the number of set bits in the binary representation of a value. @@ -8060,19 +8070,32 @@ vector cospi(vector x) __generic uint countbits(T value) { + // Emulate 8-bit support + // 8-bit support is not currently supported anywhere natively + if (T is int8_t || T is uint8_t) + { + return countbits(__intCast(value)); + } + __target_switch { case hlsl: + // 64-bit support dependent on SM6.0 and dxil + // 16-bit support dependent on SM6.2 and dxil __intrinsic_asm "countbits"; case glsl: + // 64-bit support dependent on GL_ARB_gpu_shader_int64 + // 16-bit support dependent on GL_EXT_shader_16bit_storage __intrinsic_asm "bitCount"; case metal: - if(T is int64_t || T is uint64_t) + if (T is int64_t || T is uint64_t) + { + return __emulatedCountbits64(__intCast(value)); + } + else if (T is int16_t || T is uint16_t) { - // emulate 64-bit - uint2 value_uint2 = bit_cast(value); - uint2 counted_bits_uint2 = countbits(value_uint2); - return counted_bits_uint2.x + counted_bits_uint2.y; + // emulate 16-bit + return countbits(__intCast(value)); } else { @@ -8084,17 +8107,28 @@ uint countbits(T value) case spirv: if(T is int64_t || T is uint64_t) { - // emulate 64-bit - uint2 value_uint2 = bit_cast(value); - uint2 counted_bits_uint2 = countbits(value_uint2); - return counted_bits_uint2.x + counted_bits_uint2.y; + return __emulatedCountbits64(__intCast(value)); + } + else if (T is int16_t || T is uint16_t) + { + // emulate 16-bit + return countbits(__intCast(value)); } else { + // OpBitCount only supports 32-bit return spirv_asm {OpBitCount $$uint result $value}; } case wgsl: - __intrinsic_asm "countOneBits"; + // wgsl only supports 32-bit integers + if (T is int32_t) + { + // wgsl countOneBits returns the same type as the + // one it was given. Cast signed ints to unsigned + // so we can provide the correct return value. + return countbits(__intCast(value)); + } + __intrinsic_asm "countOneBits"; } } @@ -8104,6 +8138,13 @@ __generic [require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, shader5_sm_5_0)] vector countbits(vector value) { + // Emulate 8-bit support + // 8-bit support is not currently supported anywhere natively + if (T is int8_t || T is uint8_t) + { + VECTOR_MAP_UNARY(uint, N, countbits, value); + } + __target_switch { case hlsl: @@ -8111,9 +8152,9 @@ vector countbits(vector value) case glsl: __intrinsic_asm "bitCount"; case metal: - if(T is int64_t || T is uint64_t) + if(T is int64_t || T is uint64_t || T is int16_t || T is uint16_t) { - // emulate 64-bit + // Emulate 64-bit and 16-bit VECTOR_MAP_UNARY(uint, N, countbits, value); } else @@ -8121,9 +8162,9 @@ vector countbits(vector value) __intrinsic_asm "popcount"; } case spirv: - if(T is int64_t || T is uint64_t) + if(T is int64_t || T is uint64_t || T is int16_t || T is uint16_t) { - // emulate 64-bit + // Emulate 64-bit and 16-bit VECTOR_MAP_UNARY(uint, N, countbits, value); } else @@ -8131,7 +8172,17 @@ vector countbits(vector value) return spirv_asm {OpBitCount $$vector result $value}; } case wgsl: - __intrinsic_asm "countOneBits"; + // wgsl only supports 32-bit integers + if (T is int32_t) + { + vector ret; + for (int i = 0; i < N; i++) + { + ret[i] = countbits(__intCast(value[i])); + } + return ret; + } + __intrinsic_asm "countOneBits"; default: VECTOR_MAP_UNARY(uint, N, countbits, value); } -- cgit v1.2.3