summaryrefslogtreecommitdiffstats
path: root/Whisper/ML/TensorEx.cpp
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-01-16 14:52:43 +0100
committerKonstantin <const@const.me>2023-01-16 14:52:43 +0100
commit8c4603c73675958efc960fbd4bb599a2909d106a (patch)
tree714dc6fc9a1672d5fd7f89676b97e10959662abc /Whisper/ML/TensorEx.cpp
parent990a8d0dbaefc996244097397259e92758b15cce (diff)
Source codes
Diffstat (limited to 'Whisper/ML/TensorEx.cpp')
-rw-r--r--Whisper/ML/TensorEx.cpp97
1 files changed, 97 insertions, 0 deletions
diff --git a/Whisper/ML/TensorEx.cpp b/Whisper/ML/TensorEx.cpp
new file mode 100644
index 0000000..97e4e30
--- /dev/null
+++ b/Whisper/ML/TensorEx.cpp
@@ -0,0 +1,97 @@
+#include "stdafx.h"
+#include "TensorEx.h"
+#include "../D3D/createBuffer.h"
+#include "../source/ggml.h"
+#include "../D3D/MappedResource.h"
+using namespace DirectCompute;
+
+HRESULT TensorEx::create( const ggml_tensor& ggml, eBufferUse usage, bool uploadData )
+{
+ TensorGpuViews::clear();
+ buffer = nullptr;
+ stagingBuffer = nullptr;
+
+ CHECK( TensorShape::create( ggml ) );
+ const ggml_type dataType = ggml.type;
+ const uint32_t cbElement = (uint32_t)ggml_type_size( dataType );
+
+ const size_t totalBytes = ggml_nbytes( &ggml );
+ if( totalBytes > INT_MAX )
+ return DISP_E_OVERFLOW;
+ const uint32_t countElements = (uint32_t)( totalBytes / cbElement );
+
+ {
+ const void* const rsi = uploadData ? ggml.data : nullptr;
+ ID3D11Buffer** ppStagingBuffer = ( usage == eBufferUse::ReadWriteDownload ) ? &stagingBuffer : nullptr;
+ CHECK( createBuffer( usage, totalBytes, &buffer, rsi, ppStagingBuffer ) );
+ }
+
+ DXGI_FORMAT format;
+ switch( dataType )
+ {
+ case GGML_TYPE_F16:
+ format = DXGI_FORMAT_R16_FLOAT;
+ break;
+ case GGML_TYPE_F32:
+ format = DXGI_FORMAT_R32_FLOAT;
+ break;
+ default:
+ return E_NOTIMPL;
+ }
+
+ const bool makeUav = usage == eBufferUse::ReadWrite || usage == eBufferUse::ReadWriteDownload;
+ return TensorGpuViews::create( buffer, format, totalBytes / cbElement, makeUav );
+}
+
+HRESULT TensorEx::create( eDataType type, eBufferUse usage, const std::array<uint32_t, 4>& sizeElements )
+{
+ TensorGpuViews::clear();
+ buffer = nullptr;
+ stagingBuffer = nullptr;
+ std::initializer_list<uint32_t> il( sizeElements.data(), sizeElements.data() + 4 );
+
+ ID3D11Buffer** ppStaging = ( usage == eBufferUse::ReadWriteDownload ) ? &stagingBuffer : nullptr;
+ return Tensor::create( type, il, usage, buffer, nullptr, ppStaging );
+}
+
+HRESULT TensorEx::getViewSize( uint32_t& cbElement, uint32_t& countElements ) const
+{
+ ID3D11ShaderResourceView* const srv = *this;
+ if( nullptr == srv )
+ return OLE_E_BLANK;
+
+ D3D11_SHADER_RESOURCE_VIEW_DESC viewDesc;
+ srv->GetDesc( &viewDesc );
+
+ cbElement = dxgiSizeof( viewDesc.Format );
+
+ assert( viewDesc.ViewDimension == D3D_SRV_DIMENSION_BUFFER );
+ assert( viewDesc.Buffer.FirstElement == 0 );
+ countElements = viewDesc.Buffer.NumElements;
+
+ return S_OK;
+}
+
+HRESULT TensorEx::download( void* rdi, size_t cb ) const
+{
+ if( nullptr == stagingBuffer )
+ return HRESULT_FROM_WIN32( ERROR_GPIO_OPERATION_DENIED ); // The requested operation is not supported for the specified handle.
+
+ ID3D11DeviceContext* const ctx = context();
+ ctx->CopyResource( stagingBuffer, buffer );
+
+ MappedResource mapped;
+ CHECK( mapped.map( stagingBuffer, true ) );
+ memcpy( rdi, mapped.data(), cb );
+
+ return S_OK;
+}
+
+HRESULT TensorEx::download( void* rdi ) const
+{
+ uint32_t cbElement, numElements;
+ CHECK( getViewSize( cbElement, numElements ) );
+
+ size_t cb = (size_t)cbElement * numElements;
+ return download( rdi, cb );
+} \ No newline at end of file