summaryrefslogtreecommitdiffstats
path: root/ComputeShaders/softMax.hlsl
diff options
context:
space:
mode:
Diffstat (limited to 'ComputeShaders/softMax.hlsl')
-rw-r--r--ComputeShaders/softMax.hlsl61
1 files changed, 45 insertions, 16 deletions
diff --git a/ComputeShaders/softMax.hlsl b/ComputeShaders/softMax.hlsl
index 6ebd0f2..259e457 100644
--- a/ComputeShaders/softMax.hlsl
+++ b/ComputeShaders/softMax.hlsl
@@ -1,9 +1,6 @@
// Dispatch [ nr, 1, 1 ] thread groups of this shader
RWBuffer<float> result: register( u0 );
-// table_exp_f16
-Buffer<uint> lookupTable: register( t0 );
-
cbuffer Constants: register( b0 )
{
uint4 elements: packoffset( c0 );
@@ -12,12 +9,50 @@ cbuffer Constants: register( b0 )
float inputScale: packoffset( c2.y );
}
-#include "miscUtils.hlsli"
-#include "groupReduce.hlsli"
+#ifndef THREADS
+static const uint THREADS = 32;
+#endif
+
+groupshared float sharedAccumulators[ THREADS ];
+
+// Compute horizontal maximum of the numbers, and broadcast to all threads of the group.
+void horizontalMaxBroadcast( const uint thread, inout float ax )
+{
+ sharedAccumulators[ thread ] = ax;
+ for( uint i = THREADS / 2; i > 0; i /= 2 )
+ {
+ GroupMemoryBarrierWithGroupSync();
+ if( thread < i )
+ {
+ ax = max( ax, sharedAccumulators[ thread + i ] );
+ sharedAccumulators[ thread ] = ax;
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+ ax = sharedAccumulators[ 0 ];
+}
+
+// Compute horisontal sum of the numbers. The result is only correct on the thread #0 of the group.
+void horizontalSum( const uint thread, inout float sum )
+{
+ sharedAccumulators[ thread ] = sum;
+ for( uint i = THREADS / 2; i > 1; i /= 2 )
+ {
+ GroupMemoryBarrierWithGroupSync();
+ if( thread < i )
+ {
+ sum += sharedAccumulators[ thread + i ];
+ sharedAccumulators[ thread ] = sum;
+ }
+ }
+ GroupMemoryBarrierWithGroupSync();
+ if( 0 == thread )
+ sum += sharedAccumulators[ 1 ];
+}
static const float negativeInfinity = asfloat( 0xff800000 );
-[ numthreads( 32, 1, 1 ) ]
+[numthreads( THREADS, 1, 1 )]
void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
{
const uint p = group.x * strides[ 1 ];
@@ -26,12 +61,12 @@ void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
uint i;
float m = negativeInfinity;
- for( i = p + thread; i < pEnd; i += 32 )
+ for( i = p + thread; i < pEnd; i += THREADS )
m = max( m, result[ i ] );
horizontalMaxBroadcast( thread, m );
float sum = 0;
- for( i = p + thread; i < pEnd; i += 32 )
+ for( i = p + thread; i < pEnd; i += THREADS )
{
float f = result[ i ];
@@ -39,14 +74,8 @@ void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
if( f != negativeInfinity )
{
f = ( f - m ) * inputScale;
-#if 1
- // Similar to Radeon Graphics, computing the exponent on nVidia 1080Ti is also slightly faster than loading from the lookup table
+ // On both Radeon Graphics and nVidia 1080Ti, computing the exponent is slightly faster than loading from the lookup table
f = exp( f );
-#else
- uint s = fp16Rounded( f );
- s = lookupTable[ s ];
- f = f16tof32( s );
-#endif
sum += f;
}
else
@@ -62,7 +91,7 @@ void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
const float scale = sharedAccumulators[ 0 ];
// ggml_vec_scale_f32
- for( i = p + thread; i < pEnd; i += 32 )
+ for( i = p + thread; i < pEnd; i += THREADS )
{
float f = result[ i ];
f *= scale;