diff options
| author | Konstantin <const@const.me> | 2023-01-16 14:52:43 +0100 |
|---|---|---|
| committer | Konstantin <const@const.me> | 2023-01-16 14:52:43 +0100 |
| commit | 8c4603c73675958efc960fbd4bb599a2909d106a (patch) | |
| tree | 714dc6fc9a1672d5fd7f89676b97e10959662abc /ComputeShaders/fmaRepeat2.hlsl | |
| parent | 990a8d0dbaefc996244097397259e92758b15cce (diff) | |
Source codes
Diffstat (limited to 'ComputeShaders/fmaRepeat2.hlsl')
| -rw-r--r-- | ComputeShaders/fmaRepeat2.hlsl | 45 |
1 files changed, 45 insertions, 0 deletions
diff --git a/ComputeShaders/fmaRepeat2.hlsl b/ComputeShaders/fmaRepeat2.hlsl new file mode 100644 index 0000000..edadf0a --- /dev/null +++ b/ComputeShaders/fmaRepeat2.hlsl @@ -0,0 +1,45 @@ +// Implementation of fmaRepeat() when source arguments have different shape or VRAM layout +// Dispatch [ nb[ 1 ], nb[ 2 ], nb[ 3 ] ] thread groups of this shader, where nb is size of the destination tensor +RWBuffer<float> tensor: register( u0 ); +Buffer<float> patternMul: register( t0 ); +Buffer<float> patternAdd: register( t1 ); + +cbuffer Constants: register( b0 ) +{ + uint4 tensorSize: packoffset( c0 ); + uint4 tensorStrides: packoffset( c1 ); + uint4 patternSizeMul: packoffset( c2 ); + uint4 patternStridesMul: packoffset( c3 ); + uint4 patternSizeAdd: packoffset( c4 ); + uint4 patternStridesAdd: packoffset( c5 ); +} + +#ifndef THREADS +#define THREADS 32 +#endif + +#include "repeatUtils.hlsli" + +inline float loadPattern( Buffer<float> buffer, uint rowStart, uint i, uint4 size, uint4 stride ) +{ + i %= size.x; + return buffer[ i * stride.x + rowStart ]; +} + +[ numthreads( THREADS, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + uint3 it = tensorIteratorState( group, thread, tensorSize, tensorStrides ); + const uint rsiMul = rowOffset( group % patternSizeMul.yzw, patternStridesMul ); + const uint rsiAdd = rowOffset( group % patternSizeAdd.yzw, patternStridesAdd ); + + for( uint i = thread; it.x < it.z; it.x += it.y, i++ ) + { + precise float f = tensor[ it.x ]; + float mul = loadPattern( patternMul, rsiMul, i, patternSizeMul, patternStridesMul ); + float add = loadPattern( patternAdd, rsiAdd, i, patternSizeAdd, patternStridesAdd ); + f *= mul; + f += add; + tensor[ it.x ] = f; + } +}
\ No newline at end of file |
