blob: 68ed30be683222fc96c23aaa4c1e0e13896d3952 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
|
// Ported from ggml_compute_forward_flash_attn_f16
// Dispatch with [ neq1*neq2*neq3, 1, 1 ] thread groups
Buffer<float> q: register( t0 );
Buffer<float> k: register( t1 );
Buffer<float> v: register( t2 );
RWBuffer<float> result: register( u0 );
// This temporary buffer should fit tempBufferStride * neq1 * neq2 * neq3 elements, FP32 precision
RWBuffer<float> temp: register( u1 );
cbuffer Constants: register( b0 )
{
uint4 q_elements: packoffset( c0 );
uint4 q_strides: packoffset( c1 );
uint4 k_elements: packoffset( c2 );
uint4 k_strides: packoffset( c3 );
uint4 v_elements: packoffset( c4 );
uint4 v_strides: packoffset( c5 );
uint4 res_elements: packoffset( c6 );
uint4 res_strides: packoffset( c7 );
bool masked : packoffset( c8.x );
// 1.0 / sqrt( (double) D )
float scale : packoffset( c8.y );
// This number is required to be >= nek1, and ideally rounded up to either 32 (L2 line) or 128 (L1 line) bytes
uint tempBufferStride: packoffset( c8.z );
}
static const float negativeInfinity = asfloat( 0xff800000 );
// Convert FP32 number to FP16 using rounding to nearest, then upcast back to FP32
inline float roundToFp16( const float src )
{
const uint trunc16 = f32tof16( src );
const float trunc32 = f16tof32( trunc16 );
const uint truncExp = ( trunc16 >> 10 ) & 0x1F;
if( truncExp != 0x1F )
{
const uint next16 = trunc16 + 1;
const float next32 = f16tof32( next16 );
const float errTrunc = abs( src - trunc32 );
const float errNext = abs( src - next32 );
if( errTrunc < errNext )
{
// Truncated was closer to the source
return trunc32;
}
else if( errTrunc > errNext )
{
// Truncated + 1 was closer to the source
return next32;
}
else
{
// Exactly half, doing banker's rounding to nearest even
return ( 0 == ( trunc16 & 1 ) ) ? trunc32 : next32;
}
}
else
{
// INF or NAN
return trunc32;
}
}
|