summaryrefslogtreecommitdiffstats
path: root/ComputeShaders/softMaxCompat.hlsl
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-01-16 14:52:43 +0100
committerKonstantin <const@const.me>2023-01-16 14:52:43 +0100
commit8c4603c73675958efc960fbd4bb599a2909d106a (patch)
tree714dc6fc9a1672d5fd7f89676b97e10959662abc /ComputeShaders/softMaxCompat.hlsl
parent990a8d0dbaefc996244097397259e92758b15cce (diff)
Source codes
Diffstat (limited to 'ComputeShaders/softMaxCompat.hlsl')
-rw-r--r--ComputeShaders/softMaxCompat.hlsl62
1 files changed, 62 insertions, 0 deletions
diff --git a/ComputeShaders/softMaxCompat.hlsl b/ComputeShaders/softMaxCompat.hlsl
new file mode 100644
index 0000000..2215ebd
--- /dev/null
+++ b/ComputeShaders/softMaxCompat.hlsl
@@ -0,0 +1,62 @@
+// ggml_compute_forward_soft_max_f32
+// Dispatch [ ( nr + 31 ) / 32, 1, 1 ] thread groups of this shader
+RWBuffer<float> result: register( u0 );
+
+// table_exp_f16
+Buffer<uint> lookupTable: register( t0 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 elements: packoffset( c0 );
+ uint4 strides: packoffset( c1 );
+ uint nr: packoffset( c2.x );
+}
+
+#include "miscUtils.hlsli"
+#include "fp64Utils.hlsli"
+
+static const float negativeInfinity = asfloat( 0xff800000 );
+
+[ numthreads( 32, 1, 1 ) ]
+void main( uint3 dtid: SV_DispatchThreadID )
+{
+ if( dtid.x >= nr )
+ return;
+
+ const uint p = dtid.x * strides[ 1 ];
+ const uint nc = elements[ 0 ];
+ const uint pEnd = p + nc;
+ uint i;
+
+ float m = negativeInfinity;
+ for( i = p; i < pEnd; i++ )
+ m = max( m, result[ i ] );
+
+ double sum = 0;
+ for( i = p; i < pEnd; i++ )
+ {
+ float f = result[ i ];
+
+ [branch]
+ if( f != negativeInfinity )
+ {
+ uint s = fp16Rounded( f - m );
+ s = lookupTable[ s ];
+ f = f16tof32( s );
+ sum += f;
+ }
+ else
+ f = 0;
+
+ result[ i ] = f;
+ }
+
+ const float scale = (float)div64( 1.0, sum );
+ // ggml_vec_scale_f32
+ for( i = p; i < pEnd; i++ )
+ {
+ float f = result[ i ];
+ f *= scale;
+ result[ i ] = f;
+ }
+} \ No newline at end of file