summaryrefslogtreecommitdiffstats
path: root/ComputeShaders/mulMatMadMain.hlsl
blob: 0bd57533365baffb134ff9f18e5046ab780f65dc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
// GGML_TASK_COMPUTE step for matrix*matrix product, where nb01 < nb00
Buffer<float> arg0: register( t0 );
Buffer<float> arg1: register( t1 );
RWBuffer<float> resultTensor: register( u0 );
RWBuffer<float> tempBuffer: register( u1 );

cbuffer Constants: register( b0 )
{
	uint4 aSize: packoffset( c0 );
	uint4 aStride: packoffset( c1 );
	uint4 bSize: packoffset( c2 );
	uint4 bStride: packoffset( c3 );
	uint4 resSize: packoffset( c4 );
	bool resultFp16 : packoffset( c5.x );
	uint ne: packoffset( c5.y );
}

#include "miscUtils.hlsli"

// tempBuffer[ rdi .. ] = 0.0
inline void writeTempZeros( uint rdi, const uint len, const uint thread )
{
	const uint rdiEnd = rdi + len;
	for( rdi += thread; rdi < rdiEnd; rdi += 32 )
		tempBuffer[ rdi ] = 0.0;
}

// tempBuffer[ rdi .. ] += mul * arg0[ rsi .. ]
inline void vectorMad( uint rsi, uint rdi, const uint len, const float mul, const uint thread )
{
	const uint rsiEnd = rsi + len;
	rsi += thread;
	rdi += thread;
	for( ; rsi < rsiEnd; rsi += 32, rdi += 32 )
	{
		float f = tempBuffer[ rdi ];
		f = mad( mul, arg0[ rsi ], f );
		[branch]
		if( resultFp16 )
			f = adjustFp16( f );
		tempBuffer[ rdi ] = f;
	}
}

// resultTensor[ rdi .. ] = tempBuffer[ rsi .. ]
inline void copyRow( uint rsi, uint rdi, const uint len, const uint thread )
{
	const uint rsiEnd = rsi + len;
	rsi += thread;
	rdi += thread;
	for( ; rsi < rsiEnd; rsi += 32, rdi += 32 )
	{
		float f = tempBuffer[ rsi ];
		resultTensor[ rdi ] = f;
	}
}

// resultTensor[ rdi .. ] += tempBuffer[ rsi .. ]
inline void addRow( uint rsi, uint rdi, const uint len, const uint thread )
{
	const uint rsiEnd = rsi + len;
	rsi += thread;
	rdi += thread;
	for( ; rsi < rsiEnd; rsi += 32, rdi += 32 )
	{
		float f = resultTensor[ rdi ];
		f += tempBuffer[ rsi ];
		resultTensor[ rdi ] = f;
	}
}

[numthreads( 32, 1, 1 )]
void main( const uint3 group: SV_GroupID, const uint thread : SV_GroupIndex )
{
	const uint i1 = group[ 0 ];
	const uint i2 = group[ 1 ];
	const uint i3 = group[ 2 ];

	const uint ne00 = aSize[ 0 ];
	const uint ne01 = aSize[ 1 ];
	const uint ne02 = aSize[ 2 ];
	const uint ne03 = aSize[ 3 ];

	const uint ne10 = bSize[ 0 ];
	const uint ne11 = bSize[ 1 ];
	const uint ne12 = bSize[ 2 ];
	const uint ne13 = bSize[ 3 ];

	const uint ne0 = resSize[ 0 ];
	const uint ne1 = resSize[ 1 ];
	const uint ne2 = resSize[ 2 ];
	const uint ne3 = resSize[ 3 ];

	const uint nb00 = aStride[ 0 ];
	const uint nb01 = aStride[ 1 ];
	const uint nb02 = aStride[ 2 ];
	const uint nb03 = aStride[ 3 ];

	const uint nb10 = bStride[ 0 ];
	const uint nb11 = bStride[ 1 ];
	const uint nb12 = bStride[ 2 ];
	const uint nb13 = bStride[ 3 ];

	// dst_row = wdata + wo + i3*ne2*ne1*ne0 + i2*ne1*ne0 + i1*ne0;
	const uint tempRowThread0 = i3 * ne2 * ne1 * ne0 + i2 * ne1 * ne0 + i1 * ne0;

	// Faking 4 CPU threads trying to achieve bitwise compatibility with the CPU version
	const uint nth = 4;

	// GGML_TASK_COMPUTE
	{
		// src0_col = src0->data + ( i00 * nb00 + i02 * nb02 + i03 * nb03 );
		const uint aBase = i2 * nb02 + i3 * nb03;
		// src1_val = *      (float *) ((char *) src1->data + (i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
		const uint bBase = i1 * nb11 + i2 * nb12 + i3 * nb13;

		// total columns in src1
		const uint nc = ne10;
		// columns per thread
		const uint dc = ( nc + nth - 1 ) / nth;

		uint tempRow = tempRowThread0;
		for( uint ith = 0; ith < nth; ith++, tempRow += ne )
		{
			writeTempZeros( tempRow, ne01, thread );

			// column range for this thread
			const uint ic0 = dc * ith;
			const uint ic1 = min( ic0 + dc, nc );

			for( uint ic = ic0; ic < ic1; ic++ )
			{
				const uint idxA = aBase + ic * aStride[ 0 ];
				const uint idxB = bBase + ic * bStride[ 0 ];
				const float bValue = arg1[ idxB ];
				vectorMad( idxA, tempRow, ne01, bValue, thread );
			}
		}
	}

	// GGML_TASK_FINALIZE
	{
		const uint rdi = tempRowThread0;
		// const uint rdi = i1 * resSize[ 0 ] + i2 * resSize[ 0 ] * resSize[ 1 ] + i3 * resSize[ 0 ] * resSize[ 1 ] * resSize[ 2 ];
		// const uint rdi = ( ( i3 * resSize[ 2 ] + i2 ) * resSize[ 1 ] + i1 ) * resSize[ 0 ];

		uint tempRow = tempRowThread0;
		copyRow( tempRow, rdi, ne01, thread );

		tempRow += ne;
		for( uint ith = 1; ith < nth; ith++, tempRow += ne )
			addRow( tempRow, rdi, ne01, thread );
	}
}