summaryrefslogtreecommitdiffstats
path: root/ComputeShaders/mulMatByRowTiledEx.hlsl
blob: d286ec20cd680c2b5a9e5c95c1850a2b751065e8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
// matrix*row vector product, needs first argument reshaped into a sequence of horizontal column major panels
#ifndef TILE_SIZE
static const uint TILE_SIZE = 32;
#endif
#ifndef THREADS_Y
static const uint THREADS_Y = 8;
#endif

// First tensor, reshaped into dense column major horizontal panels of size [ width, TILE_SIZE ]
Buffer<float> arg0: register( t0 );
// Second tensor, reshaped into dense column major horizontal panels of size [ width, TILE_SIZE ]
Buffer<float> arg1: register( t1 );
// FP32 output tensor, row major and continuous
RWBuffer<float> result: register( u0 );

cbuffer Constants: register( b0 )
{
	uint4 arg0Size: packoffset( c0 );
	uint arg0panel: packoffset( c1.y );
	uint2 arg0LayerStrides: packoffset( c1.z );
	// uint4 arg1Size: packoffset( c2 );
	uint4 arg1Strides: packoffset( c3 );
	uint4 resultSize: packoffset( c4 );
	uint4 resultStrides: packoffset( c5 );
}

inline uint hadd4( const uint4 v )
{
	const uint2 v2 = v.xy + v.zw;
	return v2.x + v2.y;
}

inline float hadd4( const float4 v )
{
	const float2 v2 = v.xy + v.zw;
	return v2.x + v2.y;
}

groupshared float reductionBuffer[ THREADS_Y ][ TILE_SIZE ];

[numthreads( TILE_SIZE, THREADS_Y, 1 )]
void main( const uint3 group: SV_GroupID, const uint3 thread : SV_GroupThreadID )
{
	const uint2 layer = group.yz;
	// Source offsets for the complete thread group
	uint2 rsi;
	rsi.x = group.x * arg0panel + layer.x * arg0LayerStrides.x + layer.y * arg0LayerStrides.y;
	rsi.y = layer.x * arg1Strides.z + layer.y * arg1Strides.w;
	// Apply source offsets for this particular thread
	rsi.x += thread.y * TILE_SIZE + thread.x;
	rsi.y += thread.y * arg1Strides.x;

	const uint2 rsiInc = uint2( THREADS_Y * TILE_SIZE, THREADS_Y * arg1Strides.x );

	const uint completeTiles = arg0Size.x / ( THREADS_Y * 4 );
	uint i;
	float4 acc = 0.0;
	for( i = 0; i < completeTiles; i++ )
	{
		// Each iteration of this loop consumes THREADS_Y*4 columns from the arg0 panel, and THREADS_Y*4 values from arg1
		float4 v0, v1;
		[unroll]
		for( uint j = 0; j < 4; j++, rsi += rsiInc )
		{
			// Load [ TILE_SIZE, THREADS_Y ] block from the first source tensor
			v0[ j ] = arg0[ rsi.x ];
			// Broadcast [ THREADS_Y ] row from the second source tensor
			v1[ j ] = arg1[ rsi.y ];
		}

		// Now we have [ TILE_SIZE, THREADS_Y * 4 ] block from the first source tensor in the v0 vector,
		// and [ THREADS_Y * 4 ] row from the second one in the v1 vector
		// Multiply and accumulate.
		acc = mad( v0, v1, acc );
	}

	// Handle the remainder columns, if any.
	// When present, their count is in [ 1 .. THREADS_Y * 4 - 1 ] interval
	const uint rem = arg0Size.x % ( THREADS_Y * 4 );
	if( rem != 0 )
	{
		float4 v0 = 0.0, v1 = 0.0;
		[unroll]
		for( uint j = 0; j < 4; j++, rsi += rsiInc )
		{
			const uint x = ( j * THREADS_Y ) + thread.y;
			if( x < rem )
			{
				v0[ j ] = arg0[ rsi.x ];
				v1[ j ] = arg1[ rsi.y ];
			}
		}
		acc = mad( v0, v1, acc );
	}

	// We now have [ TILE_SIZE, THREADS_Y * 4 ] block in the local variables of this thread group
	// The group however only outputs [ TILE_SIZE ] elements max, need a reduction
	float acc1 = hadd4( acc );
	reductionBuffer[ thread.y ][ thread.x ] = acc1;
	GroupMemoryBarrierWithGroupSync();

	for( i = THREADS_Y / 2; i > 1; i /= 2 )
	{
		if( thread.y < i )
		{
			acc1 += reductionBuffer[ thread.y + i ][ thread.x ];
			reductionBuffer[ thread.y ][ thread.x ] = acc1;
		}
		GroupMemoryBarrierWithGroupSync();
	}

	if( thread.y != 0 )
		return;

	const uint resultPos = group.x * TILE_SIZE;
	const uint outputSize = min( TILE_SIZE, resultSize.x - resultPos );
	if( thread.x >= outputSize )
		return;

	const uint4 resultPos4 = uint4( resultPos + thread.x, 0, layer );
	const uint rdi = hadd4( resultPos4 * resultStrides );
	result[ rdi ] = acc1 + reductionBuffer[ 1 ][ thread.x ];
}