diff options
| author | yum <yum.food.vr@gmail.com> | 2023-02-27 18:03:02 -0800 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2023-02-27 18:03:02 -0800 |
| commit | 4f967dbdb7972ec52039bd3e3ce3e1e4cbcf6545 (patch) | |
| tree | 158cf93f6ec626f32b961acc6d791bd2a2040b94 | |
| parent | 1136acfc365f357d2df13a263714e8ae0614c4f9 (diff) | |
Begin work on beam search decoding
* ContextImpl.h puts prompts, previous prompts, probabilities, and
probability IDs into vectors of size 1 or N_BEAMS, depending on
the decoding strategy.
* Extend sampleBest and friends to return top N tokens, instead of
just the top 1 token.
| -rw-r--r-- | Designs/.BeamSearch.md.swp | bin | 0 -> 12288 bytes | |||
| -rw-r--r-- | Designs/BeamSearch.md | 33 | ||||
| -rw-r--r-- | Whisper/Whisper/ContextImpl.cpp | 252 | ||||
| -rw-r--r-- | Whisper/Whisper/ContextImpl.h | 17 | ||||
| -rw-r--r-- | Whisper/Whisper/ContextImpl.misc.cpp | 6 |
5 files changed, 220 insertions, 88 deletions
diff --git a/Designs/.BeamSearch.md.swp b/Designs/.BeamSearch.md.swp Binary files differnew file mode 100644 index 0000000..8102d31 --- /dev/null +++ b/Designs/.BeamSearch.md.swp diff --git a/Designs/BeamSearch.md b/Designs/BeamSearch.md new file mode 100644 index 0000000..4902b45 --- /dev/null +++ b/Designs/BeamSearch.md @@ -0,0 +1,33 @@ +This is the design for the beam search decoding algorithm. + +ContextImpl.cpp is where greedy decoding is implemented. + +sampleBest() does the following: + +1. Populate `probs_ids` with (probability, idx). +2. Find the `top_k` most likely tokens. +3. Filter out tokens matching `token_sot`, `token_solm`, `token_not`. + 3.1. sot == start of text, solm = ???, not == not timestamps +4. Return the result. + +We should modify this to return the most likely N tokens. We'll call this +sampleBestN(). + +Greedy search works like this: + +1. Call decode(prompt, tokens). + 1.1. This populates `probs`. Need to wrap this in a vector of size `N_BEAMS`. + 1.2. Need to wrap `prompt` in a vector of size `N_BEAMS`. +2. Extract the most likely token and add it to the prompt. +3. Repeat until EOT (end of transcript) or max tokens or end of audio stream. + +Beam search will work like this: + +1. Initialize `prompts` as a vector containing `prompt`. +2. Call decode(prompt, tokens) for each prompt in `prompts`. +3. Extract the most likely `N_BEAMS` tokens in each result. +4. Compute joint probabilities for each token. +5. Extract the most likely `N_BEAMS` prompts using joint probabilities. +6. Update `prompts`. +7. Repeat until end condition is reached for all prompts. +8. Return most likely prompt. diff --git a/Whisper/Whisper/ContextImpl.cpp b/Whisper/Whisper/ContextImpl.cpp index 1160f95..d6d46f9 100644 --- a/Whisper/Whisper/ContextImpl.cpp +++ b/Whisper/Whisper/ContextImpl.cpp @@ -42,7 +42,7 @@ HRESULT ContextImpl::encode( iSpectrogram& mel, int seek ) } } -HRESULT ContextImpl::decode( const int* tokens, size_t length, int n_past, int threads ) +HRESULT ContextImpl::decode( const int* tokens, size_t length, int n_past, int threads, int nth ) { // whisper_decode using namespace DirectCompute; @@ -57,7 +57,7 @@ HRESULT ContextImpl::decode( const int* tokens, size_t length, int n_past, int t try { - context.decode( tokens, (int)length, dp, probs, threads ); + context.decode( tokens, (int)length, dp, probs_vec[nth], threads); return S_OK; } catch( HRESULT hr ) @@ -67,27 +67,27 @@ HRESULT ContextImpl::decode( const int* tokens, size_t length, int n_past, int t } // the most basic sampling scheme - select the top token -sTokenData ContextImpl::sampleBest( const float* probs, bool force_timestamp, bool is_initial ) +std::vector<sTokenData> ContextImpl::sampleBestN( const float* probs, + bool force_timestamp, bool is_initial, int nth, int n_best ) { // whisper_sample_best const Vocabulary& vocab = model.vocab; - sTokenData result = { 0 }; + std::vector<sTokenData> result_vec(n_best, { 0 }); size_t n_logits = vocab.size(); - probs_id.clear(); - probs_id.reserve( n_logits ); + probs_id_vec[nth].clear(); + probs_id_vec[nth].reserve(n_logits); for( size_t i = 0; i < n_logits; i++ ) - probs_id.emplace_back( probs[ i ], (int)i ); - + probs_id_vec[nth].emplace_back(probs[i], (int)i); { double sum_ts = 0.0; double max_ts = -1.0; double max_tx = -1.0; for( int i = 0; i < vocab.token_beg; i++ ) - max_tx = std::max( max_tx, probs_id[ i ].first ); + max_tx = std::max( max_tx, probs_id_vec[ nth ][ i ].first ); const int i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg; const int i1 = is_initial ? vocab.token_beg + 101 : (int)n_logits; @@ -97,16 +97,18 @@ sTokenData ContextImpl::sampleBest( const float* probs, bool force_timestamp, bo if( is_initial ) { for( int i = i0; i < n_logits; i++ ) - probs_id[ i ].first = -INFINITY; + probs_id_vec[ nth ][ i ].first = -INFINITY; } for( int i = vocab.token_beg; i < i1; i++ ) { - sum_ts += probs_id[ i ].first; - if( probs_id[ i ].first > max_ts ) + sum_ts += probs_id_vec[ nth ][ i ].first; + if( probs_id_vec[ nth ][ i ].first > max_ts ) { - max_ts = probs_id[ i ].first; - result.tid = probs_id[ i ].second; + max_ts = probs_id_vec[ nth ][ i ].first; + for (auto& result : result_vec) { + result.tid = probs_id_vec[nth][i].second; + } } } @@ -116,55 +118,72 @@ sTokenData ContextImpl::sampleBest( const float* probs, bool force_timestamp, bo { // ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438 for( int i = 0; i < vocab.token_beg; i++ ) - probs_id[ i ].first = -INFINITY; + probs_id_vec[ nth ][ i ].first = -INFINITY; } - result.pt = (float)( max_ts / ( sum_ts + 1e-10 ) ); - result.ptsum = (float)sum_ts; + for (auto& result : result_vec) { + result.pt = (float)(max_ts / (sum_ts + 1e-10)); + result.ptsum = (float)sum_ts; + } } // find the top K tokens - const int top_k = 4; + const int top_k = 4 + n_best - 1; std::partial_sort( - probs_id.begin(), - probs_id.begin() + top_k, probs_id.end(), + probs_id_vec[nth].begin(), + probs_id_vec[nth].begin() + top_k, probs_id_vec[nth].end(), []( const std::pair<double, Vocabulary::id>& a, const std::pair<double, Vocabulary::id>& b ) { return a.first > b.first; } ); - probs_id.resize( top_k ); + probs_id_vec[nth].resize(top_k); //printf("\n"); - //for (int i = 0; i < (int) probs_id.size(); i++) { - // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second); + //for (int i = 0; i < (int) probs_id_vec[nth].size(); i++) { + // printf("%d: '%s' %f, %d\n", i, + // vocab.id_to_token.at(probs_id_vec[nth][i].second).c_str(), + // probs_id_vec[nth][i].first, probs_id_vec[nth][i].second); //} - int res = 0; - while( ( probs_id[ res ].second == vocab.token_sot || - probs_id[ res ].second == vocab.token_solm || - probs_id[ res ].second == vocab.token_not ) && - res < (int)probs_id.size() - 1 ) - { - res++; + std::vector<int> res_vec(n_best, 0); + int i = 0; + for (int j = 0; j < n_best; j++) { + // Scan past unwanted tokens. + while ((probs_id_vec[nth][i].second == vocab.token_sot || + probs_id_vec[nth][i].second == vocab.token_solm || + probs_id_vec[nth][i].second == vocab.token_not) && + i < (int)probs_id_vec[nth].size() - 1) + { + i++; + } + res_vec[j] = i; } - result.id = probs_id[ res ].second; - result.p = (float)probs_id[ res ].first; + assert(result_vec.size() == res_vec.size()); + for (int i = 0; i < res_vec.size(); i++) { + result_vec[i].id = probs_id_vec[nth][res_vec[i]].second; + result_vec[i].p = (float)probs_id_vec[nth][res_vec[i]].first; + } - return result; + return result_vec; } -sTokenData ContextImpl::sampleBest() +std::vector<sTokenData> ContextImpl::sampleBestN(int nth, int n_best) { const int n_vocab = model.vocab.n_vocab; - return sampleBest( probs.data() + ( probs.size() - n_vocab ), false, false ); + return sampleBestN( + probs_vec[nth].data() + (probs_vec[nth].size() - n_vocab), + false, false, nth, /*n_best=*/1); } -sTokenData ContextImpl::sampleTimestamp( bool initial ) +std::vector<sTokenData> ContextImpl::sampleTimestampN(bool initial, int nth, + int n_best) { const int n_vocab = model.vocab.n_vocab; - return sampleBest( probs.data() + ( probs.size() - n_vocab ), true, initial ); + return sampleBestN( + probs_vec[nth].data() + (probs_vec[nth].size() - n_vocab), + true, initial, nth, n_best); } // a cost-function / heuristic that is high for text that takes longer to pronounce @@ -467,17 +486,34 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const if( seek_end < 100 + seek_start ) return S_FALSE; + if (prompt_past_vec.empty()) { + if (params.strategy == Whisper::eSamplingStrategy::Greedy) { + prompt_past_vec = std::vector<prompt_t>(1); + } + else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) { + prompt_past_vec = std::vector<prompt_t>(params.beam_search.beam_width); + } + } + // the accumulated text context so far - if( params.flag( eFullParamsFlags::NoContext ) ) - prompt_past.clear(); + if (params.flag(eFullParamsFlags::NoContext)) { + for (auto& prompt_past : prompt_past_vec) { + prompt_past.clear(); + } + } // prepend the prompt tokens to the prompt_past if( params.prompt_tokens && params.prompt_n_tokens > 0 ) { // parse tokens from the pointer - for( int i = 0; i < params.prompt_n_tokens; i++ ) - prompt_past.push_back( params.prompt_tokens[ i ] ); - std::rotate( prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end() ); + for (int i = 0; i < params.prompt_n_tokens; i++) { + for (auto& prompt_past : prompt_past_vec) { + prompt_past.push_back(params.prompt_tokens[i]); + } + } + for (auto& prompt_past : prompt_past_vec) { + std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end()); + } } // overwrite audio_ctx @@ -507,10 +543,24 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const // int progress_prev = 0; // int progress_step = 5; - std::vector<sTokenData> tokens_cur; - tokens_cur.reserve( model.parameters.n_text_ctx ); - std::vector<whisper_token> prompt; - prompt.reserve( model.parameters.n_text_ctx ); + typedef std::vector<sTokenData> tokens_t; + std::vector<tokens_t> tokens_cur_vec; + typedef std::vector<whisper_token> prompt_t; + std::vector<prompt_t> prompt_vec; + if (params.strategy == Whisper::eSamplingStrategy::Greedy) { + prompt_vec = std::vector<prompt_t>(1); + tokens_cur_vec = std::vector<tokens_t>(1); + } + else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) { + prompt_vec = std::vector<prompt_t>(params.beam_search.beam_width); + tokens_cur_vec = std::vector<tokens_t>(params.beam_search.beam_width); + } + for (auto& prompt : prompt_vec) { + prompt.reserve(model.parameters.n_text_ctx); + } + for (auto& tokens_cur : tokens_cur_vec) { + tokens_cur.reserve(model.parameters.n_text_ctx); + } // main loop int seek = seek_start; @@ -553,22 +603,38 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const // encode audio features starting at offset seek CHECK( encode( mel, seek ) ); - int n_past = 0; - prompt.clear(); + std::vector<int> n_past_vec; + if (params.strategy == Whisper::eSamplingStrategy::Greedy) { + n_past_vec = std::vector<int>(1, 0); + } + else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) { + n_past_vec = std::vector<int>(params.beam_search.beam_width, 0); + } + + for (auto& prompt : prompt_vec) { + prompt.clear(); + } // if we have already generated some text, use it as a prompt to condition the next generation - if( !prompt_past.empty() ) - { - int n_take = std::min( std::min( params.n_max_text_ctx, model.parameters.n_text_ctx / 2 ), int( prompt_past.size() ) ); + assert(prompt_past_vec.size() == prompt_vec.size()); + for (int i = 0; i < prompt_past_vec.size(); i++) { + auto& prompt_past = prompt_past_vec[i]; + auto& prompt = prompt_vec[i]; + if (!prompt_past.empty()) + { + int n_take = std::min(std::min(params.n_max_text_ctx, model.parameters.n_text_ctx / 2), int(prompt_past.size())); - prompt = { model.vocab.token_prev }; - prompt.insert( prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end() ); + prompt = { model.vocab.token_prev }; + prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end()); - prompt_past.clear(); - prompt_past.insert( prompt_past.end(), prompt.begin() + 1, prompt.end() ); + prompt_past.clear(); + prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end()); + } } - prompt.insert( prompt.end(), prompt_init.begin(), prompt_init.end() ); + for (auto& prompt : prompt_vec) { + prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); + } int seek_delta = 100 * WHISPER_CHUNK_SIZE; @@ -580,8 +646,17 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const //printf("\n\n"); // the accumulated transcription in the current iteration - int result_len = 0; - tokens_cur.clear(); + std::vector<int> result_len_vec; + if (params.strategy == Whisper::eSamplingStrategy::Greedy) { + result_len_vec = std::vector<int>(1, 0); + } + else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) { + result_len_vec = std::vector<int>(params.beam_search.beam_width, 0); + } + + for (auto& tokens_cur : tokens_cur_vec) { + tokens_cur.clear(); + } bool failed = false; bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? @@ -591,21 +666,29 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const auto prof = context.decodeProfiler(); for( int i = 0, n_max = model.parameters.n_text_ctx / 2 - 4; i < n_max; i++ ) { - CHECK( decode( prompt.data(), prompt.size(), n_past, params.cpuThreads ) ); - - n_past += (int)prompt.size(); - prompt.clear(); - - // very basic greedy sampling strategy: - // - // - always take the most probable token - // - // more sophisticated sampling strategies could be implemented here, but we keep it simple - // feel free to experiment! - // - { + if (params.strategy == Whisper::eSamplingStrategy::Greedy) { + auto& prompt = prompt_vec[0]; + auto& tokens_cur = tokens_cur_vec[0]; + auto& n_past = n_past_vec[0]; + auto& result_len = result_len_vec[0]; + + CHECK(decode(prompt.data(), prompt.size(), n_past, params.cpuThreads, /*nth=*/0)); + + n_past += (int)prompt.size(); + prompt.clear(); + + // very basic greedy sampling strategy: + // + // - always take the most probable token + // + // more sophisticated sampling strategies could be implemented here, but we keep it simple + // feel free to experiment! + // + auto p = profiler.cpuBlock( eCpuBlock::Sample ); - const sTokenData token = ( i == 0 ) ? sampleTimestamp( true ) : sampleBest(); + const sTokenData token = ( i == 0 ) + ? sampleTimestampN( true, /*nth=*/0, /*n_best=*/1)[0] + : sampleBestN(/*nth=*/0, /*n_best=*/1)[0]; // timestamp token - update sliding window if( token.id > model.vocab.token_beg ) @@ -655,13 +738,22 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const break; } - } - // sometimes, the decoding can get stuck in a repetition loop - // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance - // the sliding window by 1 second - if( i == n_max - 1 && ( result_len == 0 || seek_delta < 100 * WHISPER_CHUNK_SIZE / 2 ) ) - { + // sometimes, the decoding can get stuck in a repetition loop + // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance + // the sliding window by 1 second + if (i == n_max - 1 && (result_len == 0 || seek_delta < 100 * WHISPER_CHUNK_SIZE / 2)) + { + failed = true; + break; + } + } else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) { + logError(u8"%s: beam search not implemented", __func__); + failed = true; + break; + } + else { + logError(u8"%s: unsupported decoding strategy: %d", __func__, (int)params.strategy); failed = true; break; } @@ -675,6 +767,10 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const } // shrink down to result_len + // TODO assign these based on results of either decoding strategy + auto& tokens_cur = tokens_cur_vec[0]; + auto& result_len = result_len_vec[0]; + auto& prompt_past = prompt_past_vec[0]; tokens_cur.resize( result_len ); for( const auto& r : tokens_cur ) diff --git a/Whisper/Whisper/ContextImpl.h b/Whisper/Whisper/ContextImpl.h index c404b8f..cdad36b 100644 --- a/Whisper/Whisper/ContextImpl.h +++ b/Whisper/Whisper/ContextImpl.h @@ -38,7 +38,8 @@ namespace Whisper }; std::vector<Segment> result_all; - std::vector<whisper_token> prompt_past; + typedef std::vector<whisper_token> prompt_t; + std::vector<prompt_t> prompt_past_vec; // [EXPERIMENTAL] token-level timestamps data int64_t t_beg = 0; @@ -50,15 +51,17 @@ namespace Whisper int32_t exp_n_audio_ctx = 0; // 0 - use default HRESULT encode( iSpectrogram& mel, int seek ); - HRESULT decode( const int* tokens, size_t length, int n_past, int threads ); - sTokenData sampleBest( const float* probs, bool force_timestamp, bool is_initial ); - sTokenData sampleBest(); - sTokenData sampleTimestamp( bool initial ); + HRESULT decode( const int* tokens, size_t length, int n_past, int threads, int nth ); + std::vector<sTokenData> sampleBestN( const float* probs, bool force_timestamp, bool is_initial, int nth, int n_best ); + std::vector<sTokenData> sampleBestN(int nth, int n_best); + std::vector<sTokenData> sampleTimestampN( bool initial, int nth, int n_best ); int wrapSegment( int max_len ); void expComputeTokenLevelTimestamps( int i_segment, float thold_pt, float thold_ptsum ); - std::vector<float> probs; - std::vector<std::pair<double, Vocabulary::id>> probs_id; + typedef std::vector<float> probs_t; + std::vector<probs_t> probs_vec{ 1 }; + typedef std::vector<std::pair<double, Vocabulary::id>> probs_id_t; + std::vector<probs_id_t> probs_id_vec{ 1 }; mutable TranscribeResultStatic results; diff --git a/Whisper/Whisper/ContextImpl.misc.cpp b/Whisper/Whisper/ContextImpl.misc.cpp index 9a156fb..90a1bec 100644 --- a/Whisper/Whisper/ContextImpl.misc.cpp +++ b/Whisper/Whisper/ContextImpl.misc.cpp @@ -114,10 +114,10 @@ __m128i ContextImpl::getMemoryUse() const size_t cb = vectorMemoryUse( result_all ); for( const auto& r : result_all ) cb += r.memoryUsage(); - cb += vectorMemoryUse( prompt_past ); + cb += vectorMemoryUse( prompt_past_vec ); cb += vectorMemoryUse( energy ); - cb += vectorMemoryUse( probs ); - cb += vectorMemoryUse( probs_id ); + cb += vectorMemoryUse( probs_vec ); + cb += vectorMemoryUse( probs_id_vec ); cb += vectorMemoryUse( results.segments ); cb += vectorMemoryUse( results.tokens ); cb += spectrogram.memoryUsage(); |
