summaryrefslogtreecommitdiffstats
path: root/ComputeShaders/matReshapePanels.hlsl
blob: f26f24665ef6f4fdac573e8d243f501a42e2cc62 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
// This shader reshapes a matrix into the shape expected by mulMatTiledEx.hlsl and mulMatByRowTiledEx.hlsl compute shaders
// It's called in runtime, also while loading models from disk.
// So far, it's only used when running on AMD GPUs.
#ifndef TILE_SIZE
static const uint TILE_SIZE = 32;
#endif

// Input tensor
Buffer<float> source: register( t0 );
// Output tensor
RWBuffer<float> result: register( u0 );

cbuffer Constants: register( b0 )
{
	uint4 arg0Size: packoffset( c0 );
	uint4 arg0Strides: packoffset( c1 );
	// Count of elements per panel
	uint panelSize : packoffset( c2.y );
	// Layer strides of the output matrix
	uint2 layerStrides: packoffset( c2.z );
}

inline uint hadd( uint2 v2 ) { return v2.x + v2.y; }

groupshared float tileBuffer[ TILE_SIZE ][ TILE_SIZE ];

[ numthreads( TILE_SIZE, 1, 1 ) ]
void main( const uint3 group: SV_GroupID, const uint thread : SV_GroupIndex )
{
	uint rdi = hadd( group.yz * layerStrides );
	rdi += group.x * panelSize;
	rdi += thread;

	uint rsi = hadd( group.yz * arg0Strides.zw );
	const uint baseY = group.x * TILE_SIZE;
	const uint dispatchThread = baseY + thread;
	// Reshaping into a column major horizontal panel, height = TILE_SIZE, width = width of the source matrix
	uint width = arg0Size.x;
	// Usually TILE_SIZE; can be less for the last panel on the matrix when we need to generate zeros instead of loading these numbers
	const uint height = min( TILE_SIZE, arg0Size.y - baseY );

	if( arg0Strides.x == 1 )
	{
		// The input matrix is row major, can improve performance with coalesced loads and group shared buffer.
		rsi += baseY * arg0Strides.y;

		const uint widthCompleteTiles = width / TILE_SIZE;

		if( height < TILE_SIZE )
		{
			// This thread group was dispatched for the last panel of the matrix, it doesn't have enough rows
			// Write zeros to the corresponding elements of the groupshared buffer
			for( uint j = height; j < TILE_SIZE; j++ )
				tileBuffer[ thread ][ j ] = 0.0;
		}

		for( uint i = 0; i < widthCompleteTiles; i++, rsi += TILE_SIZE )
		{
			// Load [ TILE_SIZE ] * [ TILE_SIZE ] block with fully coalesced loads, store to group shared buffer in transposed order
			uint rsiTile = rsi + thread;
			uint j;
			for( j = 0; j < height; j++, rsiTile += arg0Strides.y )
			{
				// Each iteration of the loop loads a row of [ TILE_SIZE ] elements from the corresponding row of the source tensor
				// Fully coalesced load
				float f = source[ rsiTile ];
				// Random store but the local memory's fast, this works rather well in practice
				tileBuffer[ thread ][ j ] = f;
			}

			GroupMemoryBarrierWithGroupSync();

			// Copy from group shared buffer to output tensor
			for( j = 0; j < TILE_SIZE; j++, rdi += TILE_SIZE )
			{
				// Fully coalesced loads and stores
				float f = tileBuffer[ j ][ thread ];
				result[ rdi ] = f;
			}

			GroupMemoryBarrierWithGroupSync();
		}

		width %= TILE_SIZE;
		if( 0 == width )
			return;
		rsi += thread * arg0Strides.y;
	}
	else
		rsi += dispatchThread * arg0Strides.y;

	for( uint i = 0; i < width; i++ )
	{
		float f;
		[branch]
		if( thread < height )
			f = source[ rsi ];
		else
			f = 0.0;
		rsi += arg0Strides.x;

		result[ rdi ] = f;
		rdi += TILE_SIZE;
	}
}