summaryrefslogtreecommitdiffstats
path: root/ComputeShaders/addRepeatScale.hlsl
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-01-16 14:52:43 +0100
committerKonstantin <const@const.me>2023-01-16 14:52:43 +0100
commit8c4603c73675958efc960fbd4bb599a2909d106a (patch)
tree714dc6fc9a1672d5fd7f89676b97e10959662abc /ComputeShaders/addRepeatScale.hlsl
parent990a8d0dbaefc996244097397259e92758b15cce (diff)
Source codes
Diffstat (limited to 'ComputeShaders/addRepeatScale.hlsl')
-rw-r--r--ComputeShaders/addRepeatScale.hlsl73
1 files changed, 73 insertions, 0 deletions
diff --git a/ComputeShaders/addRepeatScale.hlsl b/ComputeShaders/addRepeatScale.hlsl
new file mode 100644
index 0000000..8c24088
--- /dev/null
+++ b/ComputeShaders/addRepeatScale.hlsl
@@ -0,0 +1,73 @@
+// Compute tensor = ( tensor + repeat( pattern, tensor ) ) * scale in 1 shot, without VRAM allocations
+// Dispatch [ nb[ 1 ], nb[ 2 ], nb[ 3 ] ] thread groups of this shader, where nb is size of the destination tensor
+RWBuffer<float> tensor: register( u0 );
+Buffer<float> pattern: register( t0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 tensorSize: packoffset( c0 );
+ uint4 tensorStrides: packoffset( c1 );
+ uint4 patternSize: packoffset( c2 );
+ uint4 patternStrides: packoffset( c3 );
+ float scalingMul : packoffset( c4.x );
+}
+
+#ifndef THREADS
+#define THREADS 512
+#endif
+
+#include "repeatUtils.hlsli"
+
+inline void computeSimple( uint idx, float add )
+{
+ float f = tensor[ idx ];
+ f += add;
+ f *= scalingMul;
+ tensor[ idx ] = f;
+}
+
+[ numthreads( THREADS, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ uint3 it = tensorIteratorState( group, thread, tensorSize, tensorStrides );
+ uint rsi = rowOffset( group % patternSize.yzw, patternStrides );
+
+ if( patternSize[ 0 ] == 1 )
+ {
+ // The pattern only has 1 column - broadcasting over the row
+ const float p = pattern[ rsi ];
+ ROW_LOOP( it )
+ computeSimple( it.x, p );
+ }
+ else if( patternSize[ 0 ] <= THREADS )
+ {
+ // pattern size doesn't exceed thread group size: load pattern value outside of the loop
+ const uint threadsPerGroup = THREADS - ( THREADS % patternSize[ 0 ] );
+ if( thread >= threadsPerGroup )
+ return;
+
+ const float p = pattern[ rsi + ( thread % patternSize[ 0 ] ) * patternStrides[ 0 ] ];
+ ROW_LOOP_EX( it, threadsPerGroup, tensorStrides )
+ computeSimple( it.x, p );
+ }
+ else
+ {
+ // Pattern rows are larger than the thread group, need to stream from both buffers
+ const uint rsiInc = THREADS * patternStrides[ 0 ];
+ const uint rsiDec = patternSize[ 0 ] * patternStrides[ 0 ];
+ const uint rsiEnd = rsi + rsiDec;
+ rsi += thread * patternStrides[ 0 ];
+
+ ROW_LOOP( it )
+ {
+ float f = tensor[ it.x ];
+ float p = pattern[ rsi ];
+ rsi += rsiInc;
+ if( rsi >= rsiEnd )
+ rsi -= rsiDec;
+ f += p;
+ f *= scalingMul;
+ tensor[ it.x ] = f;
+ }
+ }
+} \ No newline at end of file