summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDarren Wihandi <65404740+fairywreath@users.noreply.github.com>2025-01-16 12:21:17 -0500
committerGitHub <noreply@github.com>2025-01-16 09:21:17 -0800
commitad7d13a8a934a56db87a4ece4b1afb0f1db1c9d9 (patch)
tree5726aa8833be14d298cff4e0c34f2b6106e34679
parent9167e0d04c2d57593506feca94aacf73aad17b65 (diff)
Implement Packed Dot Product intrinsics (#6068)
* implement dot acc intrinsics * fix sm version * fix test * improve comment --------- Co-authored-by: Yong He <yonghe@outlook.com>
-rw-r--r--source/slang/hlsl.meta.slang72
-rw-r--r--tests/hlsl-intrinsic/dot-accumulate.slang55
-rw-r--r--tests/hlsl-intrinsic/dot-accumulate.slang.expected.txt4
3 files changed, 125 insertions, 6 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 11c4ab6f4..d620197f3 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -16760,20 +16760,80 @@ static const uint HIT_KIND_TRIANGLE_BACK_FACE = 255;
//
// Shader Model 6.4
+// @public:
//
-/// Treats `left` and `right` as 4-component vectors of `UInt8` and computes `dot(left, right) + acc`
+/// Treats `x` and `y` as 4-component vectors of `UInt8` and computes `dot(x, y) + acc`
/// @category math
-uint dot4add_u8packed(uint left, uint right, uint acc);
+[__readNone]
+[ForceInline]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_6_4)]
+uint dot4add_u8packed(uint x, uint y, uint acc)
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "dot4add_u8packed";
+ case wgsl: __intrinsic_asm "(dot4U8Packed($0, $1) + $2)";
+ case spirv:
+ // OpUDotAccSat cannot be used as there should not be any saturation.
+ return spirv_asm
+ {
+ OpCapability DotProduct;
+ OpCapability DotProductInput4x8BitPacked;
+ OpExtension "SPV_KHR_integer_dot_product";
+ %dotResult = OpUDot $$uint $x $y 0;
+ result:$$uint = OpIAdd %dotResult $acc;
+ };
+ default:
+ uint4 vecX = unpack_u8u32(uint8_t4_packed(x));
+ uint4 vecY = unpack_u8u32(uint8_t4_packed(y));
+ return dot(vecX, vecY) + acc;
+ }
+}
-/// Treats `left` and `right` as 4-component vectors of `Int8` and computes `dot(left, right) + acc`
+/// Treats `x` and `y` as 4-component vectors of `int8` and computes `dot(x, y) + acc`
/// @category math
-int dot4add_i8packed(uint left, uint right, int acc);
+[__readNone]
+[ForceInline]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_6_4)]
+int dot4add_i8packed(uint x, uint y, int acc)
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "dot4add_i8packed";
+ case wgsl: __intrinsic_asm "(dot4I8Packed($0, $1) + $2)";
+ case spirv:
+ // OpSDottAccSat cannot be used as there should not be any saturation.
+ return spirv_asm
+ {
+ OpCapability DotProduct;
+ OpCapability DotProductInput4x8BitPacked;
+ OpExtension "SPV_KHR_integer_dot_product";
+ %dotResult = OpSDot $$int $x $y 0;
+ result:$$int = OpIAdd %dotResult $acc;
+ };
+ default:
+ int4 vecX = unpack_s8s32(int8_t4_packed(x));
+ int4 vecY = unpack_s8s32(int8_t4_packed(y));
+ return dot(vecX, vecY) + acc;
+ }
+}
-/// Computes `dot(left, right) + acc`.
+/// Computes `dot(x, y) + acc`.
/// May not produce infinities or NaNs for intermediate results that overflow the range of `half`
/// @category math
-float dot2add(float2 left, float2 right, float acc);
+[__readNone]
+[ForceInline]
+[require(cpp_cuda_glsl_hlsl_metal_spirv_wgsl, sm_6_4)]
+float dot2add(half2 x, half2 y, float acc)
+{
+ __target_switch
+ {
+ case hlsl: __intrinsic_asm "dot2add";
+ default:
+ return float(dot(x, y)) + acc;
+ }
+}
//
// Shader Model 6.5
diff --git a/tests/hlsl-intrinsic/dot-accumulate.slang b/tests/hlsl-intrinsic/dot-accumulate.slang
new file mode 100644
index 000000000..113ae40e3
--- /dev/null
+++ b/tests/hlsl-intrinsic/dot-accumulate.slang
@@ -0,0 +1,55 @@
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+// Does not run on DX11 as SM 6.4 is required.
+//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx11
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -dx12 -profile cs_6_4 -use-dxil -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-metal -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-wgsl -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -g0 -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0], stride=4):out,name outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint outputIndex = 0;
+
+ //
+ // dot4add_u8packed()
+ // [4 3 2 1] dot [1 2 4 2] + 5
+ // (4 * 1) + (3 * 2) + (2 * 4) + (1 * 2) + 5 = 25
+ //
+ uint unsignedX = 0x01020304U;
+ uint unsignedY = 0x02040201U;
+ uint unsignedAcc = 5U;
+ uint unsignedResult = dot4add_u8packed(unsignedX, unsignedY, unsignedAcc);
+ outputBuffer[outputIndex++] = unsignedResult;
+
+ //
+ // dot4add_i8packed()
+ // [6 2 3 -1] dot [-2 -6 2 6] - 100
+ // (6 * -2) + (2 * -6) + (3 * 2) + (-1 * 6) - 100 = -124
+ //
+ int signedX = 0xFF030206;
+ int signedY = 0x0602FAFE;
+ int signedAcc = -100;
+ int signedResult = dot4add_i8packed(signedX, signedY, signedAcc);
+ outputBuffer[outputIndex++] = signedResult;
+
+ //
+ // dot2add()
+ // [10.8 -3.3] dot [1.4 -20.3] - 2.11
+ // (10.8 * 1.4) + (-3.3 * -20.3) - 2.0 = 80.11
+ //
+ half2 half2X = half2(half(10.8), half(-3.3));
+ half2 half2Y = half2(half(1.4), half(-20.3));
+
+ // `half2Acc` is assigned -2.0 here.
+ // Thread index is used so that `half2Acc` will not be implicitly emitted as literal `-2.0` which
+ // may be treated as a double by DXC and cause it to fail to compile because no overload exists for `dot2add` that
+ // accepts double.
+ float half2Acc = float(dispatchThreadID.x + 1) * -2.0f;
+ float half2Result = dot2add(half2X, half2Y, half2Acc);
+ outputBuffer[outputIndex++] = int(half2Result);
+}
diff --git a/tests/hlsl-intrinsic/dot-accumulate.slang.expected.txt b/tests/hlsl-intrinsic/dot-accumulate.slang.expected.txt
new file mode 100644
index 000000000..184864973
--- /dev/null
+++ b/tests/hlsl-intrinsic/dot-accumulate.slang.expected.txt
@@ -0,0 +1,4 @@
+type: int32_t
+25
+-124
+80