blob: ad501360827c16033075f107e5bd10f2161ca6e6 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
|
#include "stdafx.h"
#include "KeyValueDownloader.h"
HRESULT KeyValueDownloader::create( const Whisper::sModelParams& mp )
{
const uint32_t n_audio_ctx = mp.n_audio_ctx;
const uint32_t n_mem = mp.n_text_layer * mp.n_audio_ctx;
const uint32_t n_elements = mp.n_text_state * n_mem;
CD3D11_BUFFER_DESC desc{ n_elements * 2, 0, D3D11_USAGE_STAGING, D3D11_CPU_ACCESS_READ };
ID3D11Device* dev = DirectCompute::device();
CHECK( dev->CreateBuffer( &desc, nullptr, &keys ) );
CHECK( dev->CreateBuffer( &desc, nullptr, &values ) );
length = n_elements;
return S_OK;
}
HRESULT KeyValueDownloader::download( const DirectCompute::KeyValueBuffers& source )
{
ID3D11DeviceContext* ctx = DirectCompute::context();
ctx->CopyResource( keys, source.keys.getBuffer() );
ctx->CopyResource( values, source.values.getBuffer() );
return S_OK;
}
KeyValueDownloader::ReadMap::ReadMap( KeyValueDownloader& owner ) :
length( owner.length )
{
check( mappedKeys.map( owner.keys, true ) );
check( mappedValues.map( owner.values, true ) );
}
|