summaryrefslogtreecommitdiffstats
path: root/ComputeShaders/addRepeatEx.hlsl
blob: ea510b30560ccdc20f73014b89a9ea6742a6653d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
// An equivalent of "addRepeat.hlsl" followed by "addInPlace.hlsl".
// Merging into a single shader saves some global memory bandwidth and reduces CPU overhead wasted binding resources and dispatching shaders
RWBuffer<float> tensor: register( u0 );
Buffer<float> pattern: register( t0 );
Buffer<float> finalAdd: register( t1 );

cbuffer Constants: register( b0 )
{
	uint4 tensorSize: packoffset( c0 );
	uint4 tensorStrides: packoffset( c1 );
	uint4 patternSize: packoffset( c2 );
	uint4 patternStrides: packoffset( c3 );
	// uint4 finalSize: packoffset( c4 );
	uint4 finalStrides: packoffset( c5 );
}

#ifndef THREADS
#define THREADS 256
#endif

#include "repeatUtils.hlsli"

// The micro-kernel of the shader, computes tensor[ rsi.x ] += pattern + finalAdd[ rsi.y ]
inline void add2( uint2 rsi, float pattern )
{
	float f = tensor[ rsi.x ];
	f += pattern;
	f += finalAdd[ rsi.y ];
	tensor[ rsi.x ] = f;
}

[ numthreads( THREADS, 1, 1 ) ]
void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
{
	const uint2 stridesX = uint2( tensorStrides.x, finalStrides.x );
	uint2 rsi;
	rsi.x = rowOffset( group, tensorStrides );
	rsi.y = rowOffset( group, finalStrides );
	const uint rsiEnd = rsi.x + tensorSize.x * stridesX.x;
	rsi += stridesX * thread;

	uint pat = rowOffset( group % patternSize.yzw, patternStrides );

	if( patternSize.x == 1 )
	{
		// The pattern only has 1 column, broadcasting over the row
		const uint2 rsiInc = stridesX * THREADS;
		const float p = pattern[ pat ];
		for( ; rsi.x < rsiEnd; rsi += rsiInc )
			add2( rsi, p );
	}
	else if( patternSize.x <= THREADS )
	{
		// pattern size doesn't exceed thread group size, load outside of the loop
		const uint threadsPerGroup = THREADS - ( THREADS % patternSize.x );
		if( thread >= threadsPerGroup )
			return;

		const uint2 rsiInc = stridesX * threadsPerGroup;
		pat += ( thread % patternSize.x ) * patternStrides.x;
		const float p = pattern[ pat ];
		for( ; rsi.x < rsiEnd; rsi += rsiInc )
			add2( rsi, p );
	}
	else
	{
		// Pattern rows are longer than the thread group, need to stream from both buffers
		uint3 rsi3;
		rsi3.xy = rsi;
		rsi3.z = pat + thread * patternStrides.x;

		const uint3 rsiInc = uint3( stridesX, patternStrides.x ) * THREADS;
		for( ; rsi3.x < rsiEnd; rsi3 += rsiInc )
			add2( rsi3.xy, pattern[ rsi3.z ] );
	}
}