diff options
| author | Konstantin <const@const.me> | 2023-01-22 12:30:54 +0100 |
|---|---|---|
| committer | Konstantin <const@const.me> | 2023-01-22 12:30:54 +0100 |
| commit | 8fa57f680f002f4f636da687e40e21225f1ee392 (patch) | |
| tree | 15c9ad0828b1bb1d706fba9ab8192715ef30e7de /ComputeShaders/addRepeatEx.hlsl | |
| parent | cacec67bb649702db7a877de1b6482a46123f175 (diff) | |
GPU performance, optimized away a few shader dispatches
Diffstat (limited to 'ComputeShaders/addRepeatEx.hlsl')
| -rw-r--r-- | ComputeShaders/addRepeatEx.hlsl | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/ComputeShaders/addRepeatEx.hlsl b/ComputeShaders/addRepeatEx.hlsl new file mode 100644 index 0000000..ea510b3 --- /dev/null +++ b/ComputeShaders/addRepeatEx.hlsl @@ -0,0 +1,76 @@ +// An equivalent of "addRepeat.hlsl" followed by "addInPlace.hlsl". +// Merging into a single shader saves some global memory bandwidth and reduces CPU overhead wasted binding resources and dispatching shaders +RWBuffer<float> tensor: register( u0 ); +Buffer<float> pattern: register( t0 ); +Buffer<float> finalAdd: register( t1 ); + +cbuffer Constants: register( b0 ) +{ + uint4 tensorSize: packoffset( c0 ); + uint4 tensorStrides: packoffset( c1 ); + uint4 patternSize: packoffset( c2 ); + uint4 patternStrides: packoffset( c3 ); + // uint4 finalSize: packoffset( c4 ); + uint4 finalStrides: packoffset( c5 ); +} + +#ifndef THREADS +#define THREADS 256 +#endif + +#include "repeatUtils.hlsli" + +// The micro-kernel of the shader, computes tensor[ rsi.x ] += pattern + finalAdd[ rsi.y ] +inline void add2( uint2 rsi, float pattern ) +{ + float f = tensor[ rsi.x ]; + f += pattern; + f += finalAdd[ rsi.y ]; + tensor[ rsi.x ] = f; +} + +[ numthreads( THREADS, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint2 stridesX = uint2( tensorStrides.x, finalStrides.x ); + uint2 rsi; + rsi.x = rowOffset( group, tensorStrides ); + rsi.y = rowOffset( group, finalStrides ); + const uint rsiEnd = rsi.x + tensorSize.x * stridesX.x; + rsi += stridesX * thread; + + uint pat = rowOffset( group % patternSize.yzw, patternStrides ); + + if( patternSize.x == 1 ) + { + // The pattern only has 1 column, broadcasting over the row + const uint2 rsiInc = stridesX * THREADS; + const float p = pattern[ pat ]; + for( ; rsi.x < rsiEnd; rsi += rsiInc ) + add2( rsi, p ); + } + else if( patternSize.x <= THREADS ) + { + // pattern size doesn't exceed thread group size, load outside of the loop + const uint threadsPerGroup = THREADS - ( THREADS % patternSize.x ); + if( thread >= threadsPerGroup ) + return; + + const uint2 rsiInc = stridesX * threadsPerGroup; + pat += ( thread % patternSize.x ) * patternStrides.x; + const float p = pattern[ pat ]; + for( ; rsi.x < rsiEnd; rsi += rsiInc ) + add2( rsi, p ); + } + else + { + // Pattern rows are longer than the thread group, need to stream from both buffers + uint3 rsi3; + rsi3.xy = rsi; + rsi3.z = pat + thread * patternStrides.x; + + const uint3 rsiInc = uint3( stridesX, patternStrides.x ) * THREADS; + for( ; rsi3.x < rsiEnd; rsi3 += rsiInc ) + add2( rsi3.xy, pattern[ rsi3.z ] ); + } +}
\ No newline at end of file |
