summaryrefslogtreecommitdiffstats
path: root/Whisper/ML/TempBuffers.h
blob: f376ed65cbfdd53f8e2b5d7947dfe78bee7343af (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
#pragma once
#include "TensorGpuViews.h"

namespace DirectCompute
{
	class TempBuffers
	{
		class Buffer : public TensorGpuViews
		{
			size_t capacity = 0;

		public:

			void clear()
			{
				TensorGpuViews::clear();
				capacity = 0;
			}

			HRESULT resize( DXGI_FORMAT format, size_t elements, size_t cbElement, bool zeroMemory, CComPtr<ID3D11Buffer>& cb );

			size_t getCapacity() const { return capacity; }
		};

		Buffer m_fp16;
		Buffer m_fp16_2;
		Buffer m_fp32;

	public:

		CComPtr<ID3D11Buffer> smallCb;

		static void zeroMemory( ID3D11UnorderedAccessView* uav, uint32_t length, CComPtr<ID3D11Buffer>& cb );

		const TensorGpuViews& fp16( size_t countElements, bool zeroMemory = false );
		const TensorGpuViews& fp16_2( size_t countElements, bool zeroMemory = false );
		const TensorGpuViews& fp32( size_t countElements, bool zeroMemory = false );

		void clear()
		{
			m_fp16.clear();
			m_fp16_2.clear();
			m_fp32.clear();
		}

		__m128i getMemoryUse() const;
	};
}