summaryrefslogtreecommitdiffstats
path: root/ComputeShaders/repeatUtils.hlsli
blob: 118150103066fb23af7307adba5c4b9e12db6af3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
inline uint rowOffset( uint3 idx, uint4 strides )
{
	return idx[ 0 ] * strides[ 1 ] + idx[ 1 ] * strides[ 2 ] + idx[ 2 ] * strides[ 3 ];
}

// Initial iterator state for a row of the output tensor
// x = current index, y = index increment, z = end of the index
inline uint3 tensorIteratorState( uint3 group, uint thread, uint4 size, uint4 stride )
{
	uint3 res;
	res.x = rowOffset( group, stride );
	res.y = THREADS * stride[ 0 ];
	res.z = res.x + size[ 0 ] * stride[ 0 ];
	res.x += thread * stride[ 0 ];
	return res;
}

// Handle a complete row of output tensor, using the iterator made by tensorIteratorState() function
#define ROW_LOOP( ts ) for( ; ts.x < ts.z; ts.x += ts.y )
// Same as above, using different row length
#define ROW_LOOP_EX( ts, len, stride ) for( ; ts.x < ts.z; ts.x += len * stride[ 0 ] )