summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-02-03 16:07:54 +0100
committerKonstantin <const@const.me>2023-02-03 16:07:54 +0100
commit3ba8e6389679007445f4fc1c52439cb0df3ddba0 (patch)
tree0bffa67da50e7042f745f4c9aad72b2f046feeec
parent671ca710ad3f0d3a64af7b84af1025e2a0b68296 (diff)
Bugfix, incorrect output of command-line examples when launched with multiple input files
-rw-r--r--Examples/TranscribeCS/TranscribeCS.cs2
-rw-r--r--Examples/main/main.cpp2
-rw-r--r--Whisper/ML/MlContext.h2
-rw-r--r--Whisper/ML/TensorsArena.cpp30
-rw-r--r--Whisper/ML/TensorsArena.h8
-rw-r--r--Whisper/Whisper/ContextImpl.cpp6
-rw-r--r--Whisper/Whisper/DecoderInputBuffers.cpp11
-rw-r--r--Whisper/Whisper/DecoderInputBuffers.h4
-rw-r--r--Whisper/Whisper/KeyValueBuffers.cpp28
-rw-r--r--Whisper/Whisper/KeyValueBuffers.h4
-rw-r--r--Whisper/Whisper/WhisperContext.cpp19
-rw-r--r--Whisper/Whisper/WhisperContext.h10
12 files changed, 125 insertions, 1 deletions
diff --git a/Examples/TranscribeCS/TranscribeCS.cs b/Examples/TranscribeCS/TranscribeCS.cs
index d94ed21..65239f1 100644
--- a/Examples/TranscribeCS/TranscribeCS.cs
+++ b/Examples/TranscribeCS/TranscribeCS.cs
@@ -25,6 +25,8 @@ namespace TranscribeCS
using iModel model = Library.loadModel( cla.model );
using Context context = model.createContext();
cla.apply( ref context.parameters );
+ // When there're multiple input files, assuming they're independent clips
+ context.parameters.setFlag( eFullParamsFlags.NoContext, true );
using iMediaFoundation mf = Library.initMediaFoundation();
Transcribe transcribe = new Transcribe( cla );
diff --git a/Examples/main/main.cpp b/Examples/main/main.cpp
index 88ddc6d..7706b8f 100644
--- a/Examples/main/main.cpp
+++ b/Examples/main/main.cpp
@@ -246,6 +246,8 @@ int wmain( int argc, wchar_t* argv[] )
wparams.setFlag( eFullParamsFlags::PrintTimestamps, !params.no_timestamps );
wparams.setFlag( eFullParamsFlags::PrintSpecial, params.print_special );
wparams.setFlag( eFullParamsFlags::Translate, params.translate );
+ // When there're multiple input files, assuming they're independent clips
+ wparams.setFlag( eFullParamsFlags::NoContext );
wparams.language = Whisper::makeLanguageKey( params.language.c_str() );
wparams.cpuThreads = params.n_threads;
if( params.max_context != UINT_MAX )
diff --git a/Whisper/ML/MlContext.h b/Whisper/ML/MlContext.h
index fe0d48d..aea1f24 100644
--- a/Whisper/ML/MlContext.h
+++ b/Whisper/ML/MlContext.h
@@ -43,6 +43,8 @@ namespace DirectCompute
GpuProfiler profiler;
+ CComPtr<ID3D11Buffer>& getSmallConstantBuffer() { return temp.smallCb; }
+
public:
MlContext( Whisper::ProfileCollection& profileColl );
MlContext( const MlContext& ) = delete;
diff --git a/Whisper/ML/TensorsArena.cpp b/Whisper/ML/TensorsArena.cpp
index 8e5d449..873200b 100644
--- a/Whisper/ML/TensorsArena.cpp
+++ b/Whisper/ML/TensorsArena.cpp
@@ -1,6 +1,7 @@
#include "stdafx.h"
#include "TensorsArena.h"
#include "../D3D/createBuffer.h"
+#include "TempBuffers.h"
static inline uint32_t roundUpPower2( uint32_t x )
{
@@ -132,4 +133,33 @@ __m128i TensorsArena::getMemoryUse() const
for( const auto& a : arenas )
res = _mm_add_epi64( res, a.getMemoryUse() );
return res;
+}
+
+HRESULT PooledTensor::zeroMemory( CComPtr<ID3D11Buffer>& cb )
+{
+ if( 0 == capacity )
+ return S_FALSE;
+ try
+ {
+ TempBuffers::zeroMemory( views, capacity, cb );
+ return S_OK;
+ }
+ catch( HRESULT hr )
+ {
+ return hr;
+ }
+}
+
+HRESULT TensorsArena::ArenaImpl::zeroMemory( CComPtr<ID3D11Buffer>& cb )
+{
+ for( PooledTensor& e : pool )
+ CHECK( e.zeroMemory( cb ) );
+ return S_OK;
+}
+
+HRESULT TensorsArena::zeroMemory( CComPtr<ID3D11Buffer>& cb )
+{
+ for( ArenaImpl& e : arenas )
+ CHECK( e.zeroMemory( cb ) );
+ return S_OK;
} \ No newline at end of file
diff --git a/Whisper/ML/TensorsArena.h b/Whisper/ML/TensorsArena.h
index acfaf86..df6139e 100644
--- a/Whisper/ML/TensorsArena.h
+++ b/Whisper/ML/TensorsArena.h
@@ -14,6 +14,12 @@ namespace DirectCompute
public:
Tensor tensor( eDataType type, const std::array<uint32_t, 4>& ne, pfnNewCapacity pfnNewCap );
size_t getCapacity() const { return capacity; }
+ void clear()
+ {
+ views.clear();
+ capacity = 0;
+ }
+ HRESULT zeroMemory( CComPtr<ID3D11Buffer>& cb );
};
__interface iTensorArena
@@ -43,6 +49,7 @@ namespace DirectCompute
void clear();
__m128i getMemoryUse() const;
+ HRESULT zeroMemory( CComPtr<ID3D11Buffer>& cb );
private:
@@ -63,6 +70,7 @@ namespace DirectCompute
Tensor tensor( const std::array<uint32_t, 4>& ne );
__m128i getMemoryUse() const;
+ HRESULT zeroMemory( CComPtr<ID3D11Buffer>& cb );
private:
diff --git a/Whisper/Whisper/ContextImpl.cpp b/Whisper/Whisper/ContextImpl.cpp
index eccda35..5a58e5f 100644
--- a/Whisper/Whisper/ContextImpl.cpp
+++ b/Whisper/Whisper/ContextImpl.cpp
@@ -274,6 +274,12 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
// Start measuring "Run" profiler value, both CPU and GPU times
auto prof = context.completeProfiler();
bool stoppedPrematurely = false;
+
+ if( params.flag( eFullParamsFlags::NoContext ) )
+ {
+ CHECK( context.clearState() );
+ }
+
while( true )
{
if( nullptr != progress.pfn )
diff --git a/Whisper/Whisper/DecoderInputBuffers.cpp b/Whisper/Whisper/DecoderInputBuffers.cpp
index 68d3cec..54b4dd5 100644
--- a/Whisper/Whisper/DecoderInputBuffers.cpp
+++ b/Whisper/Whisper/DecoderInputBuffers.cpp
@@ -63,4 +63,15 @@ void DecoderInputBuffers::clear()
embd = nullptr;
m_size = 0;
m_capacity = 0;
+}
+
+HRESULT DecoderInputBuffers::zeroMemory() const
+{
+ if( nullptr == embd || m_size == 0 )
+ return S_FALSE;
+
+ MappedResource mapped;
+ CHECK( mapped.map( embd, false ) );
+ __stosd( (DWORD*)mapped.data(), 0, m_capacity );
+ return S_OK;
} \ No newline at end of file
diff --git a/Whisper/Whisper/DecoderInputBuffers.h b/Whisper/Whisper/DecoderInputBuffers.h
index 9ce8f75..01b9011 100644
--- a/Whisper/Whisper/DecoderInputBuffers.h
+++ b/Whisper/Whisper/DecoderInputBuffers.h
@@ -3,7 +3,7 @@
namespace DirectCompute
{
- // Two dynamic buffers
+ // A dynamic buffer
class DecoderInputBuffers
{
CComPtr<ID3D11Buffer> embd;
@@ -25,5 +25,7 @@ namespace DirectCompute
i *= sizeof( uint32_t );
return _mm_set_epi64x( (int64_t)i, 0 );
}
+
+ HRESULT zeroMemory() const;
};
} \ No newline at end of file
diff --git a/Whisper/Whisper/KeyValueBuffers.cpp b/Whisper/Whisper/KeyValueBuffers.cpp
index b932fdb..0d5772d 100644
--- a/Whisper/Whisper/KeyValueBuffers.cpp
+++ b/Whisper/Whisper/KeyValueBuffers.cpp
@@ -1,6 +1,7 @@
#include "stdafx.h"
#include "KeyValueBuffers.h"
#include "../D3D/createBuffer.h"
+#include "../ML/TempBuffers.h"
using namespace DirectCompute;
void AttentionBuffer::resize( uint32_t size )
@@ -39,4 +40,31 @@ void KeyValueBuffers::resize( uint32_t size )
{
keys.resize( size );
values.resize( size );
+}
+
+HRESULT AttentionBuffer::zeroMemory( CComPtr<ID3D11Buffer>& cb ) const
+{
+ if( 0 == m_size )
+ return S_FALSE;
+
+ CComPtr<ID3D11UnorderedAccessView> uav;
+ CD3D11_UNORDERED_ACCESS_VIEW_DESC uavDesc{ D3D11_UAV_DIMENSION_BUFFER, DXGI_FORMAT_R16_FLOAT, 0, m_size };
+ check( device()->CreateUnorderedAccessView( buffer, &uavDesc, &uav ) );
+
+ try
+ {
+ TempBuffers::zeroMemory( uav, m_size, cb );
+ return S_OK;
+ }
+ catch( HRESULT hr )
+ {
+ return hr;
+ }
+}
+
+HRESULT KeyValueBuffers::zeroMemory( CComPtr<ID3D11Buffer>& cb ) const
+{
+ CHECK( keys.zeroMemory( cb ) );
+ CHECK( values.zeroMemory( cb ) );
+ return S_OK;
} \ No newline at end of file
diff --git a/Whisper/Whisper/KeyValueBuffers.h b/Whisper/Whisper/KeyValueBuffers.h
index 9c737be..6bebfda 100644
--- a/Whisper/Whisper/KeyValueBuffers.h
+++ b/Whisper/Whisper/KeyValueBuffers.h
@@ -25,6 +25,8 @@ namespace DirectCompute
ID3D11Buffer* getBuffer() const { return buffer; }
uint32_t getSize() const { return m_size; }
+
+ HRESULT zeroMemory( CComPtr<ID3D11Buffer>& cb ) const;
};
struct KeyValueBuffers
@@ -46,5 +48,7 @@ namespace DirectCompute
i *= sizeof( uint16_t );
return setHigh_size( (int64_t)i ); // They both are in VRAM
}
+
+ HRESULT zeroMemory( CComPtr<ID3D11Buffer>& cb ) const;
};
} \ No newline at end of file
diff --git a/Whisper/Whisper/WhisperContext.cpp b/Whisper/Whisper/WhisperContext.cpp
index a3a85b3..1aa7048 100644
--- a/Whisper/Whisper/WhisperContext.cpp
+++ b/Whisper/Whisper/WhisperContext.cpp
@@ -663,4 +663,23 @@ __m128i WhisperContext::getMemoryUse() const
res = _mm_add_epi64( res, decoderInput.getMemoryUse() );
res = _mm_add_epi64( res, decoderOutput.getMemoryUse() );
return res;
+}
+
+HRESULT WhisperContext::clearState()
+{
+ CComPtr<ID3D11Buffer>& cb = getSmallConstantBuffer();
+
+ // CHECK( kv.zeroMemory( cb ) );
+ // CHECK( kvCross.zeroMemory( cb ) );
+ // The above code doesn't work for some reason.
+ // Ideally need to debug, but destroying and re-creating these two buffers is not a huge deal. Unlike the buffers in the pools, only a few megabytes of VRAM.
+ kv.clear();
+ kvCross.clear();
+
+ CHECK( arenas.outer.zeroMemory( cb ) );
+ CHECK( arenas.layer.zeroMemory( cb ) );
+ CHECK( decPool.zeroMemory( cb ) );
+ CHECK( decoderInput.zeroMemory() );
+ decoderOutput.clear();
+ return S_OK;
} \ No newline at end of file
diff --git a/Whisper/Whisper/WhisperContext.h b/Whisper/Whisper/WhisperContext.h
index 0e4319f..de0e987 100644
--- a/Whisper/Whisper/WhisperContext.h
+++ b/Whisper/Whisper/WhisperContext.h
@@ -39,6 +39,14 @@ namespace DirectCompute
Tensor tensor( eDataType type, const std::array<uint32_t, 4>& ne ) override final;
void reset() override final { }
__m128i getMemoryUse() const;
+ void clear()
+ {
+ result.clear();
+ }
+ HRESULT zeroMemory( CComPtr<ID3D11Buffer>& cb )
+ {
+ return result.zeroMemory( cb );
+ }
};
DecoderLayerPool decPool;
@@ -119,5 +127,7 @@ namespace DirectCompute
}
__m128i getMemoryUse() const;
+
+ HRESULT clearState();
};
} \ No newline at end of file