diff options
| author | Konstantin <const@const.me> | 2023-01-16 14:52:43 +0100 |
|---|---|---|
| committer | Konstantin <const@const.me> | 2023-01-16 14:52:43 +0100 |
| commit | 8c4603c73675958efc960fbd4bb599a2909d106a (patch) | |
| tree | 714dc6fc9a1672d5fd7f89676b97e10959662abc /ComputeShaders/softMaxCompat.hlsl | |
| parent | 990a8d0dbaefc996244097397259e92758b15cce (diff) | |
Source codes
Diffstat (limited to 'ComputeShaders/softMaxCompat.hlsl')
| -rw-r--r-- | ComputeShaders/softMaxCompat.hlsl | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/ComputeShaders/softMaxCompat.hlsl b/ComputeShaders/softMaxCompat.hlsl new file mode 100644 index 0000000..2215ebd --- /dev/null +++ b/ComputeShaders/softMaxCompat.hlsl @@ -0,0 +1,62 @@ +// ggml_compute_forward_soft_max_f32 +// Dispatch [ ( nr + 31 ) / 32, 1, 1 ] thread groups of this shader +RWBuffer<float> result: register( u0 ); + +// table_exp_f16 +Buffer<uint> lookupTable: register( t0 ); + +cbuffer Constants: register( b0 ) +{ + uint4 elements: packoffset( c0 ); + uint4 strides: packoffset( c1 ); + uint nr: packoffset( c2.x ); +} + +#include "miscUtils.hlsli" +#include "fp64Utils.hlsli" + +static const float negativeInfinity = asfloat( 0xff800000 ); + +[ numthreads( 32, 1, 1 ) ] +void main( uint3 dtid: SV_DispatchThreadID ) +{ + if( dtid.x >= nr ) + return; + + const uint p = dtid.x * strides[ 1 ]; + const uint nc = elements[ 0 ]; + const uint pEnd = p + nc; + uint i; + + float m = negativeInfinity; + for( i = p; i < pEnd; i++ ) + m = max( m, result[ i ] ); + + double sum = 0; + for( i = p; i < pEnd; i++ ) + { + float f = result[ i ]; + + [branch] + if( f != negativeInfinity ) + { + uint s = fp16Rounded( f - m ); + s = lookupTable[ s ]; + f = f16tof32( s ); + sum += f; + } + else + f = 0; + + result[ i ] = f; + } + + const float scale = (float)div64( 1.0, sum ); + // ggml_vec_scale_f32 + for( i = p; i < pEnd; i++ ) + { + float f = result[ i ]; + f *= scale; + result[ i ] = f; + } +}
\ No newline at end of file |
