diff options
| author | yum <yum.food.vr@gmail.com> | 2023-03-02 17:43:26 -0800 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2023-03-02 17:43:26 -0800 |
| commit | dcd7f3b60e3b9ad8df83d444f8bc67091b411529 (patch) | |
| tree | 177fa9358b2a160f6247fb074d50e624322b0f4a | |
| parent | 4f967dbdb7972ec52039bd3e3ce3e1e4cbcf6545 (diff) | |
Continue work on beam search
Define ContextImpl::Context, wrapping all the data used in decoding.
Using a vector of these is much simpler than using N vectors of all
the random stuff we need.
| -rw-r--r-- | Whisper/Whisper/ContextImpl.capture.cpp | 2 | ||||
| -rw-r--r-- | Whisper/Whisper/ContextImpl.cpp | 148 | ||||
| -rw-r--r-- | Whisper/Whisper/ContextImpl.h | 23 | ||||
| -rw-r--r-- | Whisper/Whisper/ContextImpl.misc.cpp | 4 |
4 files changed, 72 insertions, 105 deletions
diff --git a/Whisper/Whisper/ContextImpl.capture.cpp b/Whisper/Whisper/ContextImpl.capture.cpp index 0100fcd..642eef1 100644 --- a/Whisper/Whisper/ContextImpl.capture.cpp +++ b/Whisper/Whisper/ContextImpl.capture.cpp @@ -65,7 +65,7 @@ namespace __m128i ints = _mm_cvtps_epi32( floats ); store16( &minDuration, ints ); - retainDuration = std::round(retainDuration * SAMPLE_RATE); + retainDuration = (int) (retainDuration * SAMPLE_RATE + 0.5); flags = cp.flags; } diff --git a/Whisper/Whisper/ContextImpl.cpp b/Whisper/Whisper/ContextImpl.cpp index d6d46f9..5428408 100644 --- a/Whisper/Whisper/ContextImpl.cpp +++ b/Whisper/Whisper/ContextImpl.cpp @@ -57,7 +57,7 @@ HRESULT ContextImpl::decode( const int* tokens, size_t length, int n_past, int t try { - context.decode( tokens, (int)length, dp, probs_vec[nth], threads); + context.decode( tokens, (int)length, dp, ctx_[nth].probs, threads); return S_OK; } catch( HRESULT hr ) @@ -76,18 +76,18 @@ std::vector<sTokenData> ContextImpl::sampleBestN( const float* probs, size_t n_logits = vocab.size(); - probs_id_vec[nth].clear(); - probs_id_vec[nth].reserve(n_logits); + ctx_[nth].probs_id.clear(); + ctx_[nth].probs_id.reserve(n_logits); for( size_t i = 0; i < n_logits; i++ ) - probs_id_vec[nth].emplace_back(probs[i], (int)i); + ctx_[nth].probs_id.emplace_back(probs[i], (int)i); { double sum_ts = 0.0; double max_ts = -1.0; double max_tx = -1.0; for( int i = 0; i < vocab.token_beg; i++ ) - max_tx = std::max( max_tx, probs_id_vec[ nth ][ i ].first ); + max_tx = std::max( max_tx, ctx_[ nth ].probs_id[ i ].first ); const int i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg; const int i1 = is_initial ? vocab.token_beg + 101 : (int)n_logits; @@ -97,17 +97,17 @@ std::vector<sTokenData> ContextImpl::sampleBestN( const float* probs, if( is_initial ) { for( int i = i0; i < n_logits; i++ ) - probs_id_vec[ nth ][ i ].first = -INFINITY; + ctx_[nth].probs_id[i].first = -INFINITY; } for( int i = vocab.token_beg; i < i1; i++ ) { - sum_ts += probs_id_vec[ nth ][ i ].first; - if( probs_id_vec[ nth ][ i ].first > max_ts ) + sum_ts += ctx_[nth].probs_id[i].first; + if (ctx_[nth].probs_id[i].first > max_ts) { - max_ts = probs_id_vec[ nth ][ i ].first; + max_ts = ctx_[nth].probs_id[i].first; for (auto& result : result_vec) { - result.tid = probs_id_vec[nth][i].second; + result.tid = ctx_[nth].probs_id[i].second; } } } @@ -118,7 +118,7 @@ std::vector<sTokenData> ContextImpl::sampleBestN( const float* probs, { // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438 for( int i = 0; i < vocab.token_beg; i++ ) - probs_id_vec[ nth ][ i ].first = -INFINITY; + ctx_[nth].probs_id[ i ].first = -INFINITY; } for (auto& result : result_vec) { @@ -131,13 +131,13 @@ std::vector<sTokenData> ContextImpl::sampleBestN( const float* probs, const int top_k = 4 + n_best - 1; std::partial_sort( - probs_id_vec[nth].begin(), - probs_id_vec[nth].begin() + top_k, probs_id_vec[nth].end(), + ctx_[nth].probs_id.begin(), + ctx_[nth].probs_id.begin() + top_k, ctx_[nth].probs_id.end(), []( const std::pair<double, Vocabulary::id>& a, const std::pair<double, Vocabulary::id>& b ) { return a.first > b.first; } ); - probs_id_vec[nth].resize(top_k); + ctx_[nth].probs_id.resize(top_k); //printf("\n"); //for (int i = 0; i < (int) probs_id_vec[nth].size(); i++) { @@ -150,10 +150,10 @@ std::vector<sTokenData> ContextImpl::sampleBestN( const float* probs, int i = 0; for (int j = 0; j < n_best; j++) { // Scan past unwanted tokens. - while ((probs_id_vec[nth][i].second == vocab.token_sot || - probs_id_vec[nth][i].second == vocab.token_solm || - probs_id_vec[nth][i].second == vocab.token_not) && - i < (int)probs_id_vec[nth].size() - 1) + while ((ctx_[nth].probs_id[i].second == vocab.token_sot || + ctx_[nth].probs_id[i].second == vocab.token_solm || + ctx_[nth].probs_id[i].second == vocab.token_not) && + i < (int)ctx_[nth].probs_id.size() - 1) { i++; } @@ -162,8 +162,8 @@ std::vector<sTokenData> ContextImpl::sampleBestN( const float* probs, assert(result_vec.size() == res_vec.size()); for (int i = 0; i < res_vec.size(); i++) { - result_vec[i].id = probs_id_vec[nth][res_vec[i]].second; - result_vec[i].p = (float)probs_id_vec[nth][res_vec[i]].first; + result_vec[i].id = ctx_[nth].probs_id[res_vec[i]].second; + result_vec[i].p = (float)ctx_[nth].probs_id[res_vec[i]].first; } return result_vec; @@ -173,8 +173,8 @@ std::vector<sTokenData> ContextImpl::sampleBestN(int nth, int n_best) { const int n_vocab = model.vocab.n_vocab; return sampleBestN( - probs_vec[nth].data() + (probs_vec[nth].size() - n_vocab), - false, false, nth, /*n_best=*/1); + ctx_[nth].probs.data() + (ctx_[nth].probs.size() - n_vocab), + false, false, nth, /*n_best=*/n_best); } std::vector<sTokenData> ContextImpl::sampleTimestampN(bool initial, int nth, @@ -182,7 +182,7 @@ std::vector<sTokenData> ContextImpl::sampleTimestampN(bool initial, int nth, { const int n_vocab = model.vocab.n_vocab; return sampleBestN( - probs_vec[nth].data() + (probs_vec[nth].size() - n_vocab), + ctx_[nth].probs.data() + (ctx_[nth].probs.size() - n_vocab), true, initial, nth, n_best); } @@ -486,19 +486,15 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const if( seek_end < 100 + seek_start ) return S_FALSE; - if (prompt_past_vec.empty()) { - if (params.strategy == Whisper::eSamplingStrategy::Greedy) { - prompt_past_vec = std::vector<prompt_t>(1); - } - else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) { - prompt_past_vec = std::vector<prompt_t>(params.beam_search.beam_width); - } + const int n_ctxt = params.strategy == Whisper::eSamplingStrategy::Greedy ? 1 : params.beam_search.beam_width; + if (ctx_.size() != n_ctxt) { + ctx_.assign(n_ctxt, Context()); } // the accumulated text context so far if (params.flag(eFullParamsFlags::NoContext)) { - for (auto& prompt_past : prompt_past_vec) { - prompt_past.clear(); + for (auto& ctx : ctx_) { + ctx.prompt_past.clear(); } } @@ -506,14 +502,20 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const if( params.prompt_tokens && params.prompt_n_tokens > 0 ) { // parse tokens from the pointer - for (int i = 0; i < params.prompt_n_tokens; i++) { - for (auto& prompt_past : prompt_past_vec) { - prompt_past.push_back(params.prompt_tokens[i]); + for (auto& ctx : ctx_) { + for (int i = 0; i < params.prompt_n_tokens; i++) { + ctx.prompt_past.push_back(params.prompt_tokens[i]); } + std::rotate( + ctx.prompt_past.begin(), + ctx.prompt_past.end() - params.prompt_n_tokens, + ctx.prompt_past.end()); } - for (auto& prompt_past : prompt_past_vec) { - std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); - } + } + + for (auto& ctx : ctx_) { + ctx.loop_ctx.prompt.reserve(model.parameters.n_text_ctx); + ctx.loop_ctx.tokens_cur.reserve(model.parameters.n_text_ctx); } // overwrite audio_ctx @@ -543,25 +545,6 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const // int progress_prev = 0; // int progress_step = 5; - typedef std::vector<sTokenData> tokens_t; - std::vector<tokens_t> tokens_cur_vec; - typedef std::vector<whisper_token> prompt_t; - std::vector<prompt_t> prompt_vec; - if (params.strategy == Whisper::eSamplingStrategy::Greedy) { - prompt_vec = std::vector<prompt_t>(1); - tokens_cur_vec = std::vector<tokens_t>(1); - } - else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) { - prompt_vec = std::vector<prompt_t>(params.beam_search.beam_width); - tokens_cur_vec = std::vector<tokens_t>(params.beam_search.beam_width); - } - for (auto& prompt : prompt_vec) { - prompt.reserve(model.parameters.n_text_ctx); - } - for (auto& tokens_cur : tokens_cur_vec) { - tokens_cur.reserve(model.parameters.n_text_ctx); - } - // main loop int seek = seek_start; // Start measuring "Run" profiler value, both CPU and GPU times @@ -603,23 +586,13 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const // encode audio features starting at offset seek CHECK( encode( mel, seek ) ); - std::vector<int> n_past_vec; - if (params.strategy == Whisper::eSamplingStrategy::Greedy) { - n_past_vec = std::vector<int>(1, 0); - } - else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) { - n_past_vec = std::vector<int>(params.beam_search.beam_width, 0); - } + for (auto& ctx : ctx_) { + // if we have already generated some text, use it as a prompt to condition the next generation + auto& prompt_past = ctx.prompt_past; + auto& prompt = ctx.loop_ctx.prompt; - for (auto& prompt : prompt_vec) { prompt.clear(); - } - // if we have already generated some text, use it as a prompt to condition the next generation - assert(prompt_past_vec.size() == prompt_vec.size()); - for (int i = 0; i < prompt_past_vec.size(); i++) { - auto& prompt_past = prompt_past_vec[i]; - auto& prompt = prompt_vec[i]; if (!prompt_past.empty()) { int n_take = std::min(std::min(params.n_max_text_ctx, model.parameters.n_text_ctx / 2), int(prompt_past.size())); @@ -630,10 +603,10 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const prompt_past.clear(); prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); } - } - for (auto& prompt : prompt_vec) { prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); + ctx.loop_ctx.tokens_cur.clear(); + ctx.loop_ctx.result_len = 0; } int seek_delta = 100 * WHISPER_CHUNK_SIZE; @@ -645,19 +618,6 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const //} //printf("\n\n"); - // the accumulated transcription in the current iteration - std::vector<int> result_len_vec; - if (params.strategy == Whisper::eSamplingStrategy::Greedy) { - result_len_vec = std::vector<int>(1, 0); - } - else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) { - result_len_vec = std::vector<int>(params.beam_search.beam_width, 0); - } - - for (auto& tokens_cur : tokens_cur_vec) { - tokens_cur.clear(); - } - bool failed = false; bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? @@ -667,10 +627,10 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const for( int i = 0, n_max = model.parameters.n_text_ctx / 2 - 4; i < n_max; i++ ) { if (params.strategy == Whisper::eSamplingStrategy::Greedy) { - auto& prompt = prompt_vec[0]; - auto& tokens_cur = tokens_cur_vec[0]; - auto& n_past = n_past_vec[0]; - auto& result_len = result_len_vec[0]; + auto& prompt = ctx_[0].loop_ctx.prompt; + auto& tokens_cur = ctx_[0].loop_ctx.tokens_cur; + auto& n_past = ctx_[0].loop_ctx.n_past; + auto& result_len = ctx_[0].loop_ctx.result_len; CHECK(decode(prompt.data(), prompt.size(), n_past, params.cpuThreads, /*nth=*/0)); @@ -768,9 +728,9 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const // shrink down to result_len // TODO assign these based on results of either decoding strategy - auto& tokens_cur = tokens_cur_vec[0]; - auto& result_len = result_len_vec[0]; - auto& prompt_past = prompt_past_vec[0]; + auto& tokens_cur = ctx_[0].loop_ctx.tokens_cur; + auto& result_len = ctx_[0].loop_ctx.result_len; + auto& prompt_past = ctx_[0].prompt_past; tokens_cur.resize( result_len ); for( const auto& r : tokens_cur ) @@ -882,4 +842,4 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const CHECK( progress.pfn( 1.0, this, progress.pv ) ); } return S_OK; -}
\ No newline at end of file +} diff --git a/Whisper/Whisper/ContextImpl.h b/Whisper/Whisper/ContextImpl.h index cdad36b..33d3007 100644 --- a/Whisper/Whisper/ContextImpl.h +++ b/Whisper/Whisper/ContextImpl.h @@ -38,8 +38,22 @@ namespace Whisper }; std::vector<Segment> result_all; - typedef std::vector<whisper_token> prompt_t; - std::vector<prompt_t> prompt_past_vec; + struct Context { + std::vector<whisper_token> prompt_past; + std::vector<float> probs; + std::vector<std::pair<double, Vocabulary::id>> probs_id; + // These are cleared on every frame of audio processed. + struct AudioFrameContext { + std::vector<whisper_token> prompt; + std::vector<sTokenData> tokens_cur; + std::vector<float> joint_probs; + + int result_len = 0; + int n_past = 0; + }; + AudioFrameContext loop_ctx; + }; + std::vector<Context> ctx_; // [EXPERIMENTAL] token-level timestamps data int64_t t_beg = 0; @@ -58,11 +72,6 @@ namespace Whisper int wrapSegment( int max_len ); void expComputeTokenLevelTimestamps( int i_segment, float thold_pt, float thold_ptsum ); - typedef std::vector<float> probs_t; - std::vector<probs_t> probs_vec{ 1 }; - typedef std::vector<std::pair<double, Vocabulary::id>> probs_id_t; - std::vector<probs_id_t> probs_id_vec{ 1 }; - mutable TranscribeResultStatic results; HRESULT COMLIGHTCALL makeResults( eResultFlags flags, TranscribeResult& res ) const noexcept; diff --git a/Whisper/Whisper/ContextImpl.misc.cpp b/Whisper/Whisper/ContextImpl.misc.cpp index 90a1bec..52ef811 100644 --- a/Whisper/Whisper/ContextImpl.misc.cpp +++ b/Whisper/Whisper/ContextImpl.misc.cpp @@ -114,10 +114,8 @@ __m128i ContextImpl::getMemoryUse() const size_t cb = vectorMemoryUse( result_all ); for( const auto& r : result_all ) cb += r.memoryUsage(); - cb += vectorMemoryUse( prompt_past_vec ); + cb += vectorMemoryUse( ctx_ ); cb += vectorMemoryUse( energy ); - cb += vectorMemoryUse( probs_vec ); - cb += vectorMemoryUse( probs_id_vec ); cb += vectorMemoryUse( results.segments ); cb += vectorMemoryUse( results.tokens ); cb += spectrogram.memoryUsage(); |
