diff options
| author | Konstantin <const@const.me> | 2023-02-03 16:07:54 +0100 |
|---|---|---|
| committer | Konstantin <const@const.me> | 2023-02-03 16:07:54 +0100 |
| commit | 3ba8e6389679007445f4fc1c52439cb0df3ddba0 (patch) | |
| tree | 0bffa67da50e7042f745f4c9aad72b2f046feeec | |
| parent | 671ca710ad3f0d3a64af7b84af1025e2a0b68296 (diff) | |
Bugfix, incorrect output of command-line examples when launched with multiple input files
| -rw-r--r-- | Examples/TranscribeCS/TranscribeCS.cs | 2 | ||||
| -rw-r--r-- | Examples/main/main.cpp | 2 | ||||
| -rw-r--r-- | Whisper/ML/MlContext.h | 2 | ||||
| -rw-r--r-- | Whisper/ML/TensorsArena.cpp | 30 | ||||
| -rw-r--r-- | Whisper/ML/TensorsArena.h | 8 | ||||
| -rw-r--r-- | Whisper/Whisper/ContextImpl.cpp | 6 | ||||
| -rw-r--r-- | Whisper/Whisper/DecoderInputBuffers.cpp | 11 | ||||
| -rw-r--r-- | Whisper/Whisper/DecoderInputBuffers.h | 4 | ||||
| -rw-r--r-- | Whisper/Whisper/KeyValueBuffers.cpp | 28 | ||||
| -rw-r--r-- | Whisper/Whisper/KeyValueBuffers.h | 4 | ||||
| -rw-r--r-- | Whisper/Whisper/WhisperContext.cpp | 19 | ||||
| -rw-r--r-- | Whisper/Whisper/WhisperContext.h | 10 |
12 files changed, 125 insertions, 1 deletions
diff --git a/Examples/TranscribeCS/TranscribeCS.cs b/Examples/TranscribeCS/TranscribeCS.cs index d94ed21..65239f1 100644 --- a/Examples/TranscribeCS/TranscribeCS.cs +++ b/Examples/TranscribeCS/TranscribeCS.cs @@ -25,6 +25,8 @@ namespace TranscribeCS using iModel model = Library.loadModel( cla.model ); using Context context = model.createContext(); cla.apply( ref context.parameters ); + // When there're multiple input files, assuming they're independent clips + context.parameters.setFlag( eFullParamsFlags.NoContext, true ); using iMediaFoundation mf = Library.initMediaFoundation(); Transcribe transcribe = new Transcribe( cla ); diff --git a/Examples/main/main.cpp b/Examples/main/main.cpp index 88ddc6d..7706b8f 100644 --- a/Examples/main/main.cpp +++ b/Examples/main/main.cpp @@ -246,6 +246,8 @@ int wmain( int argc, wchar_t* argv[] ) wparams.setFlag( eFullParamsFlags::PrintTimestamps, !params.no_timestamps ); wparams.setFlag( eFullParamsFlags::PrintSpecial, params.print_special ); wparams.setFlag( eFullParamsFlags::Translate, params.translate ); + // When there're multiple input files, assuming they're independent clips + wparams.setFlag( eFullParamsFlags::NoContext ); wparams.language = Whisper::makeLanguageKey( params.language.c_str() ); wparams.cpuThreads = params.n_threads; if( params.max_context != UINT_MAX ) diff --git a/Whisper/ML/MlContext.h b/Whisper/ML/MlContext.h index fe0d48d..aea1f24 100644 --- a/Whisper/ML/MlContext.h +++ b/Whisper/ML/MlContext.h @@ -43,6 +43,8 @@ namespace DirectCompute GpuProfiler profiler; + CComPtr<ID3D11Buffer>& getSmallConstantBuffer() { return temp.smallCb; } + public: MlContext( Whisper::ProfileCollection& profileColl ); MlContext( const MlContext& ) = delete; diff --git a/Whisper/ML/TensorsArena.cpp b/Whisper/ML/TensorsArena.cpp index 8e5d449..873200b 100644 --- a/Whisper/ML/TensorsArena.cpp +++ b/Whisper/ML/TensorsArena.cpp @@ -1,6 +1,7 @@ #include "stdafx.h" #include "TensorsArena.h" #include "../D3D/createBuffer.h" +#include "TempBuffers.h" static inline uint32_t roundUpPower2( uint32_t x ) { @@ -132,4 +133,33 @@ __m128i TensorsArena::getMemoryUse() const for( const auto& a : arenas ) res = _mm_add_epi64( res, a.getMemoryUse() ); return res; +} + +HRESULT PooledTensor::zeroMemory( CComPtr<ID3D11Buffer>& cb ) +{ + if( 0 == capacity ) + return S_FALSE; + try + { + TempBuffers::zeroMemory( views, capacity, cb ); + return S_OK; + } + catch( HRESULT hr ) + { + return hr; + } +} + +HRESULT TensorsArena::ArenaImpl::zeroMemory( CComPtr<ID3D11Buffer>& cb ) +{ + for( PooledTensor& e : pool ) + CHECK( e.zeroMemory( cb ) ); + return S_OK; +} + +HRESULT TensorsArena::zeroMemory( CComPtr<ID3D11Buffer>& cb ) +{ + for( ArenaImpl& e : arenas ) + CHECK( e.zeroMemory( cb ) ); + return S_OK; }
\ No newline at end of file diff --git a/Whisper/ML/TensorsArena.h b/Whisper/ML/TensorsArena.h index acfaf86..df6139e 100644 --- a/Whisper/ML/TensorsArena.h +++ b/Whisper/ML/TensorsArena.h @@ -14,6 +14,12 @@ namespace DirectCompute public: Tensor tensor( eDataType type, const std::array<uint32_t, 4>& ne, pfnNewCapacity pfnNewCap ); size_t getCapacity() const { return capacity; } + void clear() + { + views.clear(); + capacity = 0; + } + HRESULT zeroMemory( CComPtr<ID3D11Buffer>& cb ); }; __interface iTensorArena @@ -43,6 +49,7 @@ namespace DirectCompute void clear(); __m128i getMemoryUse() const; + HRESULT zeroMemory( CComPtr<ID3D11Buffer>& cb ); private: @@ -63,6 +70,7 @@ namespace DirectCompute Tensor tensor( const std::array<uint32_t, 4>& ne ); __m128i getMemoryUse() const; + HRESULT zeroMemory( CComPtr<ID3D11Buffer>& cb ); private: diff --git a/Whisper/Whisper/ContextImpl.cpp b/Whisper/Whisper/ContextImpl.cpp index eccda35..5a58e5f 100644 --- a/Whisper/Whisper/ContextImpl.cpp +++ b/Whisper/Whisper/ContextImpl.cpp @@ -274,6 +274,12 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const // Start measuring "Run" profiler value, both CPU and GPU times auto prof = context.completeProfiler(); bool stoppedPrematurely = false; + + if( params.flag( eFullParamsFlags::NoContext ) ) + { + CHECK( context.clearState() ); + } + while( true ) { if( nullptr != progress.pfn ) diff --git a/Whisper/Whisper/DecoderInputBuffers.cpp b/Whisper/Whisper/DecoderInputBuffers.cpp index 68d3cec..54b4dd5 100644 --- a/Whisper/Whisper/DecoderInputBuffers.cpp +++ b/Whisper/Whisper/DecoderInputBuffers.cpp @@ -63,4 +63,15 @@ void DecoderInputBuffers::clear() embd = nullptr; m_size = 0; m_capacity = 0; +} + +HRESULT DecoderInputBuffers::zeroMemory() const +{ + if( nullptr == embd || m_size == 0 ) + return S_FALSE; + + MappedResource mapped; + CHECK( mapped.map( embd, false ) ); + __stosd( (DWORD*)mapped.data(), 0, m_capacity ); + return S_OK; }
\ No newline at end of file diff --git a/Whisper/Whisper/DecoderInputBuffers.h b/Whisper/Whisper/DecoderInputBuffers.h index 9ce8f75..01b9011 100644 --- a/Whisper/Whisper/DecoderInputBuffers.h +++ b/Whisper/Whisper/DecoderInputBuffers.h @@ -3,7 +3,7 @@ namespace DirectCompute { - // Two dynamic buffers + // A dynamic buffer class DecoderInputBuffers { CComPtr<ID3D11Buffer> embd; @@ -25,5 +25,7 @@ namespace DirectCompute i *= sizeof( uint32_t ); return _mm_set_epi64x( (int64_t)i, 0 ); } + + HRESULT zeroMemory() const; }; }
\ No newline at end of file diff --git a/Whisper/Whisper/KeyValueBuffers.cpp b/Whisper/Whisper/KeyValueBuffers.cpp index b932fdb..0d5772d 100644 --- a/Whisper/Whisper/KeyValueBuffers.cpp +++ b/Whisper/Whisper/KeyValueBuffers.cpp @@ -1,6 +1,7 @@ #include "stdafx.h" #include "KeyValueBuffers.h" #include "../D3D/createBuffer.h" +#include "../ML/TempBuffers.h" using namespace DirectCompute; void AttentionBuffer::resize( uint32_t size ) @@ -39,4 +40,31 @@ void KeyValueBuffers::resize( uint32_t size ) { keys.resize( size ); values.resize( size ); +} + +HRESULT AttentionBuffer::zeroMemory( CComPtr<ID3D11Buffer>& cb ) const +{ + if( 0 == m_size ) + return S_FALSE; + + CComPtr<ID3D11UnorderedAccessView> uav; + CD3D11_UNORDERED_ACCESS_VIEW_DESC uavDesc{ D3D11_UAV_DIMENSION_BUFFER, DXGI_FORMAT_R16_FLOAT, 0, m_size }; + check( device()->CreateUnorderedAccessView( buffer, &uavDesc, &uav ) ); + + try + { + TempBuffers::zeroMemory( uav, m_size, cb ); + return S_OK; + } + catch( HRESULT hr ) + { + return hr; + } +} + +HRESULT KeyValueBuffers::zeroMemory( CComPtr<ID3D11Buffer>& cb ) const +{ + CHECK( keys.zeroMemory( cb ) ); + CHECK( values.zeroMemory( cb ) ); + return S_OK; }
\ No newline at end of file diff --git a/Whisper/Whisper/KeyValueBuffers.h b/Whisper/Whisper/KeyValueBuffers.h index 9c737be..6bebfda 100644 --- a/Whisper/Whisper/KeyValueBuffers.h +++ b/Whisper/Whisper/KeyValueBuffers.h @@ -25,6 +25,8 @@ namespace DirectCompute ID3D11Buffer* getBuffer() const { return buffer; } uint32_t getSize() const { return m_size; } + + HRESULT zeroMemory( CComPtr<ID3D11Buffer>& cb ) const; }; struct KeyValueBuffers @@ -46,5 +48,7 @@ namespace DirectCompute i *= sizeof( uint16_t ); return setHigh_size( (int64_t)i ); // They both are in VRAM } + + HRESULT zeroMemory( CComPtr<ID3D11Buffer>& cb ) const; }; }
\ No newline at end of file diff --git a/Whisper/Whisper/WhisperContext.cpp b/Whisper/Whisper/WhisperContext.cpp index a3a85b3..1aa7048 100644 --- a/Whisper/Whisper/WhisperContext.cpp +++ b/Whisper/Whisper/WhisperContext.cpp @@ -663,4 +663,23 @@ __m128i WhisperContext::getMemoryUse() const res = _mm_add_epi64( res, decoderInput.getMemoryUse() ); res = _mm_add_epi64( res, decoderOutput.getMemoryUse() ); return res; +} + +HRESULT WhisperContext::clearState() +{ + CComPtr<ID3D11Buffer>& cb = getSmallConstantBuffer(); + + // CHECK( kv.zeroMemory( cb ) ); + // CHECK( kvCross.zeroMemory( cb ) ); + // The above code doesn't work for some reason. + // Ideally need to debug, but destroying and re-creating these two buffers is not a huge deal. Unlike the buffers in the pools, only a few megabytes of VRAM. + kv.clear(); + kvCross.clear(); + + CHECK( arenas.outer.zeroMemory( cb ) ); + CHECK( arenas.layer.zeroMemory( cb ) ); + CHECK( decPool.zeroMemory( cb ) ); + CHECK( decoderInput.zeroMemory() ); + decoderOutput.clear(); + return S_OK; }
\ No newline at end of file diff --git a/Whisper/Whisper/WhisperContext.h b/Whisper/Whisper/WhisperContext.h index 0e4319f..de0e987 100644 --- a/Whisper/Whisper/WhisperContext.h +++ b/Whisper/Whisper/WhisperContext.h @@ -39,6 +39,14 @@ namespace DirectCompute Tensor tensor( eDataType type, const std::array<uint32_t, 4>& ne ) override final; void reset() override final { } __m128i getMemoryUse() const; + void clear() + { + result.clear(); + } + HRESULT zeroMemory( CComPtr<ID3D11Buffer>& cb ) + { + return result.zeroMemory( cb ); + } }; DecoderLayerPool decPool; @@ -119,5 +127,7 @@ namespace DirectCompute } __m128i getMemoryUse() const; + + HRESULT clearState(); }; }
\ No newline at end of file |
