summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2023-03-02 20:32:49 -0800
committeryum <yum.food.vr@gmail.com>2023-03-02 20:33:22 -0800
commitd743645ba27cc85d36fe6820cd9d21f0fc4a11f2 (patch)
tree0e3d27ad5e63911f23bbefebf6ca017a17cf38f3
parentdcd7f3b60e3b9ad8df83d444f8bc67091b411529 (diff)
Finish beam search rough draft
Seems to work. Doesn't crash. Lots of room for optimization and cleanup.
-rw-r--r--Whisper/Whisper/ContextImpl.cpp223
-rw-r--r--Whisper/Whisper/ContextImpl.h23
2 files changed, 229 insertions, 17 deletions
diff --git a/Whisper/Whisper/ContextImpl.cpp b/Whisper/Whisper/ContextImpl.cpp
index 5428408..4e384ca 100644
--- a/Whisper/Whisper/ContextImpl.cpp
+++ b/Whisper/Whisper/ContextImpl.cpp
@@ -66,6 +66,39 @@ HRESULT ContextImpl::decode( const int* tokens, size_t length, int n_past, int t
}
}
+std::pair<int, int> ContextImpl::beamGetMinJointProb() const
+{
+ float min_p = 1.1;
+ int min_p_beam;
+ int min_p_sample;
+ for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) {
+ for (int nth_best = 0; nth_best < ctx_[nth_beam].beam_ctx.best_tokens.size(); nth_best++) {
+ const float cur_p = ctx_[nth_beam].beam_ctx.joint_prob *
+ ctx_[nth_beam].beam_ctx.best_tokens[nth_best].p;
+ if (cur_p < min_p) {
+ min_p = cur_p;
+ min_p_beam = nth_beam;
+ min_p_sample = nth_best;
+ }
+ }
+ }
+ return { min_p_beam, min_p_sample };
+}
+
+int ContextImpl::beamGetMaxJointProb() const
+{
+ float max_p = -0.1;
+ int max_p_beam;
+ for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) {
+ const float cur_p = ctx_[nth_beam].beam_ctx.joint_prob;
+ if (cur_p > max_p) {
+ max_p = cur_p;
+ max_p_beam = nth_beam;
+ }
+ }
+ return max_p_beam;
+}
+
// the most basic sampling scheme - select the top token
std::vector<sTokenData> ContextImpl::sampleBestN( const float* probs,
bool force_timestamp, bool is_initial, int nth, int n_best )
@@ -181,9 +214,10 @@ std::vector<sTokenData> ContextImpl::sampleTimestampN(bool initial, int nth,
int n_best)
{
const int n_vocab = model.vocab.n_vocab;
- return sampleBestN(
+ auto ts = sampleBestN(
ctx_[nth].probs.data() + (ctx_[nth].probs.size() - n_vocab),
- true, initial, nth, n_best);
+ true, initial, nth, 1)[0];
+ return std::vector<sTokenData>(n_best, ts);
}
// a cost-function / heuristic that is high for text that takes longer to pronounce
@@ -591,6 +625,7 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
auto& prompt_past = ctx.prompt_past;
auto& prompt = ctx.loop_ctx.prompt;
+ ctx.n_past = 0;
prompt.clear();
if (!prompt_past.empty())
@@ -605,12 +640,16 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
}
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
+
ctx.loop_ctx.tokens_cur.clear();
ctx.loop_ctx.result_len = 0;
+ ctx.beam_ctx.joint_prob = 1.0;
+ ctx.beam_ctx.beam_done = false;
+ ctx.seek_delta = 100 * WHISPER_CHUNK_SIZE;
+ // have we already sampled a non-beg timestamp token for the current segment?
+ ctx.has_ts = false;
}
- int seek_delta = 100 * WHISPER_CHUNK_SIZE;
-
// print the prompt
//printf("\n\n");
//for (int i = 0; i < prompt.size(); i++) {
@@ -619,7 +658,6 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
//printf("\n\n");
bool failed = false;
- bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
{
// Measure "Decode" profiler value, both CPU and GPU times
@@ -627,10 +665,12 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
for( int i = 0, n_max = model.parameters.n_text_ctx / 2 - 4; i < n_max; i++ )
{
if (params.strategy == Whisper::eSamplingStrategy::Greedy) {
+ auto& has_ts = ctx_[0].has_ts;
+ auto& n_past = ctx_[0].n_past;
auto& prompt = ctx_[0].loop_ctx.prompt;
- auto& tokens_cur = ctx_[0].loop_ctx.tokens_cur;
- auto& n_past = ctx_[0].loop_ctx.n_past;
auto& result_len = ctx_[0].loop_ctx.result_len;
+ auto& seek_delta = ctx_[0].seek_delta;
+ auto& tokens_cur = ctx_[0].loop_ctx.tokens_cur;
CHECK(decode(prompt.data(), prompt.size(), n_past, params.cpuThreads, /*nth=*/0));
@@ -708,9 +748,160 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
break;
}
} else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) {
- logError(u8"%s: beam search not implemented", __func__);
- failed = true;
- break;
+ // Get the most likely `beam_wd` tokens for each beam.
+ const int beam_wd = params.beam_search.n_best;
+ for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) {
+ auto& n_past = ctx_[nth_beam].n_past;
+ auto& probs = ctx_[nth_beam].probs;
+ auto& probs_prev = ctx_[nth_beam].beam_ctx.probs_prev;
+ auto& prompt = ctx_[nth_beam].loop_ctx.prompt;
+ auto& tokens_cur = ctx_[nth_beam].loop_ctx.tokens_cur;
+
+ if (prompt.empty()) {
+ // The current beam is already complete.
+ // TODO we can probably elide this assignment.
+ probs = probs_prev;
+ }
+ else {
+ CHECK(decode(prompt.data(), prompt.size(), n_past,
+ params.cpuThreads, nth_beam));
+ n_past += (int)prompt.size();
+ prompt.clear();
+ probs_prev = probs;
+ }
+
+ // For each beam, pick the top `beam_wd` most likely
+ // tokens.
+ auto p = profiler.cpuBlock(eCpuBlock::Sample);
+ const std::vector<sTokenData> tokens = (i == 0)
+ ? sampleTimestampN(true, nth_beam, beam_wd)
+ : sampleBestN(nth_beam, beam_wd);
+ ctx_[nth_beam].beam_ctx.best_tokens = tokens;
+ }
+
+ // Of the (`beam_wd` * `n_beams`) selected tokens, identify
+ // the `n_beams` tokens with the highest joint probability.
+ std::vector<std::pair<int, int>> best_beams_and_tokens(n_ctxt);
+ {
+ std::pair<int, int> min_beam_sample = beamGetMinJointProb();
+
+ for (int nth_beam = 0; nth_beam < best_beams_and_tokens.size(); nth_beam++) {
+ // Initialize to simply pick the first token in each beam.
+ // The only important thing is that each element is different.
+ best_beams_and_tokens[nth_beam] = { nth_beam, 0 };
+ }
+ for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) {
+ const auto& joint_prob = ctx_[nth_beam].beam_ctx.joint_prob;
+ for (int nth_best = 0; nth_best < beam_wd; nth_best++) {
+ // We want to see if this (beam, sample) combo is better than
+ // something in the current list of possibilities. To do that,
+ // find the lowest-joint-probability parse in the list, and compare
+ // against it.
+ const auto& min_p_beam = min_beam_sample.first;
+ const auto& min_p_sample = min_beam_sample.second;
+ const float min_p = ctx_[min_p_beam].beam_ctx.joint_prob *
+ ctx_[min_p_beam].beam_ctx.best_tokens[min_p_sample].p;
+
+ const auto& token = ctx_[nth_beam].beam_ctx.best_tokens[nth_best];
+ const float cur_p = ctx_[nth_beam].beam_ctx.joint_prob * token.p;
+ if (cur_p > min_p) {
+ // Better parse found. Record.
+ // TODO this could be done in constant time using a heap.
+ for (auto& [cur_beam, cur_sample] : best_beams_and_tokens) {
+ if (cur_beam == min_p_beam && cur_sample == min_p_sample) {
+ cur_beam = nth_beam;
+ cur_sample = nth_best;
+ }
+ }
+ min_beam_sample = beamGetMinJointProb();
+ }
+ }
+ }
+ }
+
+ // Sort parse in ascending order of beam.
+ std::sort(best_beams_and_tokens.begin(),
+ best_beams_and_tokens.end(),
+ [](const std::pair<int, int> a, const std::pair<int, int> b) {
+ return a.first < b.first;
+ });
+
+ // Extract tokens for new parse.
+ std::vector<sTokenData> new_tokens;
+ for (const auto& [nth_beam, nth_best] : best_beams_and_tokens) {
+ new_tokens.push_back(
+ ctx_[nth_beam].beam_ctx.best_tokens[nth_best]);
+ }
+
+ // Update context to reflect new parse. Because the beams are sorted
+ // in ascending order, we're always copying from old or
+ // identical data.
+ for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) {
+#if 0
+ // Trivial optimization: only copy if beams differ.
+ if (nth_beam == best_beams_and_tokens[nth_beam].first) {
+ continue;
+ }
+#endif
+ ctx_[nth_beam] = ctx_[best_beams_and_tokens[nth_beam].first];
+ }
+
+ // Emit.
+ bool ts_went_backwards = false;
+ for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) {
+ auto& beam_done = ctx_[nth_beam].beam_ctx.beam_done;
+ auto& joint_prob = ctx_[nth_beam].beam_ctx.joint_prob;
+ auto& has_ts = ctx_[nth_beam].has_ts;
+ auto& prompt = ctx_[nth_beam].loop_ctx.prompt;
+ auto& result_len = ctx_[nth_beam].loop_ctx.result_len;
+ auto& seek_delta = ctx_[nth_beam].seek_delta;
+ auto& tokens_cur = ctx_[nth_beam].loop_ctx.tokens_cur;
+ auto& token = new_tokens[nth_beam];
+
+ if (!beam_done) {
+ // Timestamp token: update sliding window.
+ if (token.id > model.vocab.token_beg) {
+ const int seek_delta_new = 2 * (token.id - model.vocab.token_beg);
+ if (has_ts && seek_delta > seek_delta_new && result_len < i) {
+ ts_went_backwards = true;
+ break;
+ }
+ seek_delta = seek_delta_new;
+ result_len = i + 1;
+ has_ts = true;
+ }
+
+ prompt.push_back(token.id);
+ tokens_cur.push_back(token);
+ joint_prob *= token.p;
+ }
+ if (token.id == model.vocab.token_eot ||
+ (params.max_tokens > 0 && i >= params.max_tokens) ||
+ (has_ts && seek + seek_delta + 100 >= seek_end)) {
+ if (result_len == 0)
+ {
+ if (seek + seek_delta + 100 >= seek_end)
+ result_len = i + 1;
+ else
+ {
+ failed = true;
+ break;
+ }
+ }
+ beam_done = true;
+ }
+ }
+
+ bool all_done = true;
+ for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) {
+ if (!ctx_[nth_beam].beam_ctx.beam_done) {
+ all_done = false;
+ break;
+ }
+ }
+ if (all_done || ts_went_backwards || failed) {
+ break;
+ }
}
else {
logError(u8"%s: unsupported decoding strategy: %d", __func__, (int)params.strategy);
@@ -726,12 +917,14 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
continue;
}
+ const int best_beam = beamGetMaxJointProb();
+ auto& prompt_past = ctx_[best_beam].prompt_past;
+ auto& result_len = ctx_[best_beam].loop_ctx.result_len;
+ auto& seek_delta = ctx_[best_beam].seek_delta;
+ auto& tokens_cur = ctx_[best_beam].loop_ctx.tokens_cur;
+
// shrink down to result_len
- // TODO assign these based on results of either decoding strategy
- auto& tokens_cur = ctx_[0].loop_ctx.tokens_cur;
- auto& result_len = ctx_[0].loop_ctx.result_len;
- auto& prompt_past = ctx_[0].prompt_past;
- tokens_cur.resize( result_len );
+ tokens_cur.resize(result_len);
for( const auto& r : tokens_cur )
prompt_past.push_back( r.id );
diff --git a/Whisper/Whisper/ContextImpl.h b/Whisper/Whisper/ContextImpl.h
index 33d3007..2f34126 100644
--- a/Whisper/Whisper/ContextImpl.h
+++ b/Whisper/Whisper/ContextImpl.h
@@ -42,16 +42,30 @@ namespace Whisper
std::vector<whisper_token> prompt_past;
std::vector<float> probs;
std::vector<std::pair<double, Vocabulary::id>> probs_id;
+ int seek_delta;
+ bool has_ts;
+ int n_past = 0;
// These are cleared on every frame of audio processed.
struct AudioFrameContext {
std::vector<whisper_token> prompt;
std::vector<sTokenData> tokens_cur;
- std::vector<float> joint_probs;
int result_len = 0;
- int n_past = 0;
+ };
+ struct BeamSearchContext {
+ // Each beam picks the N most likely tokens and accumulates
+ // them here.
+ std::vector<sTokenData> best_tokens;
+ // Some beams may finish earlier than others, in which case
+ // we'd like to avoid re-running inference.
+ std::vector<float> probs_prev;
+ // Joint probability of every token leading up to the current
+ // context.
+ float joint_prob;
+ bool beam_done;
};
AudioFrameContext loop_ctx;
+ BeamSearchContext beam_ctx;
};
std::vector<Context> ctx_;
@@ -72,6 +86,11 @@ namespace Whisper
int wrapSegment( int max_len );
void expComputeTokenLevelTimestamps( int i_segment, float thold_pt, float thold_ptsum );
+ // Return the (nth beam, nth best) pair with the lowest joint probability.
+ std::pair<int, int> beamGetMinJointProb() const;
+ // Return the (nth beam, nth best) pair with the highest joint probability.
+ int beamGetMaxJointProb() const;
+
mutable TranscribeResultStatic results;
HRESULT COMLIGHTCALL makeResults( eResultFlags flags, TranscribeResult& res ) const noexcept;