summaryrefslogtreecommitdiffstats
path: root/ComputeShaders/fmaRepeat2.hlsl
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-01-16 14:52:43 +0100
committerKonstantin <const@const.me>2023-01-16 14:52:43 +0100
commit8c4603c73675958efc960fbd4bb599a2909d106a (patch)
tree714dc6fc9a1672d5fd7f89676b97e10959662abc /ComputeShaders/fmaRepeat2.hlsl
parent990a8d0dbaefc996244097397259e92758b15cce (diff)
Source codes
Diffstat (limited to 'ComputeShaders/fmaRepeat2.hlsl')
-rw-r--r--ComputeShaders/fmaRepeat2.hlsl45
1 files changed, 45 insertions, 0 deletions
diff --git a/ComputeShaders/fmaRepeat2.hlsl b/ComputeShaders/fmaRepeat2.hlsl
new file mode 100644
index 0000000..edadf0a
--- /dev/null
+++ b/ComputeShaders/fmaRepeat2.hlsl
@@ -0,0 +1,45 @@
+// Implementation of fmaRepeat() when source arguments have different shape or VRAM layout
+// Dispatch [ nb[ 1 ], nb[ 2 ], nb[ 3 ] ] thread groups of this shader, where nb is size of the destination tensor
+RWBuffer<float> tensor: register( u0 );
+Buffer<float> patternMul: register( t0 );
+Buffer<float> patternAdd: register( t1 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 tensorSize: packoffset( c0 );
+ uint4 tensorStrides: packoffset( c1 );
+ uint4 patternSizeMul: packoffset( c2 );
+ uint4 patternStridesMul: packoffset( c3 );
+ uint4 patternSizeAdd: packoffset( c4 );
+ uint4 patternStridesAdd: packoffset( c5 );
+}
+
+#ifndef THREADS
+#define THREADS 32
+#endif
+
+#include "repeatUtils.hlsli"
+
+inline float loadPattern( Buffer<float> buffer, uint rowStart, uint i, uint4 size, uint4 stride )
+{
+ i %= size.x;
+ return buffer[ i * stride.x + rowStart ];
+}
+
+[ numthreads( THREADS, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ uint3 it = tensorIteratorState( group, thread, tensorSize, tensorStrides );
+ const uint rsiMul = rowOffset( group % patternSizeMul.yzw, patternStridesMul );
+ const uint rsiAdd = rowOffset( group % patternSizeAdd.yzw, patternStridesAdd );
+
+ for( uint i = thread; it.x < it.z; it.x += it.y, i++ )
+ {
+ precise float f = tensor[ it.x ];
+ float mul = loadPattern( patternMul, rsiMul, i, patternSizeMul, patternStridesMul );
+ float add = loadPattern( patternAdd, rsiAdd, i, patternSizeAdd, patternStridesAdd );
+ f *= mul;
+ f += add;
+ tensor[ it.x ] = f;
+ }
+} \ No newline at end of file