From ad7d13a8a934a56db87a4ece4b1afb0f1db1c9d9 Mon Sep 17 00:00:00 2001 From: Darren Wihandi <65404740+fairywreath@users.noreply.github.com> Date: Thu, 16 Jan 2025 12:21:17 -0500 Subject: Implement Packed Dot Product intrinsics (#6068) * implement dot acc intrinsics * fix sm version * fix test * improve comment --------- Co-authored-by: Yong He --- source/slang/hlsl.meta.slang | 72 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 66 insertions(+), 6 deletions(-) (limited to 'source') diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 11c4ab6f4..d620197f3 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -16760,20 +16760,80 @@ static const uint HIT_KIND_TRIANGLE_BACK_FACE = 255; // // Shader Model 6.4 +// @public: // -/// Treats `left` and `right` as 4-component vectors of `UInt8` and computes `dot(left, right) + acc` +/// Treats `x` and `y` as 4-component vectors of `UInt8` and computes `dot(x, y) + acc` /// @category math -uint dot4add_u8packed(uint left, uint right, uint acc); +[__readNone] +[ForceInline] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_6_4)] +uint dot4add_u8packed(uint x, uint y, uint acc) +{ + __target_switch + { + case hlsl: __intrinsic_asm "dot4add_u8packed"; + case wgsl: __intrinsic_asm "(dot4U8Packed($0, $1) + $2)"; + case spirv: + // OpUDotAccSat cannot be used as there should not be any saturation. + return spirv_asm + { + OpCapability DotProduct; + OpCapability DotProductInput4x8BitPacked; + OpExtension "SPV_KHR_integer_dot_product"; + %dotResult = OpUDot $$uint $x $y 0; + result:$$uint = OpIAdd %dotResult $acc; + }; + default: + uint4 vecX = unpack_u8u32(uint8_t4_packed(x)); + uint4 vecY = unpack_u8u32(uint8_t4_packed(y)); + return dot(vecX, vecY) + acc; + } +} -/// Treats `left` and `right` as 4-component vectors of `Int8` and computes `dot(left, right) + acc` +/// Treats `x` and `y` as 4-component vectors of `int8` and computes `dot(x, y) + acc` /// @category math -int dot4add_i8packed(uint left, uint right, int acc); +[__readNone] +[ForceInline] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_6_4)] +int dot4add_i8packed(uint x, uint y, int acc) +{ + __target_switch + { + case hlsl: __intrinsic_asm "dot4add_i8packed"; + case wgsl: __intrinsic_asm "(dot4I8Packed($0, $1) + $2)"; + case spirv: + // OpSDottAccSat cannot be used as there should not be any saturation. + return spirv_asm + { + OpCapability DotProduct; + OpCapability DotProductInput4x8BitPacked; + OpExtension "SPV_KHR_integer_dot_product"; + %dotResult = OpSDot $$int $x $y 0; + result:$$int = OpIAdd %dotResult $acc; + }; + default: + int4 vecX = unpack_s8s32(int8_t4_packed(x)); + int4 vecY = unpack_s8s32(int8_t4_packed(y)); + return dot(vecX, vecY) + acc; + } +} -/// Computes `dot(left, right) + acc`. +/// Computes `dot(x, y) + acc`. /// May not produce infinities or NaNs for intermediate results that overflow the range of `half` /// @category math -float dot2add(float2 left, float2 right, float acc); +[__readNone] +[ForceInline] +[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_6_4)] +float dot2add(half2 x, half2 y, float acc) +{ + __target_switch + { + case hlsl: __intrinsic_asm "dot2add"; + default: + return float(dot(x, y)) + acc; + } +} // // Shader Model 6.5 -- cgit v1.2.3