summaryrefslogtreecommitdiffstats
path: root/Whisper/Hybrid/KeyValueDownloader.cpp
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-01-16 14:52:43 +0100
committerKonstantin <const@const.me>2023-01-16 14:52:43 +0100
commit8c4603c73675958efc960fbd4bb599a2909d106a (patch)
tree714dc6fc9a1672d5fd7f89676b97e10959662abc /Whisper/Hybrid/KeyValueDownloader.cpp
parent990a8d0dbaefc996244097397259e92758b15cce (diff)
Source codes
Diffstat (limited to 'Whisper/Hybrid/KeyValueDownloader.cpp')
-rw-r--r--Whisper/Hybrid/KeyValueDownloader.cpp32
1 files changed, 32 insertions, 0 deletions
diff --git a/Whisper/Hybrid/KeyValueDownloader.cpp b/Whisper/Hybrid/KeyValueDownloader.cpp
new file mode 100644
index 0000000..ad50136
--- /dev/null
+++ b/Whisper/Hybrid/KeyValueDownloader.cpp
@@ -0,0 +1,32 @@
+#include "stdafx.h"
+#include "KeyValueDownloader.h"
+
+HRESULT KeyValueDownloader::create( const Whisper::sModelParams& mp )
+{
+ const uint32_t n_audio_ctx = mp.n_audio_ctx;
+ const uint32_t n_mem = mp.n_text_layer * mp.n_audio_ctx;
+ const uint32_t n_elements = mp.n_text_state * n_mem;
+
+ CD3D11_BUFFER_DESC desc{ n_elements * 2, 0, D3D11_USAGE_STAGING, D3D11_CPU_ACCESS_READ };
+ ID3D11Device* dev = DirectCompute::device();
+ CHECK( dev->CreateBuffer( &desc, nullptr, &keys ) );
+ CHECK( dev->CreateBuffer( &desc, nullptr, &values ) );
+
+ length = n_elements;
+ return S_OK;
+}
+
+HRESULT KeyValueDownloader::download( const DirectCompute::KeyValueBuffers& source )
+{
+ ID3D11DeviceContext* ctx = DirectCompute::context();
+ ctx->CopyResource( keys, source.keys.getBuffer() );
+ ctx->CopyResource( values, source.values.getBuffer() );
+ return S_OK;
+}
+
+KeyValueDownloader::ReadMap::ReadMap( KeyValueDownloader& owner ) :
+ length( owner.length )
+{
+ check( mappedKeys.map( owner.keys, true ) );
+ check( mappedValues.map( owner.values, true ) );
+} \ No newline at end of file