summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2023-02-27 18:03:02 -0800
committeryum <yum.food.vr@gmail.com>2023-02-27 18:03:02 -0800
commit4f967dbdb7972ec52039bd3e3ce3e1e4cbcf6545 (patch)
tree158cf93f6ec626f32b961acc6d791bd2a2040b94
parent1136acfc365f357d2df13a263714e8ae0614c4f9 (diff)
Begin work on beam search decoding
* ContextImpl.h puts prompts, previous prompts, probabilities, and probability IDs into vectors of size 1 or N_BEAMS, depending on the decoding strategy. * Extend sampleBest and friends to return top N tokens, instead of just the top 1 token.
-rw-r--r--Designs/.BeamSearch.md.swpbin0 -> 12288 bytes
-rw-r--r--Designs/BeamSearch.md33
-rw-r--r--Whisper/Whisper/ContextImpl.cpp252
-rw-r--r--Whisper/Whisper/ContextImpl.h17
-rw-r--r--Whisper/Whisper/ContextImpl.misc.cpp6
5 files changed, 220 insertions, 88 deletions
diff --git a/Designs/.BeamSearch.md.swp b/Designs/.BeamSearch.md.swp
new file mode 100644
index 0000000..8102d31
--- /dev/null
+++ b/Designs/.BeamSearch.md.swp
Binary files differ
diff --git a/Designs/BeamSearch.md b/Designs/BeamSearch.md
new file mode 100644
index 0000000..4902b45
--- /dev/null
+++ b/Designs/BeamSearch.md
@@ -0,0 +1,33 @@
+This is the design for the beam search decoding algorithm.
+
+ContextImpl.cpp is where greedy decoding is implemented.
+
+sampleBest() does the following:
+
+1. Populate `probs_ids` with (probability, idx).
+2. Find the `top_k` most likely tokens.
+3. Filter out tokens matching `token_sot`, `token_solm`, `token_not`.
+ 3.1. sot == start of text, solm = ???, not == not timestamps
+4. Return the result.
+
+We should modify this to return the most likely N tokens. We'll call this
+sampleBestN().
+
+Greedy search works like this:
+
+1. Call decode(prompt, tokens).
+ 1.1. This populates `probs`. Need to wrap this in a vector of size `N_BEAMS`.
+ 1.2. Need to wrap `prompt` in a vector of size `N_BEAMS`.
+2. Extract the most likely token and add it to the prompt.
+3. Repeat until EOT (end of transcript) or max tokens or end of audio stream.
+
+Beam search will work like this:
+
+1. Initialize `prompts` as a vector containing `prompt`.
+2. Call decode(prompt, tokens) for each prompt in `prompts`.
+3. Extract the most likely `N_BEAMS` tokens in each result.
+4. Compute joint probabilities for each token.
+5. Extract the most likely `N_BEAMS` prompts using joint probabilities.
+6. Update `prompts`.
+7. Repeat until end condition is reached for all prompts.
+8. Return most likely prompt.
diff --git a/Whisper/Whisper/ContextImpl.cpp b/Whisper/Whisper/ContextImpl.cpp
index 1160f95..d6d46f9 100644
--- a/Whisper/Whisper/ContextImpl.cpp
+++ b/Whisper/Whisper/ContextImpl.cpp
@@ -42,7 +42,7 @@ HRESULT ContextImpl::encode( iSpectrogram& mel, int seek )
}
}
-HRESULT ContextImpl::decode( const int* tokens, size_t length, int n_past, int threads )
+HRESULT ContextImpl::decode( const int* tokens, size_t length, int n_past, int threads, int nth )
{
// whisper_decode
using namespace DirectCompute;
@@ -57,7 +57,7 @@ HRESULT ContextImpl::decode( const int* tokens, size_t length, int n_past, int t
try
{
- context.decode( tokens, (int)length, dp, probs, threads );
+ context.decode( tokens, (int)length, dp, probs_vec[nth], threads);
return S_OK;
}
catch( HRESULT hr )
@@ -67,27 +67,27 @@ HRESULT ContextImpl::decode( const int* tokens, size_t length, int n_past, int t
}
// the most basic sampling scheme - select the top token
-sTokenData ContextImpl::sampleBest( const float* probs, bool force_timestamp, bool is_initial )
+std::vector<sTokenData> ContextImpl::sampleBestN( const float* probs,
+ bool force_timestamp, bool is_initial, int nth, int n_best )
{
// whisper_sample_best
const Vocabulary& vocab = model.vocab;
- sTokenData result = { 0 };
+ std::vector<sTokenData> result_vec(n_best, { 0 });
size_t n_logits = vocab.size();
- probs_id.clear();
- probs_id.reserve( n_logits );
+ probs_id_vec[nth].clear();
+ probs_id_vec[nth].reserve(n_logits);
for( size_t i = 0; i < n_logits; i++ )
- probs_id.emplace_back( probs[ i ], (int)i );
-
+ probs_id_vec[nth].emplace_back(probs[i], (int)i);
{
double sum_ts = 0.0;
double max_ts = -1.0;
double max_tx = -1.0;
for( int i = 0; i < vocab.token_beg; i++ )
- max_tx = std::max( max_tx, probs_id[ i ].first );
+ max_tx = std::max( max_tx, probs_id_vec[ nth ][ i ].first );
const int i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg;
const int i1 = is_initial ? vocab.token_beg + 101 : (int)n_logits;
@@ -97,16 +97,18 @@ sTokenData ContextImpl::sampleBest( const float* probs, bool force_timestamp, bo
if( is_initial )
{
for( int i = i0; i < n_logits; i++ )
- probs_id[ i ].first = -INFINITY;
+ probs_id_vec[ nth ][ i ].first = -INFINITY;
}
for( int i = vocab.token_beg; i < i1; i++ )
{
- sum_ts += probs_id[ i ].first;
- if( probs_id[ i ].first > max_ts )
+ sum_ts += probs_id_vec[ nth ][ i ].first;
+ if( probs_id_vec[ nth ][ i ].first > max_ts )
{
- max_ts = probs_id[ i ].first;
- result.tid = probs_id[ i ].second;
+ max_ts = probs_id_vec[ nth ][ i ].first;
+ for (auto& result : result_vec) {
+ result.tid = probs_id_vec[nth][i].second;
+ }
}
}
@@ -116,55 +118,72 @@ sTokenData ContextImpl::sampleBest( const float* probs, bool force_timestamp, bo
{
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
for( int i = 0; i < vocab.token_beg; i++ )
- probs_id[ i ].first = -INFINITY;
+ probs_id_vec[ nth ][ i ].first = -INFINITY;
}
- result.pt = (float)( max_ts / ( sum_ts + 1e-10 ) );
- result.ptsum = (float)sum_ts;
+ for (auto& result : result_vec) {
+ result.pt = (float)(max_ts / (sum_ts + 1e-10));
+ result.ptsum = (float)sum_ts;
+ }
}
// find the top K tokens
- const int top_k = 4;
+ const int top_k = 4 + n_best - 1;
std::partial_sort(
- probs_id.begin(),
- probs_id.begin() + top_k, probs_id.end(),
+ probs_id_vec[nth].begin(),
+ probs_id_vec[nth].begin() + top_k, probs_id_vec[nth].end(),
[]( const std::pair<double, Vocabulary::id>& a, const std::pair<double, Vocabulary::id>& b ) {
return a.first > b.first;
} );
- probs_id.resize( top_k );
+ probs_id_vec[nth].resize(top_k);
//printf("\n");
- //for (int i = 0; i < (int) probs_id.size(); i++) {
- // printf("%d: '%s' %f, %d\n", i, vocab.id_to_token.at(probs_id[i].second).c_str(), probs_id[i].first, probs_id[i].second);
+ //for (int i = 0; i < (int) probs_id_vec[nth].size(); i++) {
+ // printf("%d: '%s' %f, %d\n", i,
+ // vocab.id_to_token.at(probs_id_vec[nth][i].second).c_str(),
+ // probs_id_vec[nth][i].first, probs_id_vec[nth][i].second);
//}
- int res = 0;
- while( ( probs_id[ res ].second == vocab.token_sot ||
- probs_id[ res ].second == vocab.token_solm ||
- probs_id[ res ].second == vocab.token_not ) &&
- res < (int)probs_id.size() - 1 )
- {
- res++;
+ std::vector<int> res_vec(n_best, 0);
+ int i = 0;
+ for (int j = 0; j < n_best; j++) {
+ // Scan past unwanted tokens.
+ while ((probs_id_vec[nth][i].second == vocab.token_sot ||
+ probs_id_vec[nth][i].second == vocab.token_solm ||
+ probs_id_vec[nth][i].second == vocab.token_not) &&
+ i < (int)probs_id_vec[nth].size() - 1)
+ {
+ i++;
+ }
+ res_vec[j] = i;
}
- result.id = probs_id[ res ].second;
- result.p = (float)probs_id[ res ].first;
+ assert(result_vec.size() == res_vec.size());
+ for (int i = 0; i < res_vec.size(); i++) {
+ result_vec[i].id = probs_id_vec[nth][res_vec[i]].second;
+ result_vec[i].p = (float)probs_id_vec[nth][res_vec[i]].first;
+ }
- return result;
+ return result_vec;
}
-sTokenData ContextImpl::sampleBest()
+std::vector<sTokenData> ContextImpl::sampleBestN(int nth, int n_best)
{
const int n_vocab = model.vocab.n_vocab;
- return sampleBest( probs.data() + ( probs.size() - n_vocab ), false, false );
+ return sampleBestN(
+ probs_vec[nth].data() + (probs_vec[nth].size() - n_vocab),
+ false, false, nth, /*n_best=*/1);
}
-sTokenData ContextImpl::sampleTimestamp( bool initial )
+std::vector<sTokenData> ContextImpl::sampleTimestampN(bool initial, int nth,
+ int n_best)
{
const int n_vocab = model.vocab.n_vocab;
- return sampleBest( probs.data() + ( probs.size() - n_vocab ), true, initial );
+ return sampleBestN(
+ probs_vec[nth].data() + (probs_vec[nth].size() - n_vocab),
+ true, initial, nth, n_best);
}
// a cost-function / heuristic that is high for text that takes longer to pronounce
@@ -467,17 +486,34 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
if( seek_end < 100 + seek_start )
return S_FALSE;
+ if (prompt_past_vec.empty()) {
+ if (params.strategy == Whisper::eSamplingStrategy::Greedy) {
+ prompt_past_vec = std::vector<prompt_t>(1);
+ }
+ else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) {
+ prompt_past_vec = std::vector<prompt_t>(params.beam_search.beam_width);
+ }
+ }
+
// the accumulated text context so far
- if( params.flag( eFullParamsFlags::NoContext ) )
- prompt_past.clear();
+ if (params.flag(eFullParamsFlags::NoContext)) {
+ for (auto& prompt_past : prompt_past_vec) {
+ prompt_past.clear();
+ }
+ }
// prepend the prompt tokens to the prompt_past
if( params.prompt_tokens && params.prompt_n_tokens > 0 )
{
// parse tokens from the pointer
- for( int i = 0; i < params.prompt_n_tokens; i++ )
- prompt_past.push_back( params.prompt_tokens[ i ] );
- std::rotate( prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end() );
+ for (int i = 0; i < params.prompt_n_tokens; i++) {
+ for (auto& prompt_past : prompt_past_vec) {
+ prompt_past.push_back(params.prompt_tokens[i]);
+ }
+ }
+ for (auto& prompt_past : prompt_past_vec) {
+ std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
+ }
}
// overwrite audio_ctx
@@ -507,10 +543,24 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
// int progress_prev = 0;
// int progress_step = 5;
- std::vector<sTokenData> tokens_cur;
- tokens_cur.reserve( model.parameters.n_text_ctx );
- std::vector<whisper_token> prompt;
- prompt.reserve( model.parameters.n_text_ctx );
+ typedef std::vector<sTokenData> tokens_t;
+ std::vector<tokens_t> tokens_cur_vec;
+ typedef std::vector<whisper_token> prompt_t;
+ std::vector<prompt_t> prompt_vec;
+ if (params.strategy == Whisper::eSamplingStrategy::Greedy) {
+ prompt_vec = std::vector<prompt_t>(1);
+ tokens_cur_vec = std::vector<tokens_t>(1);
+ }
+ else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) {
+ prompt_vec = std::vector<prompt_t>(params.beam_search.beam_width);
+ tokens_cur_vec = std::vector<tokens_t>(params.beam_search.beam_width);
+ }
+ for (auto& prompt : prompt_vec) {
+ prompt.reserve(model.parameters.n_text_ctx);
+ }
+ for (auto& tokens_cur : tokens_cur_vec) {
+ tokens_cur.reserve(model.parameters.n_text_ctx);
+ }
// main loop
int seek = seek_start;
@@ -553,22 +603,38 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
// encode audio features starting at offset seek
CHECK( encode( mel, seek ) );
- int n_past = 0;
- prompt.clear();
+ std::vector<int> n_past_vec;
+ if (params.strategy == Whisper::eSamplingStrategy::Greedy) {
+ n_past_vec = std::vector<int>(1, 0);
+ }
+ else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) {
+ n_past_vec = std::vector<int>(params.beam_search.beam_width, 0);
+ }
+
+ for (auto& prompt : prompt_vec) {
+ prompt.clear();
+ }
// if we have already generated some text, use it as a prompt to condition the next generation
- if( !prompt_past.empty() )
- {
- int n_take = std::min( std::min( params.n_max_text_ctx, model.parameters.n_text_ctx / 2 ), int( prompt_past.size() ) );
+ assert(prompt_past_vec.size() == prompt_vec.size());
+ for (int i = 0; i < prompt_past_vec.size(); i++) {
+ auto& prompt_past = prompt_past_vec[i];
+ auto& prompt = prompt_vec[i];
+ if (!prompt_past.empty())
+ {
+ int n_take = std::min(std::min(params.n_max_text_ctx, model.parameters.n_text_ctx / 2), int(prompt_past.size()));
- prompt = { model.vocab.token_prev };
- prompt.insert( prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end() );
+ prompt = { model.vocab.token_prev };
+ prompt.insert(prompt.begin() + 1, prompt_past.end() - n_take, prompt_past.end());
- prompt_past.clear();
- prompt_past.insert( prompt_past.end(), prompt.begin() + 1, prompt.end() );
+ prompt_past.clear();
+ prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
+ }
}
- prompt.insert( prompt.end(), prompt_init.begin(), prompt_init.end() );
+ for (auto& prompt : prompt_vec) {
+ prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
+ }
int seek_delta = 100 * WHISPER_CHUNK_SIZE;
@@ -580,8 +646,17 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
//printf("\n\n");
// the accumulated transcription in the current iteration
- int result_len = 0;
- tokens_cur.clear();
+ std::vector<int> result_len_vec;
+ if (params.strategy == Whisper::eSamplingStrategy::Greedy) {
+ result_len_vec = std::vector<int>(1, 0);
+ }
+ else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) {
+ result_len_vec = std::vector<int>(params.beam_search.beam_width, 0);
+ }
+
+ for (auto& tokens_cur : tokens_cur_vec) {
+ tokens_cur.clear();
+ }
bool failed = false;
bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
@@ -591,21 +666,29 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
auto prof = context.decodeProfiler();
for( int i = 0, n_max = model.parameters.n_text_ctx / 2 - 4; i < n_max; i++ )
{
- CHECK( decode( prompt.data(), prompt.size(), n_past, params.cpuThreads ) );
-
- n_past += (int)prompt.size();
- prompt.clear();
-
- // very basic greedy sampling strategy:
- //
- // - always take the most probable token
- //
- // more sophisticated sampling strategies could be implemented here, but we keep it simple
- // feel free to experiment!
- //
- {
+ if (params.strategy == Whisper::eSamplingStrategy::Greedy) {
+ auto& prompt = prompt_vec[0];
+ auto& tokens_cur = tokens_cur_vec[0];
+ auto& n_past = n_past_vec[0];
+ auto& result_len = result_len_vec[0];
+
+ CHECK(decode(prompt.data(), prompt.size(), n_past, params.cpuThreads, /*nth=*/0));
+
+ n_past += (int)prompt.size();
+ prompt.clear();
+
+ // very basic greedy sampling strategy:
+ //
+ // - always take the most probable token
+ //
+ // more sophisticated sampling strategies could be implemented here, but we keep it simple
+ // feel free to experiment!
+ //
+
auto p = profiler.cpuBlock( eCpuBlock::Sample );
- const sTokenData token = ( i == 0 ) ? sampleTimestamp( true ) : sampleBest();
+ const sTokenData token = ( i == 0 )
+ ? sampleTimestampN( true, /*nth=*/0, /*n_best=*/1)[0]
+ : sampleBestN(/*nth=*/0, /*n_best=*/1)[0];
// timestamp token - update sliding window
if( token.id > model.vocab.token_beg )
@@ -655,13 +738,22 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
break;
}
- }
- // sometimes, the decoding can get stuck in a repetition loop
- // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance
- // the sliding window by 1 second
- if( i == n_max - 1 && ( result_len == 0 || seek_delta < 100 * WHISPER_CHUNK_SIZE / 2 ) )
- {
+ // sometimes, the decoding can get stuck in a repetition loop
+ // this is a simple strategy to avoid such cases - we simply flag the decoding as failed and advance
+ // the sliding window by 1 second
+ if (i == n_max - 1 && (result_len == 0 || seek_delta < 100 * WHISPER_CHUNK_SIZE / 2))
+ {
+ failed = true;
+ break;
+ }
+ } else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) {
+ logError(u8"%s: beam search not implemented", __func__);
+ failed = true;
+ break;
+ }
+ else {
+ logError(u8"%s: unsupported decoding strategy: %d", __func__, (int)params.strategy);
failed = true;
break;
}
@@ -675,6 +767,10 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
}
// shrink down to result_len
+ // TODO assign these based on results of either decoding strategy
+ auto& tokens_cur = tokens_cur_vec[0];
+ auto& result_len = result_len_vec[0];
+ auto& prompt_past = prompt_past_vec[0];
tokens_cur.resize( result_len );
for( const auto& r : tokens_cur )
diff --git a/Whisper/Whisper/ContextImpl.h b/Whisper/Whisper/ContextImpl.h
index c404b8f..cdad36b 100644
--- a/Whisper/Whisper/ContextImpl.h
+++ b/Whisper/Whisper/ContextImpl.h
@@ -38,7 +38,8 @@ namespace Whisper
};
std::vector<Segment> result_all;
- std::vector<whisper_token> prompt_past;
+ typedef std::vector<whisper_token> prompt_t;
+ std::vector<prompt_t> prompt_past_vec;
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg = 0;
@@ -50,15 +51,17 @@ namespace Whisper
int32_t exp_n_audio_ctx = 0; // 0 - use default
HRESULT encode( iSpectrogram& mel, int seek );
- HRESULT decode( const int* tokens, size_t length, int n_past, int threads );
- sTokenData sampleBest( const float* probs, bool force_timestamp, bool is_initial );
- sTokenData sampleBest();
- sTokenData sampleTimestamp( bool initial );
+ HRESULT decode( const int* tokens, size_t length, int n_past, int threads, int nth );
+ std::vector<sTokenData> sampleBestN( const float* probs, bool force_timestamp, bool is_initial, int nth, int n_best );
+ std::vector<sTokenData> sampleBestN(int nth, int n_best);
+ std::vector<sTokenData> sampleTimestampN( bool initial, int nth, int n_best );
int wrapSegment( int max_len );
void expComputeTokenLevelTimestamps( int i_segment, float thold_pt, float thold_ptsum );
- std::vector<float> probs;
- std::vector<std::pair<double, Vocabulary::id>> probs_id;
+ typedef std::vector<float> probs_t;
+ std::vector<probs_t> probs_vec{ 1 };
+ typedef std::vector<std::pair<double, Vocabulary::id>> probs_id_t;
+ std::vector<probs_id_t> probs_id_vec{ 1 };
mutable TranscribeResultStatic results;
diff --git a/Whisper/Whisper/ContextImpl.misc.cpp b/Whisper/Whisper/ContextImpl.misc.cpp
index 9a156fb..90a1bec 100644
--- a/Whisper/Whisper/ContextImpl.misc.cpp
+++ b/Whisper/Whisper/ContextImpl.misc.cpp
@@ -114,10 +114,10 @@ __m128i ContextImpl::getMemoryUse() const
size_t cb = vectorMemoryUse( result_all );
for( const auto& r : result_all )
cb += r.memoryUsage();
- cb += vectorMemoryUse( prompt_past );
+ cb += vectorMemoryUse( prompt_past_vec );
cb += vectorMemoryUse( energy );
- cb += vectorMemoryUse( probs );
- cb += vectorMemoryUse( probs_id );
+ cb += vectorMemoryUse( probs_vec );
+ cb += vectorMemoryUse( probs_id_vec );
cb += vectorMemoryUse( results.segments );
cb += vectorMemoryUse( results.tokens );
cb += spectrogram.memoryUsage();