summaryrefslogtreecommitdiffstats
path: root/Whisper/ML/TensorEx.h
blob: c82f3d82454dbaa611874e3214258da7ab6ddf0a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
#pragma once
#include "Tensor.h"

namespace DirectCompute
{
	// A tensor which supports dynamic updates from CPU, or downloads from VRAM to system RAM
	class TensorEx : public Tensor
	{
	protected:
		CComPtr<ID3D11Buffer> buffer;
		CComPtr<ID3D11Buffer> stagingBuffer;

		HRESULT getViewSize( uint32_t& cbElement, uint32_t& countElements ) const;

	public:

		HRESULT create( const ggml_tensor& ggml, eBufferUse usage, bool uploadData );
		HRESULT create( eDataType type, eBufferUse usage, const std::array<uint32_t, 4>& sizeElements );

		HRESULT download( void* rdi, size_t cb ) const;

		HRESULT download( void* rdi ) const;

		template<class E>
		HRESULT download( std::vector<E>& vec ) const
		{
			uint32_t cbElement, numElements;
			CHECK( getViewSize( cbElement, numElements ) );

			try
			{
				vec.resize( numElements );
			}
			catch( const std::bad_alloc& )
			{
				return E_OUTOFMEMORY;
			}

			return download( vec.data(), (size_t)cbElement * numElements );
		}
	};
}