blob: e0e9644c2bf4ea58164d7b615baea944e57474d2 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
|
#pragma once
#include "../Whisper/sModelParams.h"
#include "../Whisper/KeyValueBuffers.h"
#include "../D3D/MappedResource.h"
#include "../CPU/Tensor.h"
class KeyValueDownloader
{
CComPtr<ID3D11Buffer> keys, values;
uint32_t length = 0;
using E = uint16_t;
static constexpr DirectCompute::eDataType dataType = DirectCompute::eDataType::FP16;
public:
// Create the staging resources to download kvCross tensors produced by the GPGPU encoder
HRESULT create( const Whisper::sModelParams& mp );
// Download these two tensors from VRAM to the staging buffers in system RAM
HRESULT download( const DirectCompute::KeyValueBuffers& source );
class ReadMap
{
const uint32_t length;
DirectCompute::MappedResource mappedKeys, mappedValues;
public:
ReadMap( KeyValueDownloader& owner );
~ReadMap() = default;
ReadMap( const ReadMap& ) = delete;
// A slice of model.memory_k tensor
CpuCompute::Tensor keysView( uint32_t len, uint32_t off ) const
{
if( len + off <= length )
{
E* rsi = (E*)mappedKeys.data();
rsi += off;
return CpuCompute::Tensor::fromData( rsi, dataType, len );
}
throw E_BOUNDS;
}
// A slice of model.memory_v tensor
CpuCompute::Tensor valuesView( uint32_t len, uint32_t off ) const
{
if( len + off <= length )
{
E* rsi = (E*)mappedValues.data();
rsi += off;
return CpuCompute::Tensor::fromData( rsi, dataType, len );
}
throw E_BOUNDS;
}
};
// Map both staging buffers, return RAII object which unmaps when destroyed,
// which can supply the data in the shape of CpuCompute::Tensor vector
decltype( auto ) map()
{
return ReadMap( *this );
}
};
|