blob: 263bcf7196688ddf991ae6050196737cb4af2bbc (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
|
#include "stdafx.h"
#include "LookupTablesData.h"
#include <immintrin.h>
using namespace DirectCompute;
namespace
{
inline float fp32( uint16_t f16 )
{
__m128i i = _mm_cvtsi32_si128( f16 );
__m128 f = _mm_cvtph_ps( i );
return _mm_cvtss_f32( f );
}
inline uint16_t fp16( float fp32 )
{
__m128 f = _mm_set_ss( fp32 );
__m128i i = _mm_cvtps_ph( f, 0 );
uint32_t res = (uint32_t)_mm_cvtsi128_si32( i );
return (uint16_t)res;
}
constexpr double GELU_COEF_A = 0.044715;
constexpr double SQRT_2_OVER_PI = 0.79788456080286535587989211986876;
inline float computeGelu( float x )
{
return (float)( 0.5 * x * ( 1.0 + tanh( SQRT_2_OVER_PI * x * ( 1.0 + GELU_COEF_A * x * x ) ) ) );
}
}
LookupTablesData::LookupTablesData()
{
for( int i = 0; i < 0x10000; i++ )
{
const float f = fp32( i );
gelu[ i ] = fp16( computeGelu( f ) );
exponent[ i ] = fp16( (float)exp( f ) );
}
}
|