summaryrefslogtreecommitdiffstats
path: root/Whisper/ML/LookupTablesData.cpp
blob: 263bcf7196688ddf991ae6050196737cb4af2bbc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#include "stdafx.h"
#include "LookupTablesData.h"
#include <immintrin.h>
using namespace DirectCompute;

namespace
{
	inline float fp32( uint16_t f16 )
	{
		__m128i i = _mm_cvtsi32_si128( f16 );
		__m128 f = _mm_cvtph_ps( i );
		return _mm_cvtss_f32( f );
	}

	inline uint16_t fp16( float fp32 )
	{
		__m128 f = _mm_set_ss( fp32 );
		__m128i i = _mm_cvtps_ph( f, 0 );
		uint32_t res = (uint32_t)_mm_cvtsi128_si32( i );
		return (uint16_t)res;
	}

	constexpr double GELU_COEF_A = 0.044715;
	constexpr double SQRT_2_OVER_PI = 0.79788456080286535587989211986876;

	inline float computeGelu( float x )
	{
		return (float)( 0.5 * x * ( 1.0 + tanh( SQRT_2_OVER_PI * x * ( 1.0 + GELU_COEF_A * x * x ) ) ) );
	}
}

LookupTablesData::LookupTablesData()
{
	for( int i = 0; i < 0x10000; i++ )
	{
		const float f = fp32( i );
		gelu[ i ] = fp16( computeGelu( f ) );
		exponent[ i ] = fp16( (float)exp( f ) );
	}
}