summaryrefslogtreecommitdiffstats
path: root/ComputeShaders/mulMatDotMain.hlsl
blob: 47c6d3eff79e2668704976ea07d6504218f2c86f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
// GGML_TASK_COMPUTE step for matrix*matrix product, where nb01 >= nb00;
// Dispatch with [ ne11, ne01*ne02*ne03 ] thread groups
// Each thread group computes a single dot product
Buffer<float> arg0: register( t0 );
Buffer<float> arg1: register( t1 );
RWBuffer<float> result: register( u0 );

cbuffer Constants: register( b0 )
{
	uint4 src0_elements: packoffset( c0 );
	uint4 src0_strides: packoffset( c1 );
	uint4 src1_elements: packoffset( c2 );
	uint4 result_elements: packoffset( c4 );
	uint4 result_strides: packoffset( c5 );
}

inline uint product( uint3 vec )
{
	return vec.x * vec.y * vec.z;
}

inline uint product( uint4 vec )
{
	uint2 tmp = vec.xy * vec.zw;
	return tmp.x * tmp.y;
}

inline float dotProductInner( uint i0, uint i1, uint length, uint thread )
{
	float res = 0;
	for( uint i = thread; i < length; i += 32 )
		res = mad( arg0[ i0 + i ], arg1[ i1 + i ], res );
	return res;
}

#include "groupReduce.hlsli"

[numthreads( 32, 1, 1 )]
void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
{
	const uint ne00 = src0_elements.x;
	const uint ne01 = src0_elements.y;
	const uint ne02 = src0_elements.z;
	const uint ne03 = src0_elements.w;

	const uint ne10 = src1_elements.x;
	const uint ne11 = src1_elements.y;
	const uint ne12 = src1_elements.z;
	const uint ne13 = src1_elements.w;

	const int nb00 = src0_strides.x;
	const int nb01 = src0_strides.y;
	const int nb02 = src0_strides.z;
	const int nb03 = src0_strides.w;

	// total rows in src0
	// const int nr = ne01*ne02*ne03;
	const uint nr = product( src0_elements.yzw );

	const uint ir = group.y;

	// src0 indices
	const uint i03 = ir / ( ne02 * ne01 );
	const uint i02 = ( ir - i03 * ne02 * ne01 ) / ne01;
	const uint i01 = ( ir - i03 * ne02 * ne01 - i02 * ne01 );

	const uint i13 = i03;
	const uint i12 = i02;

	const uint i0 = i01;
	const uint i2 = i02;
	const uint i3 = i03;

	// src0_row = (ggml_fp16_t *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
	// src1_col = wdata + ( i13 * ne12 * ne11 + i12 * ne11 + 0 ) * ne00;
	const uint src0_row = i01 * nb01 + i02 * nb02 + i03 * nb03;
	const uint src1_col = ( i13 * ne12 * ne11 + i12 * ne11 ) * ne00;

	const uint ic = group.x;
	float curr = dotProductInner( src0_row, src1_col + ic * ne00, ne00, thread );
	horizontalSumCompatNew( thread, curr );

	if( 0 != thread )
		return;

	const uint nb0 = result_strides.x;
	const uint nb1 = result_strides.y;
	const uint nb2 = result_strides.z;
	const uint nb3 = result_strides.w;

	const uint ne0 = result_elements.x;
	// float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
	const uint dst_col = i0 * nb0 + i2 * nb2 + i3 * nb3;
	result[ dst_col + ic * ne0 ] = curr;
}