blob: c82f3d82454dbaa611874e3214258da7ab6ddf0a (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
|
#pragma once
#include "Tensor.h"
namespace DirectCompute
{
// A tensor which supports dynamic updates from CPU, or downloads from VRAM to system RAM
class TensorEx : public Tensor
{
protected:
CComPtr<ID3D11Buffer> buffer;
CComPtr<ID3D11Buffer> stagingBuffer;
HRESULT getViewSize( uint32_t& cbElement, uint32_t& countElements ) const;
public:
HRESULT create( const ggml_tensor& ggml, eBufferUse usage, bool uploadData );
HRESULT create( eDataType type, eBufferUse usage, const std::array<uint32_t, 4>& sizeElements );
HRESULT download( void* rdi, size_t cb ) const;
HRESULT download( void* rdi ) const;
template<class E>
HRESULT download( std::vector<E>& vec ) const
{
uint32_t cbElement, numElements;
CHECK( getViewSize( cbElement, numElements ) );
try
{
vec.resize( numElements );
}
catch( const std::bad_alloc& )
{
return E_OUTOFMEMORY;
}
return download( vec.data(), (size_t)cbElement * numElements );
}
};
}
|