summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2023-03-03 20:10:32 -0800
committeryum <yum.food.vr@gmail.com>2023-03-03 20:42:10 -0800
commita74ee78dbb79c551851dc182090e7d4292b1e80c (patch)
treea18542faddbdb1a8cc6285e39d0bcb3aad47a19d
parentf7d5741e5c069d759f8412bd40b279e1d7abac4c (diff)
Use logprobs, fix beam candidate selection
Incorrect sort condition resulted in worst 5 beams being picked instead of best 5. Use log probabilities for joint probability calculation instead of linear probabilities. Long beams would have probabilities converge exponentially towards zero; now they converge linearly towards -INFINITY. Using both transcripts in Evaluation/setup.ps1, I see a small edit distance regression (~5%) using beam search vs. greedy.
-rw-r--r--Evaluate/.swpbin12288 -> 0 bytes
-rw-r--r--Evaluate/evaluate.py22
-rw-r--r--Whisper/Whisper/ContextImpl.cpp47
-rw-r--r--Whisper/Whisper/ContextImpl.h6
4 files changed, 34 insertions, 41 deletions
diff --git a/Evaluate/.swp b/Evaluate/.swp
deleted file mode 100644
index c1bc460..0000000
--- a/Evaluate/.swp
+++ /dev/null
Binary files differ
diff --git a/Evaluate/evaluate.py b/Evaluate/evaluate.py
index 5e8c85d..81b3edf 100644
--- a/Evaluate/evaluate.py
+++ b/Evaluate/evaluate.py
@@ -1,11 +1,12 @@
import argparse
import editdistance
-import jiwer
import re
import subprocess
import sys
import time
+from whisper.normalizers import EnglishTextNormalizer
+
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("reference_path", type=str, help="Path to reference transcript")
@@ -33,22 +34,15 @@ if __name__ == "__main__":
with open(args.reference_path, "r") as f:
ref_transcript = f.read()
+ # Normalize transcripts before computing edit distance (as described in
+ # whisper paper).
+ normalize = EnglishTextNormalizer()
+ test_transcript = normalize(test_transcript)
+ ref_transcript = normalize(ref_transcript)
+
dist = editdistance.eval(ref_transcript, test_transcript)
- wer_transform = jiwer.Compose([
- jiwer.ToLowerCase(),
- jiwer.RemoveWhiteSpace(replace_by_space=True),
- jiwer.RemoveMultipleSpaces(),
- jiwer.RemovePunctuation(),
- jiwer.ReduceToListOfListOfWords(word_delimiter=" "),
- ])
- wer = jiwer.wer(
- ref_transcript,
- test_transcript,
- truth_transform=wer_transform,
- hypothesis_transform=wer_transform)
print(f"Duration: {t1 - t0}")
print(f"Levenshtein distance: {dist}")
- print(f"Word error rate: {wer}")
print(f"Transcript: {test_transcript}")
diff --git a/Whisper/Whisper/ContextImpl.cpp b/Whisper/Whisper/ContextImpl.cpp
index 22347cf..c52a167 100644
--- a/Whisper/Whisper/ContextImpl.cpp
+++ b/Whisper/Whisper/ContextImpl.cpp
@@ -68,13 +68,13 @@ HRESULT ContextImpl::decode( const int* tokens, size_t length, int n_past, int t
std::pair<int, int> ContextImpl::beamGetMinJointProb() const
{
- float min_p = 1.1;
- int min_p_beam;
- int min_p_sample;
+ float min_p = INFINITY;
+ int min_p_beam = 0;
+ int min_p_sample = 0;
for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) {
for (int nth_best = 0; nth_best < ctx_[nth_beam].beam_ctx.best_tokens.size(); nth_best++) {
- const float cur_p = ctx_[nth_beam].beam_ctx.joint_prob *
- ctx_[nth_beam].beam_ctx.best_tokens[nth_best].p;
+ const float cur_p = ctx_[nth_beam].beam_ctx.joint_logprob +
+ log(ctx_[nth_beam].beam_ctx.best_tokens[nth_best].p);
if (cur_p < min_p) {
min_p = cur_p;
min_p_beam = nth_beam;
@@ -87,10 +87,10 @@ std::pair<int, int> ContextImpl::beamGetMinJointProb() const
int ContextImpl::beamGetMaxJointProb() const
{
- float max_p = -0.1;
- int max_p_beam;
+ float max_p = -INFINITY;
+ int max_p_beam = 0;
for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) {
- const float cur_p = ctx_[nth_beam].beam_ctx.joint_prob;
+ const float cur_p = ctx_[nth_beam].beam_ctx.joint_logprob;
if (cur_p > max_p) {
max_p = cur_p;
max_p_beam = nth_beam;
@@ -643,7 +643,7 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
ctx.loop_ctx.tokens_cur.clear();
ctx.loop_ctx.result_len = 0;
- ctx.beam_ctx.joint_prob = 1.0;
+ ctx.beam_ctx.joint_logprob = 0.0;
ctx.beam_ctx.beam_done = false;
ctx.seek_delta = 100 * WHISPER_CHUNK_SIZE;
// have we already sampled a non-beg timestamp token for the current segment?
@@ -757,18 +757,12 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
auto& prompt = ctx_[nth_beam].loop_ctx.prompt;
auto& tokens_cur = ctx_[nth_beam].loop_ctx.tokens_cur;
- if (prompt.empty()) {
- // The current beam is already complete.
- // TODO we can probably elide this assignment.
- probs = probs_prev;
- }
- else {
+ if (!prompt.empty()) {
CHECK(decode(prompt.data(), prompt.size(), n_past,
params.cpuThreads, nth_beam));
n_past += (int)prompt.size();
prompt.clear();
probs_prev = probs;
- }
// For each beam, pick the top `beam_wd` most likely
// tokens.
@@ -777,6 +771,7 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
? sampleTimestampN(true, nth_beam, beam_wd)
: sampleBestN(nth_beam, beam_wd);
ctx_[nth_beam].beam_ctx.best_tokens = tokens;
+ }
}
// Of the (`beam_wd` * `n_beams`) selected tokens, identify
@@ -791,7 +786,7 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
best_beams_and_tokens[nth_beam] = { nth_beam, 0 };
}
for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) {
- const auto& joint_prob = ctx_[nth_beam].beam_ctx.joint_prob;
+ const auto& joint_prob = ctx_[nth_beam].beam_ctx.joint_logprob;
for (int nth_best = 0; nth_best < beam_wd; nth_best++) {
// We want to see if this (beam, sample) combo is better than
// something in the current list of possibilities. To do that,
@@ -799,11 +794,11 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
// against it.
const auto& min_p_beam = min_beam_sample.first;
const auto& min_p_sample = min_beam_sample.second;
- const float min_p = ctx_[min_p_beam].beam_ctx.joint_prob *
- ctx_[min_p_beam].beam_ctx.best_tokens[min_p_sample].p;
+ const float min_p = ctx_[min_p_beam].beam_ctx.joint_logprob +
+ log(ctx_[min_p_beam].beam_ctx.best_tokens[min_p_sample].p);
const auto& token = ctx_[nth_beam].beam_ctx.best_tokens[nth_best];
- const float cur_p = ctx_[nth_beam].beam_ctx.joint_prob * token.p;
+ const float cur_p = ctx_[nth_beam].beam_ctx.joint_logprob + log(token.p);
if (cur_p > min_p) {
// Better parse found. Record.
// TODO this could be done in constant time using a heap.
@@ -819,11 +814,11 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
}
}
- // Sort parse in ascending order of beam.
+ // Sort parse in descending order of beam.
std::sort(best_beams_and_tokens.begin(),
best_beams_and_tokens.end(),
[](const std::pair<int, int> a, const std::pair<int, int> b) {
- return a.first < b.first;
+ return a.first > b.first;
});
// Extract tokens for new parse.
@@ -848,7 +843,7 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
bool ts_went_backwards = false;
for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) {
auto& beam_done = ctx_[nth_beam].beam_ctx.beam_done;
- auto& joint_prob = ctx_[nth_beam].beam_ctx.joint_prob;
+ auto& joint_logprob = ctx_[nth_beam].beam_ctx.joint_logprob;
auto& has_ts = ctx_[nth_beam].has_ts;
auto& prompt = ctx_[nth_beam].loop_ctx.prompt;
auto& result_len = ctx_[nth_beam].loop_ctx.result_len;
@@ -856,6 +851,10 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
auto& tokens_cur = ctx_[nth_beam].loop_ctx.tokens_cur;
auto& token = new_tokens[nth_beam];
+ if (!tokens_cur.empty() && tokens_cur[tokens_cur.size() - 1].id != model.vocab.token_eot) {
+ beam_done = false;
+ }
+
if (!beam_done) {
// Timestamp token: update sliding window.
if (token.id > model.vocab.token_beg) {
@@ -871,7 +870,7 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
prompt.push_back(token.id);
tokens_cur.push_back(token);
- joint_prob *= token.p;
+ joint_logprob += log(token.p);
}
if (token.id == model.vocab.token_eot ||
(params.max_tokens > 0 && i >= params.max_tokens) ||
diff --git a/Whisper/Whisper/ContextImpl.h b/Whisper/Whisper/ContextImpl.h
index 2f34126..7357af3 100644
--- a/Whisper/Whisper/ContextImpl.h
+++ b/Whisper/Whisper/ContextImpl.h
@@ -59,9 +59,9 @@ namespace Whisper
// Some beams may finish earlier than others, in which case
// we'd like to avoid re-running inference.
std::vector<float> probs_prev;
- // Joint probability of every token leading up to the current
- // context.
- float joint_prob;
+ // Joint log-probability of every token leading up to the
+ // current context.
+ float joint_logprob;
bool beam_done;
};
AudioFrameContext loop_ctx;