blob: 7de6fb8b8c26a995def16ebe23a5897a8435c468 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
|
#include "stdafx.h"
#include "TensorShape.h"
#include "../source/ggml.h"
using namespace DirectCompute;
TensorShape::TensorShape()
{
setZero();
}
TensorShape::TensorShape( const TensorShape& that )
{
_mm_storeu_si128( ( __m128i* )ne.data(), that.sizeVec() );
_mm_storeu_si128( ( __m128i* )nb.data(), that.stridesVec() );
}
void TensorShape::operator=( const TensorShape& that )
{
_mm_storeu_si128( ( __m128i* )ne.data(), that.sizeVec() );
_mm_storeu_si128( ( __m128i* )nb.data(), that.stridesVec() );
}
HRESULT TensorShape::create( const ggml_tensor& ggml )
{
for( size_t i = 0; i < 4; i++ )
ne[ i ] = (uint32_t)ggml.ne[ i ];
const ggml_type dataType = ggml.type;
// Verify a few things
uint32_t cbElement = (uint32_t)ggml_type_size( dataType );
for( size_t i = 0; i < 4; i++ )
{
size_t stride = ggml.nb[ i ];
if( 0 != stride % cbElement )
return E_INVALIDARG;
size_t nn = stride / cbElement;
if( nn > UINT_MAX )
return DISP_E_OVERFLOW;
nb[ i ] = (uint32_t)nn;
}
return S_OK;
}
TensorShape::TensorShape( const ggml_tensor& ggml )
{
HRESULT hr = create( ggml );
if( FAILED( hr ) )
throw hr;
}
void TensorShape::setDenseStrides()
{
nb[ 0 ] = 1;
nb[ 1 ] = ne[ 0 ];
const uint32_t p01 = ne[ 0 ] * ne[ 1 ];
nb[ 2 ] = p01;
nb[ 3 ] = p01 * ne[ 2 ];
}
bool DirectCompute::canMulMat( const TensorShape& t0, const TensorShape& t1 )
{
/*
return
( t0.ne[ 0 ] == t1.ne[ 0 ] ) &&
( t0.ne[ 2 ] == t1.ne[ 2 ] ) &&
( t0.ne[ 3 ] == t1.ne[ 3 ] ); */
__m128i a = t0.sizeVec();
__m128i b = t1.sizeVec();
__m128i xx = _mm_xor_si128( a, b );
xx = _mm_shuffle_epi32( xx, _MM_SHUFFLE( 3, 2, 0, 0 ) );
return (bool)_mm_testz_si128( xx, xx );
}
|