summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-01-23 21:03:05 +0100
committerKonstantin <const@const.me>2023-01-23 21:03:05 +0100
commit1a52ce8301aa0f93c82cece3e3db0986beb3d41a (patch)
treef95a236f1b732671846acc84d444ea7d2d1a1b90
parent15dbcacdbc5db68c1ea86bb330d07ec70de75af6 (diff)
GPU performance, mulMatByRowTiled shader
-rw-r--r--ComputeShaders/mulMatByRowTiled.hlsl9
1 files changed, 5 insertions, 4 deletions
diff --git a/ComputeShaders/mulMatByRowTiled.hlsl b/ComputeShaders/mulMatByRowTiled.hlsl
index fea2fcb..11c7c18 100644
--- a/ComputeShaders/mulMatByRowTiled.hlsl
+++ b/ComputeShaders/mulMatByRowTiled.hlsl
@@ -39,7 +39,8 @@ void main( uint3 group: SV_GroupID, uint3 thread : SV_GroupThreadID, uint thread
// Zero out the shared buffer
for( i = thread.y; i < TILE_Y; i += THREADS_Y )
resTemp[ i ][ thread.x ] = 0.0;
- GroupMemoryBarrierWithGroupSync();
+ // Before the reduction at the end of this shader, each thread only loads/stores the [ thread.y + THREADS_Y * N ][ thread.x ] elements of the shared buffer,
+ // where N is an integer. That's why until the end, we don't need these thread sync instructions.
// Count of rows to compute in this thread group
const uint height = min( TILE_Y, arg0Size.y - group.x * TILE_Y );
@@ -70,7 +71,6 @@ void main( uint3 group: SV_GroupID, uint3 thread : SV_GroupThreadID, uint thread
acc = mad( v0, v1, acc );
resTemp[ i ][ thread.x ] = acc;
}
- GroupMemoryBarrierWithGroupSync();
}
const uint rem = arg0Size.x % THREADS_X;
@@ -92,10 +92,11 @@ void main( uint3 group: SV_GroupID, uint3 thread : SV_GroupThreadID, uint thread
acc = mad( v0, v1, acc );
resTemp[ i ][ thread.x ] = acc;
}
- GroupMemoryBarrierWithGroupSync();
}
- // Now we need horizontal sums of these shared accumulators, i.e. reduce [height][THREADS_X] shared array into [height][1] column
+ // Now we need horizontal sum of these shared accumulators, reducing [height][THREADS_X] shared array into [height][1] column
+ GroupMemoryBarrierWithGroupSync();
+
for( i = THREADS_X / 2; i > 0; i /= 2 )
{
if( thread.x < i )