From a74ee78dbb79c551851dc182090e7d4292b1e80c Mon Sep 17 00:00:00 2001 From: yum Date: Fri, 3 Mar 2023 20:10:32 -0800 Subject: Use logprobs, fix beam candidate selection Incorrect sort condition resulted in worst 5 beams being picked instead of best 5. Use log probabilities for joint probability calculation instead of linear probabilities. Long beams would have probabilities converge exponentially towards zero; now they converge linearly towards -INFINITY. Using both transcripts in Evaluation/setup.ps1, I see a small edit distance regression (~5%) using beam search vs. greedy. --- Evaluate/.swp | Bin 12288 -> 0 bytes Evaluate/evaluate.py | 22 +++++++------------ Whisper/Whisper/ContextImpl.cpp | 47 ++++++++++++++++++++-------------------- Whisper/Whisper/ContextImpl.h | 6 ++--- 4 files changed, 34 insertions(+), 41 deletions(-) delete mode 100644 Evaluate/.swp diff --git a/Evaluate/.swp b/Evaluate/.swp deleted file mode 100644 index c1bc460..0000000 Binary files a/Evaluate/.swp and /dev/null differ diff --git a/Evaluate/evaluate.py b/Evaluate/evaluate.py index 5e8c85d..81b3edf 100644 --- a/Evaluate/evaluate.py +++ b/Evaluate/evaluate.py @@ -1,11 +1,12 @@ import argparse import editdistance -import jiwer import re import subprocess import sys import time +from whisper.normalizers import EnglishTextNormalizer + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("reference_path", type=str, help="Path to reference transcript") @@ -33,22 +34,15 @@ if __name__ == "__main__": with open(args.reference_path, "r") as f: ref_transcript = f.read() + # Normalize transcripts before computing edit distance (as described in + # whisper paper). + normalize = EnglishTextNormalizer() + test_transcript = normalize(test_transcript) + ref_transcript = normalize(ref_transcript) + dist = editdistance.eval(ref_transcript, test_transcript) - wer_transform = jiwer.Compose([ - jiwer.ToLowerCase(), - jiwer.RemoveWhiteSpace(replace_by_space=True), - jiwer.RemoveMultipleSpaces(), - jiwer.RemovePunctuation(), - jiwer.ReduceToListOfListOfWords(word_delimiter=" "), - ]) - wer = jiwer.wer( - ref_transcript, - test_transcript, - truth_transform=wer_transform, - hypothesis_transform=wer_transform) print(f"Duration: {t1 - t0}") print(f"Levenshtein distance: {dist}") - print(f"Word error rate: {wer}") print(f"Transcript: {test_transcript}") diff --git a/Whisper/Whisper/ContextImpl.cpp b/Whisper/Whisper/ContextImpl.cpp index 22347cf..c52a167 100644 --- a/Whisper/Whisper/ContextImpl.cpp +++ b/Whisper/Whisper/ContextImpl.cpp @@ -68,13 +68,13 @@ HRESULT ContextImpl::decode( const int* tokens, size_t length, int n_past, int t std::pair ContextImpl::beamGetMinJointProb() const { - float min_p = 1.1; - int min_p_beam; - int min_p_sample; + float min_p = INFINITY; + int min_p_beam = 0; + int min_p_sample = 0; for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) { for (int nth_best = 0; nth_best < ctx_[nth_beam].beam_ctx.best_tokens.size(); nth_best++) { - const float cur_p = ctx_[nth_beam].beam_ctx.joint_prob * - ctx_[nth_beam].beam_ctx.best_tokens[nth_best].p; + const float cur_p = ctx_[nth_beam].beam_ctx.joint_logprob + + log(ctx_[nth_beam].beam_ctx.best_tokens[nth_best].p); if (cur_p < min_p) { min_p = cur_p; min_p_beam = nth_beam; @@ -87,10 +87,10 @@ std::pair ContextImpl::beamGetMinJointProb() const int ContextImpl::beamGetMaxJointProb() const { - float max_p = -0.1; - int max_p_beam; + float max_p = -INFINITY; + int max_p_beam = 0; for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) { - const float cur_p = ctx_[nth_beam].beam_ctx.joint_prob; + const float cur_p = ctx_[nth_beam].beam_ctx.joint_logprob; if (cur_p > max_p) { max_p = cur_p; max_p_beam = nth_beam; @@ -643,7 +643,7 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const ctx.loop_ctx.tokens_cur.clear(); ctx.loop_ctx.result_len = 0; - ctx.beam_ctx.joint_prob = 1.0; + ctx.beam_ctx.joint_logprob = 0.0; ctx.beam_ctx.beam_done = false; ctx.seek_delta = 100 * WHISPER_CHUNK_SIZE; // have we already sampled a non-beg timestamp token for the current segment? @@ -757,18 +757,12 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const auto& prompt = ctx_[nth_beam].loop_ctx.prompt; auto& tokens_cur = ctx_[nth_beam].loop_ctx.tokens_cur; - if (prompt.empty()) { - // The current beam is already complete. - // TODO we can probably elide this assignment. - probs = probs_prev; - } - else { + if (!prompt.empty()) { CHECK(decode(prompt.data(), prompt.size(), n_past, params.cpuThreads, nth_beam)); n_past += (int)prompt.size(); prompt.clear(); probs_prev = probs; - } // For each beam, pick the top `beam_wd` most likely // tokens. @@ -777,6 +771,7 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const ? sampleTimestampN(true, nth_beam, beam_wd) : sampleBestN(nth_beam, beam_wd); ctx_[nth_beam].beam_ctx.best_tokens = tokens; + } } // Of the (`beam_wd` * `n_beams`) selected tokens, identify @@ -791,7 +786,7 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const best_beams_and_tokens[nth_beam] = { nth_beam, 0 }; } for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) { - const auto& joint_prob = ctx_[nth_beam].beam_ctx.joint_prob; + const auto& joint_prob = ctx_[nth_beam].beam_ctx.joint_logprob; for (int nth_best = 0; nth_best < beam_wd; nth_best++) { // We want to see if this (beam, sample) combo is better than // something in the current list of possibilities. To do that, @@ -799,11 +794,11 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const // against it. const auto& min_p_beam = min_beam_sample.first; const auto& min_p_sample = min_beam_sample.second; - const float min_p = ctx_[min_p_beam].beam_ctx.joint_prob * - ctx_[min_p_beam].beam_ctx.best_tokens[min_p_sample].p; + const float min_p = ctx_[min_p_beam].beam_ctx.joint_logprob + + log(ctx_[min_p_beam].beam_ctx.best_tokens[min_p_sample].p); const auto& token = ctx_[nth_beam].beam_ctx.best_tokens[nth_best]; - const float cur_p = ctx_[nth_beam].beam_ctx.joint_prob * token.p; + const float cur_p = ctx_[nth_beam].beam_ctx.joint_logprob + log(token.p); if (cur_p > min_p) { // Better parse found. Record. // TODO this could be done in constant time using a heap. @@ -819,11 +814,11 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const } } - // Sort parse in ascending order of beam. + // Sort parse in descending order of beam. std::sort(best_beams_and_tokens.begin(), best_beams_and_tokens.end(), [](const std::pair a, const std::pair b) { - return a.first < b.first; + return a.first > b.first; }); // Extract tokens for new parse. @@ -848,7 +843,7 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const bool ts_went_backwards = false; for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) { auto& beam_done = ctx_[nth_beam].beam_ctx.beam_done; - auto& joint_prob = ctx_[nth_beam].beam_ctx.joint_prob; + auto& joint_logprob = ctx_[nth_beam].beam_ctx.joint_logprob; auto& has_ts = ctx_[nth_beam].has_ts; auto& prompt = ctx_[nth_beam].loop_ctx.prompt; auto& result_len = ctx_[nth_beam].loop_ctx.result_len; @@ -856,6 +851,10 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const auto& tokens_cur = ctx_[nth_beam].loop_ctx.tokens_cur; auto& token = new_tokens[nth_beam]; + if (!tokens_cur.empty() && tokens_cur[tokens_cur.size() - 1].id != model.vocab.token_eot) { + beam_done = false; + } + if (!beam_done) { // Timestamp token: update sliding window. if (token.id > model.vocab.token_beg) { @@ -871,7 +870,7 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const prompt.push_back(token.id); tokens_cur.push_back(token); - joint_prob *= token.p; + joint_logprob += log(token.p); } if (token.id == model.vocab.token_eot || (params.max_tokens > 0 && i >= params.max_tokens) || diff --git a/Whisper/Whisper/ContextImpl.h b/Whisper/Whisper/ContextImpl.h index 2f34126..7357af3 100644 --- a/Whisper/Whisper/ContextImpl.h +++ b/Whisper/Whisper/ContextImpl.h @@ -59,9 +59,9 @@ namespace Whisper // Some beams may finish earlier than others, in which case // we'd like to avoid re-running inference. std::vector probs_prev; - // Joint probability of every token leading up to the current - // context. - float joint_prob; + // Joint log-probability of every token leading up to the + // current context. + float joint_logprob; bool beam_done; }; AudioFrameContext loop_ctx; -- cgit v1.2.3