summaryrefslogtreecommitdiffstats
path: root/Whisper/Hybrid/KeyValueDownloader.h
blob: e0e9644c2bf4ea58164d7b615baea944e57474d2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#pragma once
#include "../Whisper/sModelParams.h"
#include "../Whisper/KeyValueBuffers.h"
#include "../D3D/MappedResource.h"
#include "../CPU/Tensor.h"

class KeyValueDownloader
{
	CComPtr<ID3D11Buffer> keys, values;
	uint32_t length = 0;

	using E = uint16_t;
	static constexpr DirectCompute::eDataType dataType = DirectCompute::eDataType::FP16;

public:
	// Create the staging resources to download kvCross tensors produced by the GPGPU encoder
	HRESULT create( const Whisper::sModelParams& mp );

	// Download these two tensors from VRAM to the staging buffers in system RAM
	HRESULT download( const DirectCompute::KeyValueBuffers& source );

	class ReadMap
	{
		const uint32_t length;
		DirectCompute::MappedResource mappedKeys, mappedValues;

	public:
		ReadMap( KeyValueDownloader& owner );
		~ReadMap() = default;
		ReadMap( const ReadMap& ) = delete;

		// A slice of model.memory_k tensor
		CpuCompute::Tensor keysView( uint32_t len, uint32_t off ) const
		{
			if( len + off <= length )
			{
				E* rsi = (E*)mappedKeys.data();
				rsi += off;
				return CpuCompute::Tensor::fromData( rsi, dataType, len );
			}
			throw E_BOUNDS;
		}

		// A slice of model.memory_v tensor
		CpuCompute::Tensor valuesView( uint32_t len, uint32_t off ) const
		{
			if( len + off <= length )
			{
				E* rsi = (E*)mappedValues.data();
				rsi += off;
				return CpuCompute::Tensor::fromData( rsi, dataType, len );
			}
			throw E_BOUNDS;
		}
	};

	// Map both staging buffers, return RAII object which unmaps when destroyed,
	// which can supply the data in the shape of CpuCompute::Tensor vector
	decltype( auto ) map()
	{
		return ReadMap( *this );
	}
};