diff options
| author | Konstantin <const@const.me> | 2023-01-16 14:52:43 +0100 |
|---|---|---|
| committer | Konstantin <const@const.me> | 2023-01-16 14:52:43 +0100 |
| commit | 8c4603c73675958efc960fbd4bb599a2909d106a (patch) | |
| tree | 714dc6fc9a1672d5fd7f89676b97e10959662abc /Whisper/ML/TempBuffers.cpp | |
| parent | 990a8d0dbaefc996244097397259e92758b15cce (diff) | |
Source codes
Diffstat (limited to 'Whisper/ML/TempBuffers.cpp')
| -rw-r--r-- | Whisper/ML/TempBuffers.cpp | 88 |
1 files changed, 88 insertions, 0 deletions
diff --git a/Whisper/ML/TempBuffers.cpp b/Whisper/ML/TempBuffers.cpp new file mode 100644 index 0000000..8823364 --- /dev/null +++ b/Whisper/ML/TempBuffers.cpp @@ -0,0 +1,88 @@ +#include "stdafx.h" +#include "TempBuffers.h" +#include "../D3D/createBuffer.h" +#include "../D3D/MappedResource.h" +#include "../D3D/shaders.h" +using namespace DirectCompute; + +#define CHECK( hr ) { const HRESULT __hr = ( hr ); if( FAILED( __hr ) ) return __hr; } + +HRESULT TempBuffers::Buffer::resize( DXGI_FORMAT format, size_t elements, size_t cbElement, bool zeroMemory, CComPtr<ID3D11Buffer>& cb ) +{ + if( elements <= capacity ) + { + if( zeroMemory ) + TempBuffers::zeroMemory( *this, (uint32_t)elements, cb ); + return S_OK; + } + clear(); + + CComPtr<ID3D11Buffer> buffer; + const size_t totalBytes = elements * cbElement; + CHECK( createBuffer( eBufferUse::ReadWrite, totalBytes, &buffer, nullptr, nullptr ) ); + CHECK( TensorGpuViews::create( buffer, format, elements, true ) ); + capacity = elements; + return S_OK; +} + +void TempBuffers::zeroMemory( ID3D11UnorderedAccessView* uav, uint32_t length, CComPtr<ID3D11Buffer>& cb ) +{ + const __m128i cbData = _mm_cvtsi32_si128( (int)length ); + if( cb ) + { + MappedResource mapped; + check( mapped.map( cb, false ) ); + store16( mapped.data(), cbData ); + } + else + { + CD3D11_BUFFER_DESC desc{ 16, D3D11_BIND_CONSTANT_BUFFER, D3D11_USAGE_DYNAMIC, D3D11_CPU_ACCESS_WRITE }; + std::array<uint32_t, 4> cbBuffer; + store( cbBuffer, cbData ); + D3D11_SUBRESOURCE_DATA srd{ cbBuffer.data(), 0, 0 }; + check( device()->CreateBuffer( &desc, &srd, &cb ) ); + } + + ID3D11DeviceContext* ctx = context(); + ctx->CSSetUnorderedAccessViews( 0, 1, &uav, nullptr ); + csSetCB( cb ); + + constexpr uint32_t THREADS = 512; + constexpr uint32_t ITERATIONS = 128; + constexpr uint32_t elementsPerGroup = THREADS * ITERATIONS; + const uint32_t countGroups = ( length + elementsPerGroup - 1 ) / elementsPerGroup; + bindShader( eComputeShader::zeroMemory ); + ctx->Dispatch( countGroups, 1, 1 ); +} + +const TensorGpuViews& TempBuffers::fp16( size_t countElements, bool zeroMemory ) +{ + HRESULT hr = m_fp16.resize( DXGI_FORMAT_R16_FLOAT, countElements, 2, zeroMemory, smallCb ); + if( FAILED( hr ) ) + throw hr; + return m_fp16; +} + +const TensorGpuViews& TempBuffers::fp16_2( size_t countElements, bool zeroMemory ) +{ + HRESULT hr = m_fp16_2.resize( DXGI_FORMAT_R16_FLOAT, countElements, 2, zeroMemory, smallCb ); + if( FAILED( hr ) ) + throw hr; + return m_fp16_2; +} + +const TensorGpuViews& TempBuffers::fp32( size_t countElements, bool zeroMemory ) +{ + HRESULT hr = m_fp32.resize( DXGI_FORMAT_R32_FLOAT, countElements, 4, zeroMemory, smallCb ); + if( FAILED( hr ) ) + throw hr; + return m_fp32; +} + +__m128i TempBuffers::getMemoryUse() const +{ + size_t cb = m_fp16.getCapacity() * 2; + cb += m_fp16_2.getCapacity() * 2; + cb += m_fp32.getCapacity() * 4; + return setHigh_size( cb ); +}
\ No newline at end of file |
