diff options
| author | yum <yum.food.vr@gmail.com> | 2023-03-02 20:32:49 -0800 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2023-03-02 20:33:22 -0800 |
| commit | d743645ba27cc85d36fe6820cd9d21f0fc4a11f2 (patch) | |
| tree | 0e3d27ad5e63911f23bbefebf6ca017a17cf38f3 | |
| parent | dcd7f3b60e3b9ad8df83d444f8bc67091b411529 (diff) | |
Finish beam search rough draft
Seems to work. Doesn't crash. Lots of room for optimization and cleanup.
| -rw-r--r-- | Whisper/Whisper/ContextImpl.cpp | 223 | ||||
| -rw-r--r-- | Whisper/Whisper/ContextImpl.h | 23 |
2 files changed, 229 insertions, 17 deletions
diff --git a/Whisper/Whisper/ContextImpl.cpp b/Whisper/Whisper/ContextImpl.cpp index 5428408..4e384ca 100644 --- a/Whisper/Whisper/ContextImpl.cpp +++ b/Whisper/Whisper/ContextImpl.cpp @@ -66,6 +66,39 @@ HRESULT ContextImpl::decode( const int* tokens, size_t length, int n_past, int t } } +std::pair<int, int> ContextImpl::beamGetMinJointProb() const +{ + float min_p = 1.1; + int min_p_beam; + int min_p_sample; + for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) { + for (int nth_best = 0; nth_best < ctx_[nth_beam].beam_ctx.best_tokens.size(); nth_best++) { + const float cur_p = ctx_[nth_beam].beam_ctx.joint_prob * + ctx_[nth_beam].beam_ctx.best_tokens[nth_best].p; + if (cur_p < min_p) { + min_p = cur_p; + min_p_beam = nth_beam; + min_p_sample = nth_best; + } + } + } + return { min_p_beam, min_p_sample }; +} + +int ContextImpl::beamGetMaxJointProb() const +{ + float max_p = -0.1; + int max_p_beam; + for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) { + const float cur_p = ctx_[nth_beam].beam_ctx.joint_prob; + if (cur_p > max_p) { + max_p = cur_p; + max_p_beam = nth_beam; + } + } + return max_p_beam; +} + // the most basic sampling scheme - select the top token std::vector<sTokenData> ContextImpl::sampleBestN( const float* probs, bool force_timestamp, bool is_initial, int nth, int n_best ) @@ -181,9 +214,10 @@ std::vector<sTokenData> ContextImpl::sampleTimestampN(bool initial, int nth, int n_best) { const int n_vocab = model.vocab.n_vocab; - return sampleBestN( + auto ts = sampleBestN( ctx_[nth].probs.data() + (ctx_[nth].probs.size() - n_vocab), - true, initial, nth, n_best); + true, initial, nth, 1)[0]; + return std::vector<sTokenData>(n_best, ts); } // a cost-function / heuristic that is high for text that takes longer to pronounce @@ -591,6 +625,7 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const auto& prompt_past = ctx.prompt_past; auto& prompt = ctx.loop_ctx.prompt; + ctx.n_past = 0; prompt.clear(); if (!prompt_past.empty()) @@ -605,12 +640,16 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const } prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end()); + ctx.loop_ctx.tokens_cur.clear(); ctx.loop_ctx.result_len = 0; + ctx.beam_ctx.joint_prob = 1.0; + ctx.beam_ctx.beam_done = false; + ctx.seek_delta = 100 * WHISPER_CHUNK_SIZE; + // have we already sampled a non-beg timestamp token for the current segment? + ctx.has_ts = false; } - int seek_delta = 100 * WHISPER_CHUNK_SIZE; - // print the prompt //printf("\n\n"); //for (int i = 0; i < prompt.size(); i++) { @@ -619,7 +658,6 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const //printf("\n\n"); bool failed = false; - bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment? { // Measure "Decode" profiler value, both CPU and GPU times @@ -627,10 +665,12 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const for( int i = 0, n_max = model.parameters.n_text_ctx / 2 - 4; i < n_max; i++ ) { if (params.strategy == Whisper::eSamplingStrategy::Greedy) { + auto& has_ts = ctx_[0].has_ts; + auto& n_past = ctx_[0].n_past; auto& prompt = ctx_[0].loop_ctx.prompt; - auto& tokens_cur = ctx_[0].loop_ctx.tokens_cur; - auto& n_past = ctx_[0].loop_ctx.n_past; auto& result_len = ctx_[0].loop_ctx.result_len; + auto& seek_delta = ctx_[0].seek_delta; + auto& tokens_cur = ctx_[0].loop_ctx.tokens_cur; CHECK(decode(prompt.data(), prompt.size(), n_past, params.cpuThreads, /*nth=*/0)); @@ -708,9 +748,160 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const break; } } else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) { - logError(u8"%s: beam search not implemented", __func__); - failed = true; - break; + // Get the most likely `beam_wd` tokens for each beam. + const int beam_wd = params.beam_search.n_best; + for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) { + auto& n_past = ctx_[nth_beam].n_past; + auto& probs = ctx_[nth_beam].probs; + auto& probs_prev = ctx_[nth_beam].beam_ctx.probs_prev; + auto& prompt = ctx_[nth_beam].loop_ctx.prompt; + auto& tokens_cur = ctx_[nth_beam].loop_ctx.tokens_cur; + + if (prompt.empty()) { + // The current beam is already complete. + // TODO we can probably elide this assignment. + probs = probs_prev; + } + else { + CHECK(decode(prompt.data(), prompt.size(), n_past, + params.cpuThreads, nth_beam)); + n_past += (int)prompt.size(); + prompt.clear(); + probs_prev = probs; + } + + // For each beam, pick the top `beam_wd` most likely + // tokens. + auto p = profiler.cpuBlock(eCpuBlock::Sample); + const std::vector<sTokenData> tokens = (i == 0) + ? sampleTimestampN(true, nth_beam, beam_wd) + : sampleBestN(nth_beam, beam_wd); + ctx_[nth_beam].beam_ctx.best_tokens = tokens; + } + + // Of the (`beam_wd` * `n_beams`) selected tokens, identify + // the `n_beams` tokens with the highest joint probability. + std::vector<std::pair<int, int>> best_beams_and_tokens(n_ctxt); + { + std::pair<int, int> min_beam_sample = beamGetMinJointProb(); + + for (int nth_beam = 0; nth_beam < best_beams_and_tokens.size(); nth_beam++) { + // Initialize to simply pick the first token in each beam. + // The only important thing is that each element is different. + best_beams_and_tokens[nth_beam] = { nth_beam, 0 }; + } + for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) { + const auto& joint_prob = ctx_[nth_beam].beam_ctx.joint_prob; + for (int nth_best = 0; nth_best < beam_wd; nth_best++) { + // We want to see if this (beam, sample) combo is better than + // something in the current list of possibilities. To do that, + // find the lowest-joint-probability parse in the list, and compare + // against it. + const auto& min_p_beam = min_beam_sample.first; + const auto& min_p_sample = min_beam_sample.second; + const float min_p = ctx_[min_p_beam].beam_ctx.joint_prob * + ctx_[min_p_beam].beam_ctx.best_tokens[min_p_sample].p; + + const auto& token = ctx_[nth_beam].beam_ctx.best_tokens[nth_best]; + const float cur_p = ctx_[nth_beam].beam_ctx.joint_prob * token.p; + if (cur_p > min_p) { + // Better parse found. Record. + // TODO this could be done in constant time using a heap. + for (auto& [cur_beam, cur_sample] : best_beams_and_tokens) { + if (cur_beam == min_p_beam && cur_sample == min_p_sample) { + cur_beam = nth_beam; + cur_sample = nth_best; + } + } + min_beam_sample = beamGetMinJointProb(); + } + } + } + } + + // Sort parse in ascending order of beam. + std::sort(best_beams_and_tokens.begin(), + best_beams_and_tokens.end(), + [](const std::pair<int, int> a, const std::pair<int, int> b) { + return a.first < b.first; + }); + + // Extract tokens for new parse. + std::vector<sTokenData> new_tokens; + for (const auto& [nth_beam, nth_best] : best_beams_and_tokens) { + new_tokens.push_back( + ctx_[nth_beam].beam_ctx.best_tokens[nth_best]); + } + + // Update context to reflect new parse. Because the beams are sorted + // in ascending order, we're always copying from old or + // identical data. + for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) { +#if 0 + // Trivial optimization: only copy if beams differ. + if (nth_beam == best_beams_and_tokens[nth_beam].first) { + continue; + } +#endif + ctx_[nth_beam] = ctx_[best_beams_and_tokens[nth_beam].first]; + } + + // Emit. + bool ts_went_backwards = false; + for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) { + auto& beam_done = ctx_[nth_beam].beam_ctx.beam_done; + auto& joint_prob = ctx_[nth_beam].beam_ctx.joint_prob; + auto& has_ts = ctx_[nth_beam].has_ts; + auto& prompt = ctx_[nth_beam].loop_ctx.prompt; + auto& result_len = ctx_[nth_beam].loop_ctx.result_len; + auto& seek_delta = ctx_[nth_beam].seek_delta; + auto& tokens_cur = ctx_[nth_beam].loop_ctx.tokens_cur; + auto& token = new_tokens[nth_beam]; + + if (!beam_done) { + // Timestamp token: update sliding window. + if (token.id > model.vocab.token_beg) { + const int seek_delta_new = 2 * (token.id - model.vocab.token_beg); + if (has_ts && seek_delta > seek_delta_new && result_len < i) { + ts_went_backwards = true; + break; + } + seek_delta = seek_delta_new; + result_len = i + 1; + has_ts = true; + } + + prompt.push_back(token.id); + tokens_cur.push_back(token); + joint_prob *= token.p; + } + if (token.id == model.vocab.token_eot || + (params.max_tokens > 0 && i >= params.max_tokens) || + (has_ts && seek + seek_delta + 100 >= seek_end)) { + if (result_len == 0) + { + if (seek + seek_delta + 100 >= seek_end) + result_len = i + 1; + else + { + failed = true; + break; + } + } + beam_done = true; + } + } + + bool all_done = true; + for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) { + if (!ctx_[nth_beam].beam_ctx.beam_done) { + all_done = false; + break; + } + } + if (all_done || ts_went_backwards || failed) { + break; + } } else { logError(u8"%s: unsupported decoding strategy: %d", __func__, (int)params.strategy); @@ -726,12 +917,14 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const continue; } + const int best_beam = beamGetMaxJointProb(); + auto& prompt_past = ctx_[best_beam].prompt_past; + auto& result_len = ctx_[best_beam].loop_ctx.result_len; + auto& seek_delta = ctx_[best_beam].seek_delta; + auto& tokens_cur = ctx_[best_beam].loop_ctx.tokens_cur; + // shrink down to result_len - // TODO assign these based on results of either decoding strategy - auto& tokens_cur = ctx_[0].loop_ctx.tokens_cur; - auto& result_len = ctx_[0].loop_ctx.result_len; - auto& prompt_past = ctx_[0].prompt_past; - tokens_cur.resize( result_len ); + tokens_cur.resize(result_len); for( const auto& r : tokens_cur ) prompt_past.push_back( r.id ); diff --git a/Whisper/Whisper/ContextImpl.h b/Whisper/Whisper/ContextImpl.h index 33d3007..2f34126 100644 --- a/Whisper/Whisper/ContextImpl.h +++ b/Whisper/Whisper/ContextImpl.h @@ -42,16 +42,30 @@ namespace Whisper std::vector<whisper_token> prompt_past; std::vector<float> probs; std::vector<std::pair<double, Vocabulary::id>> probs_id; + int seek_delta; + bool has_ts; + int n_past = 0; // These are cleared on every frame of audio processed. struct AudioFrameContext { std::vector<whisper_token> prompt; std::vector<sTokenData> tokens_cur; - std::vector<float> joint_probs; int result_len = 0; - int n_past = 0; + }; + struct BeamSearchContext { + // Each beam picks the N most likely tokens and accumulates + // them here. + std::vector<sTokenData> best_tokens; + // Some beams may finish earlier than others, in which case + // we'd like to avoid re-running inference. + std::vector<float> probs_prev; + // Joint probability of every token leading up to the current + // context. + float joint_prob; + bool beam_done; }; AudioFrameContext loop_ctx; + BeamSearchContext beam_ctx; }; std::vector<Context> ctx_; @@ -72,6 +86,11 @@ namespace Whisper int wrapSegment( int max_len ); void expComputeTokenLevelTimestamps( int i_segment, float thold_pt, float thold_ptsum ); + // Return the (nth beam, nth best) pair with the lowest joint probability. + std::pair<int, int> beamGetMinJointProb() const; + // Return the (nth beam, nth best) pair with the highest joint probability. + int beamGetMaxJointProb() const; + mutable TranscribeResultStatic results; HRESULT COMLIGHTCALL makeResults( eResultFlags flags, TranscribeResult& res ) const noexcept; |
