summaryrefslogtreecommitdiffstats
path: root/Whisper/ML/TensorsArena.h
blob: df6139efea0a96e80646cc43bdfe4829703aa213 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#pragma once
#include "Tensor.h"

namespace DirectCompute
{
	using pfnNewCapacity = uint32_t( * )( uint32_t current, uint32_t requested );

	uint32_t defaultNewCapacity( uint32_t current, uint32_t requested );

	class PooledTensor
	{
		TensorGpuViews views;
		uint32_t capacity = 0;
	public:
		Tensor tensor( eDataType type, const std::array<uint32_t, 4>& ne, pfnNewCapacity pfnNewCap );
		size_t getCapacity() const { return capacity; }
		void clear()
		{
			views.clear();
			capacity = 0;
		}
		HRESULT zeroMemory( CComPtr<ID3D11Buffer>& cb );
	};

	__interface iTensorArena
	{
		Tensor tensor( eDataType type, const std::array<uint32_t, 4>& ne );
		void reset();
	};

	class TensorsArena: public iTensorArena
	{
	public:
		struct sArenaConfig
		{
			pfnNewCapacity pfnCapInner;
			size_t initialCapOuter;
		};

		struct sArenaConfigs
		{
			sArenaConfig fp16, fp32;
		};

		TensorsArena( const sArenaConfigs& configs );

		Tensor tensor( eDataType type, const std::array<uint32_t, 4>& ne ) override final;
		void reset() override final;

		void clear();
		__m128i getMemoryUse() const;
		HRESULT zeroMemory( CComPtr<ID3D11Buffer>& cb );

	private:

		struct ArenaImpl
		{
			ArenaImpl( eDataType dataType, const sArenaConfig& config );

			void reset()
			{
				index = 0;
			}

			void clear()
			{
				index = 0;
				pool.clear();
			}

			Tensor tensor( const std::array<uint32_t, 4>& ne );
			__m128i getMemoryUse() const;
			HRESULT zeroMemory( CComPtr<ID3D11Buffer>& cb );

		private:

			const eDataType type;
			const pfnNewCapacity pfnNewCap;

			std::vector<PooledTensor> pool;
			size_t index = 0;
		};

		static constexpr size_t countTypes = 2;
		std::array<ArenaImpl, countTypes> arenas;
	};
}