diff options
| author | yum <yum.food.vr@gmail.com> | 2023-03-03 20:10:32 -0800 |
|---|---|---|
| committer | yum <yum.food.vr@gmail.com> | 2023-03-03 20:42:10 -0800 |
| commit | a74ee78dbb79c551851dc182090e7d4292b1e80c (patch) | |
| tree | a18542faddbdb1a8cc6285e39d0bcb3aad47a19d /Whisper | |
| parent | f7d5741e5c069d759f8412bd40b279e1d7abac4c (diff) | |
Use logprobs, fix beam candidate selection
Incorrect sort condition resulted in worst 5 beams being picked instead of
best 5.
Use log probabilities for joint probability calculation instead of
linear probabilities. Long beams would have probabilities converge
exponentially towards zero; now they converge linearly towards
-INFINITY.
Using both transcripts in Evaluation/setup.ps1, I see a small edit
distance regression (~5%) using beam search vs. greedy.
Diffstat (limited to 'Whisper')
| -rw-r--r-- | Whisper/Whisper/ContextImpl.cpp | 47 | ||||
| -rw-r--r-- | Whisper/Whisper/ContextImpl.h | 6 |
2 files changed, 26 insertions, 27 deletions
diff --git a/Whisper/Whisper/ContextImpl.cpp b/Whisper/Whisper/ContextImpl.cpp index 22347cf..c52a167 100644 --- a/Whisper/Whisper/ContextImpl.cpp +++ b/Whisper/Whisper/ContextImpl.cpp @@ -68,13 +68,13 @@ HRESULT ContextImpl::decode( const int* tokens, size_t length, int n_past, int t std::pair<int, int> ContextImpl::beamGetMinJointProb() const { - float min_p = 1.1; - int min_p_beam; - int min_p_sample; + float min_p = INFINITY; + int min_p_beam = 0; + int min_p_sample = 0; for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) { for (int nth_best = 0; nth_best < ctx_[nth_beam].beam_ctx.best_tokens.size(); nth_best++) { - const float cur_p = ctx_[nth_beam].beam_ctx.joint_prob * - ctx_[nth_beam].beam_ctx.best_tokens[nth_best].p; + const float cur_p = ctx_[nth_beam].beam_ctx.joint_logprob + + log(ctx_[nth_beam].beam_ctx.best_tokens[nth_best].p); if (cur_p < min_p) { min_p = cur_p; min_p_beam = nth_beam; @@ -87,10 +87,10 @@ std::pair<int, int> ContextImpl::beamGetMinJointProb() const int ContextImpl::beamGetMaxJointProb() const { - float max_p = -0.1; - int max_p_beam; + float max_p = -INFINITY; + int max_p_beam = 0; for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) { - const float cur_p = ctx_[nth_beam].beam_ctx.joint_prob; + const float cur_p = ctx_[nth_beam].beam_ctx.joint_logprob; if (cur_p > max_p) { max_p = cur_p; max_p_beam = nth_beam; @@ -643,7 +643,7 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const ctx.loop_ctx.tokens_cur.clear(); ctx.loop_ctx.result_len = 0; - ctx.beam_ctx.joint_prob = 1.0; + ctx.beam_ctx.joint_logprob = 0.0; ctx.beam_ctx.beam_done = false; ctx.seek_delta = 100 * WHISPER_CHUNK_SIZE; // have we already sampled a non-beg timestamp token for the current segment? @@ -757,18 +757,12 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const auto& prompt = ctx_[nth_beam].loop_ctx.prompt; auto& tokens_cur = ctx_[nth_beam].loop_ctx.tokens_cur; - if (prompt.empty()) { - // The current beam is already complete. - // TODO we can probably elide this assignment. - probs = probs_prev; - } - else { + if (!prompt.empty()) { CHECK(decode(prompt.data(), prompt.size(), n_past, params.cpuThreads, nth_beam)); n_past += (int)prompt.size(); prompt.clear(); probs_prev = probs; - } // For each beam, pick the top `beam_wd` most likely // tokens. @@ -777,6 +771,7 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const ? sampleTimestampN(true, nth_beam, beam_wd) : sampleBestN(nth_beam, beam_wd); ctx_[nth_beam].beam_ctx.best_tokens = tokens; + } } // Of the (`beam_wd` * `n_beams`) selected tokens, identify @@ -791,7 +786,7 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const best_beams_and_tokens[nth_beam] = { nth_beam, 0 }; } for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) { - const auto& joint_prob = ctx_[nth_beam].beam_ctx.joint_prob; + const auto& joint_prob = ctx_[nth_beam].beam_ctx.joint_logprob; for (int nth_best = 0; nth_best < beam_wd; nth_best++) { // We want to see if this (beam, sample) combo is better than // something in the current list of possibilities. To do that, @@ -799,11 +794,11 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const // against it. const auto& min_p_beam = min_beam_sample.first; const auto& min_p_sample = min_beam_sample.second; - const float min_p = ctx_[min_p_beam].beam_ctx.joint_prob * - ctx_[min_p_beam].beam_ctx.best_tokens[min_p_sample].p; + const float min_p = ctx_[min_p_beam].beam_ctx.joint_logprob + + log(ctx_[min_p_beam].beam_ctx.best_tokens[min_p_sample].p); const auto& token = ctx_[nth_beam].beam_ctx.best_tokens[nth_best]; - const float cur_p = ctx_[nth_beam].beam_ctx.joint_prob * token.p; + const float cur_p = ctx_[nth_beam].beam_ctx.joint_logprob + log(token.p); if (cur_p > min_p) { // Better parse found. Record. // TODO this could be done in constant time using a heap. @@ -819,11 +814,11 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const } } - // Sort parse in ascending order of beam. + // Sort parse in descending order of beam. std::sort(best_beams_and_tokens.begin(), best_beams_and_tokens.end(), [](const std::pair<int, int> a, const std::pair<int, int> b) { - return a.first < b.first; + return a.first > b.first; }); // Extract tokens for new parse. @@ -848,7 +843,7 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const bool ts_went_backwards = false; for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) { auto& beam_done = ctx_[nth_beam].beam_ctx.beam_done; - auto& joint_prob = ctx_[nth_beam].beam_ctx.joint_prob; + auto& joint_logprob = ctx_[nth_beam].beam_ctx.joint_logprob; auto& has_ts = ctx_[nth_beam].has_ts; auto& prompt = ctx_[nth_beam].loop_ctx.prompt; auto& result_len = ctx_[nth_beam].loop_ctx.result_len; @@ -856,6 +851,10 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const auto& tokens_cur = ctx_[nth_beam].loop_ctx.tokens_cur; auto& token = new_tokens[nth_beam]; + if (!tokens_cur.empty() && tokens_cur[tokens_cur.size() - 1].id != model.vocab.token_eot) { + beam_done = false; + } + if (!beam_done) { // Timestamp token: update sliding window. if (token.id > model.vocab.token_beg) { @@ -871,7 +870,7 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const prompt.push_back(token.id); tokens_cur.push_back(token); - joint_prob *= token.p; + joint_logprob += log(token.p); } if (token.id == model.vocab.token_eot || (params.max_tokens > 0 && i >= params.max_tokens) || diff --git a/Whisper/Whisper/ContextImpl.h b/Whisper/Whisper/ContextImpl.h index 2f34126..7357af3 100644 --- a/Whisper/Whisper/ContextImpl.h +++ b/Whisper/Whisper/ContextImpl.h @@ -59,9 +59,9 @@ namespace Whisper // Some beams may finish earlier than others, in which case // we'd like to avoid re-running inference. std::vector<float> probs_prev; - // Joint probability of every token leading up to the current - // context. - float joint_prob; + // Joint log-probability of every token leading up to the + // current context. + float joint_logprob; bool beam_done; }; AudioFrameContext loop_ctx; |
