diff options
| author | Konstantin <const@const.me> | 2023-01-16 14:52:43 +0100 |
|---|---|---|
| committer | Konstantin <const@const.me> | 2023-01-16 14:52:43 +0100 |
| commit | 8c4603c73675958efc960fbd4bb599a2909d106a (patch) | |
| tree | 714dc6fc9a1672d5fd7f89676b97e10959662abc /ComputeShaders/fmaRepeat1.hlsl | |
| parent | 990a8d0dbaefc996244097397259e92758b15cce (diff) | |
Source codes
Diffstat (limited to 'ComputeShaders/fmaRepeat1.hlsl')
| -rw-r--r-- | ComputeShaders/fmaRepeat1.hlsl | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/ComputeShaders/fmaRepeat1.hlsl b/ComputeShaders/fmaRepeat1.hlsl new file mode 100644 index 0000000..3db3827 --- /dev/null +++ b/ComputeShaders/fmaRepeat1.hlsl @@ -0,0 +1,77 @@ +// Implementation of fmaRepeat() when both source arguments have same size and strides +// Dispatch [ nb[ 1 ], nb[ 2 ], nb[ 3 ] ] thread groups of this shader, where nb is size of the destination tensor +RWBuffer<float> tensor: register( u0 ); +Buffer<float> patternMul: register( t0 ); +Buffer<float> patternAdd: register( t1 ); + +cbuffer Constants: register( b0 ) +{ + uint4 tensorSize: packoffset( c0 ); + uint4 tensorStrides: packoffset( c1 ); + uint4 patternSize: packoffset( c2 ); + uint4 patternStrides: packoffset( c3 ); +} + +#ifndef THREADS +#define THREADS 512 +#endif + +#include "repeatUtils.hlsli" + +inline void computeSimple( uint idx, float mul, float add ) +{ + precise float f = tensor[ idx ]; + f *= mul; + f += add; + tensor[ idx ] = f; +} + +[ numthreads( THREADS, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + uint3 it = tensorIteratorState( group, thread, tensorSize, tensorStrides ); + uint rsi = rowOffset( group % patternSize.yzw, patternStrides ); + + if( patternSize[ 0 ] == 1 ) + { + // The pattern only has 1 column - broadcasting over the row + const float pMul = patternMul[ rsi ]; + const float pAdd = patternAdd[ rsi ]; + ROW_LOOP( it ) + computeSimple( it.x, pMul, pAdd ); + } + else if( patternSize[ 0 ] <= THREADS ) + { + // pattern size doesn't exceed thread group size: load pattern value outside of the loop + const uint threadsPerGroup = THREADS - ( THREADS % patternSize[ 0 ] ); + if( thread >= threadsPerGroup ) + return; + + rsi += ( thread % patternSize[ 0 ] ) * patternStrides[ 0 ]; + const float pMul = patternMul[ rsi ]; + const float pAdd = patternAdd[ rsi ]; + ROW_LOOP_EX( it, threadsPerGroup, tensorStrides ) + computeSimple( it.x, pMul, pAdd ); + } + else + { + // Pattern rows are larger than the thread group, need to stream from both buffers + const uint rsiInc = THREADS * patternStrides[ 0 ]; + const uint rsiDec = patternSize[ 0 ] * patternStrides[ 0 ]; + const uint rsiEnd = rsi + rsiDec; + rsi += thread * patternStrides[ 0 ]; + + ROW_LOOP( it ) + { + precise float f = tensor[ it.x ]; + float mul = patternMul[ rsi ]; + float add = patternAdd[ rsi ]; + rsi += rsiInc; + if( rsi >= rsiEnd ) + rsi -= rsiDec; + f *= mul; + f += add; + tensor[ it.x ] = f; + } + } +}
\ No newline at end of file |
