summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2023-03-02 17:43:26 -0800
committeryum <yum.food.vr@gmail.com>2023-03-02 17:43:26 -0800
commitdcd7f3b60e3b9ad8df83d444f8bc67091b411529 (patch)
tree177fa9358b2a160f6247fb074d50e624322b0f4a
parent4f967dbdb7972ec52039bd3e3ce3e1e4cbcf6545 (diff)
Continue work on beam search
Define ContextImpl::Context, wrapping all the data used in decoding. Using a vector of these is much simpler than using N vectors of all the random stuff we need.
-rw-r--r--Whisper/Whisper/ContextImpl.capture.cpp2
-rw-r--r--Whisper/Whisper/ContextImpl.cpp148
-rw-r--r--Whisper/Whisper/ContextImpl.h23
-rw-r--r--Whisper/Whisper/ContextImpl.misc.cpp4
4 files changed, 72 insertions, 105 deletions
diff --git a/Whisper/Whisper/ContextImpl.capture.cpp b/Whisper/Whisper/ContextImpl.capture.cpp
index 0100fcd..642eef1 100644
--- a/Whisper/Whisper/ContextImpl.capture.cpp
+++ b/Whisper/Whisper/ContextImpl.capture.cpp
@@ -65,7 +65,7 @@ namespace
__m128i ints = _mm_cvtps_epi32( floats );
store16( &minDuration, ints );
- retainDuration = std::round(retainDuration * SAMPLE_RATE);
+ retainDuration = (int) (retainDuration * SAMPLE_RATE + 0.5);
flags = cp.flags;
}
diff --git a/Whisper/Whisper/ContextImpl.cpp b/Whisper/Whisper/ContextImpl.cpp
index d6d46f9..5428408 100644
--- a/Whisper/Whisper/ContextImpl.cpp
+++ b/Whisper/Whisper/ContextImpl.cpp
@@ -57,7 +57,7 @@ HRESULT ContextImpl::decode( const int* tokens, size_t length, int n_past, int t
try
{
- context.decode( tokens, (int)length, dp, probs_vec[nth], threads);
+ context.decode( tokens, (int)length, dp, ctx_[nth].probs, threads);
return S_OK;
}
catch( HRESULT hr )
@@ -76,18 +76,18 @@ std::vector<sTokenData> ContextImpl::sampleBestN( const float* probs,
size_t n_logits = vocab.size();
- probs_id_vec[nth].clear();
- probs_id_vec[nth].reserve(n_logits);
+ ctx_[nth].probs_id.clear();
+ ctx_[nth].probs_id.reserve(n_logits);
for( size_t i = 0; i < n_logits; i++ )
- probs_id_vec[nth].emplace_back(probs[i], (int)i);
+ ctx_[nth].probs_id.emplace_back(probs[i], (int)i);
{
double sum_ts = 0.0;
double max_ts = -1.0;
double max_tx = -1.0;
for( int i = 0; i < vocab.token_beg; i++ )
- max_tx = std::max( max_tx, probs_id_vec[ nth ][ i ].first );
+ max_tx = std::max( max_tx, ctx_[ nth ].probs_id[ i ].first );
const int i0 = is_initial ? vocab.token_beg + 101 : vocab.token_beg;
const int i1 = is_initial ? vocab.token_beg + 101 : (int)n_logits;
@@ -97,17 +97,17 @@ std::vector<sTokenData> ContextImpl::sampleBestN( const float* probs,
if( is_initial )
{
for( int i = i0; i < n_logits; i++ )
- probs_id_vec[ nth ][ i ].first = -INFINITY;
+ ctx_[nth].probs_id[i].first = -INFINITY;
}
for( int i = vocab.token_beg; i < i1; i++ )
{
- sum_ts += probs_id_vec[ nth ][ i ].first;
- if( probs_id_vec[ nth ][ i ].first > max_ts )
+ sum_ts += ctx_[nth].probs_id[i].first;
+ if (ctx_[nth].probs_id[i].first > max_ts)
{
- max_ts = probs_id_vec[ nth ][ i ].first;
+ max_ts = ctx_[nth].probs_id[i].first;
for (auto& result : result_vec) {
- result.tid = probs_id_vec[nth][i].second;
+ result.tid = ctx_[nth].probs_id[i].second;
}
}
}
@@ -118,7 +118,7 @@ std::vector<sTokenData> ContextImpl::sampleBestN( const float* probs,
{
// ref: https://github.com/openai/whisper/blob/0b1ba3d46ebf7fe6f953acfd8cad62a4f851b49f/whisper/decoding.py#L430-L438
for( int i = 0; i < vocab.token_beg; i++ )
- probs_id_vec[ nth ][ i ].first = -INFINITY;
+ ctx_[nth].probs_id[ i ].first = -INFINITY;
}
for (auto& result : result_vec) {
@@ -131,13 +131,13 @@ std::vector<sTokenData> ContextImpl::sampleBestN( const float* probs,
const int top_k = 4 + n_best - 1;
std::partial_sort(
- probs_id_vec[nth].begin(),
- probs_id_vec[nth].begin() + top_k, probs_id_vec[nth].end(),
+ ctx_[nth].probs_id.begin(),
+ ctx_[nth].probs_id.begin() + top_k, ctx_[nth].probs_id.end(),
[]( const std::pair<double, Vocabulary::id>& a, const std::pair<double, Vocabulary::id>& b ) {
return a.first > b.first;
} );
- probs_id_vec[nth].resize(top_k);
+ ctx_[nth].probs_id.resize(top_k);
//printf("\n");
//for (int i = 0; i < (int) probs_id_vec[nth].size(); i++) {
@@ -150,10 +150,10 @@ std::vector<sTokenData> ContextImpl::sampleBestN( const float* probs,
int i = 0;
for (int j = 0; j < n_best; j++) {
// Scan past unwanted tokens.
- while ((probs_id_vec[nth][i].second == vocab.token_sot ||
- probs_id_vec[nth][i].second == vocab.token_solm ||
- probs_id_vec[nth][i].second == vocab.token_not) &&
- i < (int)probs_id_vec[nth].size() - 1)
+ while ((ctx_[nth].probs_id[i].second == vocab.token_sot ||
+ ctx_[nth].probs_id[i].second == vocab.token_solm ||
+ ctx_[nth].probs_id[i].second == vocab.token_not) &&
+ i < (int)ctx_[nth].probs_id.size() - 1)
{
i++;
}
@@ -162,8 +162,8 @@ std::vector<sTokenData> ContextImpl::sampleBestN( const float* probs,
assert(result_vec.size() == res_vec.size());
for (int i = 0; i < res_vec.size(); i++) {
- result_vec[i].id = probs_id_vec[nth][res_vec[i]].second;
- result_vec[i].p = (float)probs_id_vec[nth][res_vec[i]].first;
+ result_vec[i].id = ctx_[nth].probs_id[res_vec[i]].second;
+ result_vec[i].p = (float)ctx_[nth].probs_id[res_vec[i]].first;
}
return result_vec;
@@ -173,8 +173,8 @@ std::vector<sTokenData> ContextImpl::sampleBestN(int nth, int n_best)
{
const int n_vocab = model.vocab.n_vocab;
return sampleBestN(
- probs_vec[nth].data() + (probs_vec[nth].size() - n_vocab),
- false, false, nth, /*n_best=*/1);
+ ctx_[nth].probs.data() + (ctx_[nth].probs.size() - n_vocab),
+ false, false, nth, /*n_best=*/n_best);
}
std::vector<sTokenData> ContextImpl::sampleTimestampN(bool initial, int nth,
@@ -182,7 +182,7 @@ std::vector<sTokenData> ContextImpl::sampleTimestampN(bool initial, int nth,
{
const int n_vocab = model.vocab.n_vocab;
return sampleBestN(
- probs_vec[nth].data() + (probs_vec[nth].size() - n_vocab),
+ ctx_[nth].probs.data() + (ctx_[nth].probs.size() - n_vocab),
true, initial, nth, n_best);
}
@@ -486,19 +486,15 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
if( seek_end < 100 + seek_start )
return S_FALSE;
- if (prompt_past_vec.empty()) {
- if (params.strategy == Whisper::eSamplingStrategy::Greedy) {
- prompt_past_vec = std::vector<prompt_t>(1);
- }
- else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) {
- prompt_past_vec = std::vector<prompt_t>(params.beam_search.beam_width);
- }
+ const int n_ctxt = params.strategy == Whisper::eSamplingStrategy::Greedy ? 1 : params.beam_search.beam_width;
+ if (ctx_.size() != n_ctxt) {
+ ctx_.assign(n_ctxt, Context());
}
// the accumulated text context so far
if (params.flag(eFullParamsFlags::NoContext)) {
- for (auto& prompt_past : prompt_past_vec) {
- prompt_past.clear();
+ for (auto& ctx : ctx_) {
+ ctx.prompt_past.clear();
}
}
@@ -506,14 +502,20 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
if( params.prompt_tokens && params.prompt_n_tokens > 0 )
{
// parse tokens from the pointer
- for (int i = 0; i < params.prompt_n_tokens; i++) {
- for (auto& prompt_past : prompt_past_vec) {
- prompt_past.push_back(params.prompt_tokens[i]);
+ for (auto& ctx : ctx_) {
+ for (int i = 0; i < params.prompt_n_tokens; i++) {
+ ctx.prompt_past.push_back(params.prompt_tokens[i]);
}
+ std::rotate(
+ ctx.prompt_past.begin(),
+ ctx.prompt_past.end() - params.prompt_n_tokens,
+ ctx.prompt_past.end());
}
- for (auto& prompt_past : prompt_past_vec) {
- std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
- }
+ }
+
+ for (auto& ctx : ctx_) {
+ ctx.loop_ctx.prompt.reserve(model.parameters.n_text_ctx);
+ ctx.loop_ctx.tokens_cur.reserve(model.parameters.n_text_ctx);
}
// overwrite audio_ctx
@@ -543,25 +545,6 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
// int progress_prev = 0;
// int progress_step = 5;
- typedef std::vector<sTokenData> tokens_t;
- std::vector<tokens_t> tokens_cur_vec;
- typedef std::vector<whisper_token> prompt_t;
- std::vector<prompt_t> prompt_vec;
- if (params.strategy == Whisper::eSamplingStrategy::Greedy) {
- prompt_vec = std::vector<prompt_t>(1);
- tokens_cur_vec = std::vector<tokens_t>(1);
- }
- else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) {
- prompt_vec = std::vector<prompt_t>(params.beam_search.beam_width);
- tokens_cur_vec = std::vector<tokens_t>(params.beam_search.beam_width);
- }
- for (auto& prompt : prompt_vec) {
- prompt.reserve(model.parameters.n_text_ctx);
- }
- for (auto& tokens_cur : tokens_cur_vec) {
- tokens_cur.reserve(model.parameters.n_text_ctx);
- }
-
// main loop
int seek = seek_start;
// Start measuring "Run" profiler value, both CPU and GPU times
@@ -603,23 +586,13 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
// encode audio features starting at offset seek
CHECK( encode( mel, seek ) );
- std::vector<int> n_past_vec;
- if (params.strategy == Whisper::eSamplingStrategy::Greedy) {
- n_past_vec = std::vector<int>(1, 0);
- }
- else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) {
- n_past_vec = std::vector<int>(params.beam_search.beam_width, 0);
- }
+ for (auto& ctx : ctx_) {
+ // if we have already generated some text, use it as a prompt to condition the next generation
+ auto& prompt_past = ctx.prompt_past;
+ auto& prompt = ctx.loop_ctx.prompt;
- for (auto& prompt : prompt_vec) {
prompt.clear();
- }
- // if we have already generated some text, use it as a prompt to condition the next generation
- assert(prompt_past_vec.size() == prompt_vec.size());
- for (int i = 0; i < prompt_past_vec.size(); i++) {
- auto& prompt_past = prompt_past_vec[i];
- auto& prompt = prompt_vec[i];
if (!prompt_past.empty())
{
int n_take = std::min(std::min(params.n_max_text_ctx, model.parameters.n_text_ctx / 2), int(prompt_past.size()));
@@ -630,10 +603,10 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
prompt_past.clear();
prompt_past.insert(prompt_past.end(), prompt.begin() + 1, prompt.end());
}
- }
- for (auto& prompt : prompt_vec) {
prompt.insert(prompt.end(), prompt_init.begin(), prompt_init.end());
+ ctx.loop_ctx.tokens_cur.clear();
+ ctx.loop_ctx.result_len = 0;
}
int seek_delta = 100 * WHISPER_CHUNK_SIZE;
@@ -645,19 +618,6 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
//}
//printf("\n\n");
- // the accumulated transcription in the current iteration
- std::vector<int> result_len_vec;
- if (params.strategy == Whisper::eSamplingStrategy::Greedy) {
- result_len_vec = std::vector<int>(1, 0);
- }
- else if (params.strategy == Whisper::eSamplingStrategy::BeamSearch) {
- result_len_vec = std::vector<int>(params.beam_search.beam_width, 0);
- }
-
- for (auto& tokens_cur : tokens_cur_vec) {
- tokens_cur.clear();
- }
-
bool failed = false;
bool has_ts = false; // have we already sampled a non-beg timestamp token for the current segment?
@@ -667,10 +627,10 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
for( int i = 0, n_max = model.parameters.n_text_ctx / 2 - 4; i < n_max; i++ )
{
if (params.strategy == Whisper::eSamplingStrategy::Greedy) {
- auto& prompt = prompt_vec[0];
- auto& tokens_cur = tokens_cur_vec[0];
- auto& n_past = n_past_vec[0];
- auto& result_len = result_len_vec[0];
+ auto& prompt = ctx_[0].loop_ctx.prompt;
+ auto& tokens_cur = ctx_[0].loop_ctx.tokens_cur;
+ auto& n_past = ctx_[0].loop_ctx.n_past;
+ auto& result_len = ctx_[0].loop_ctx.result_len;
CHECK(decode(prompt.data(), prompt.size(), n_past, params.cpuThreads, /*nth=*/0));
@@ -768,9 +728,9 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
// shrink down to result_len
// TODO assign these based on results of either decoding strategy
- auto& tokens_cur = tokens_cur_vec[0];
- auto& result_len = result_len_vec[0];
- auto& prompt_past = prompt_past_vec[0];
+ auto& tokens_cur = ctx_[0].loop_ctx.tokens_cur;
+ auto& result_len = ctx_[0].loop_ctx.result_len;
+ auto& prompt_past = ctx_[0].prompt_past;
tokens_cur.resize( result_len );
for( const auto& r : tokens_cur )
@@ -882,4 +842,4 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
CHECK( progress.pfn( 1.0, this, progress.pv ) );
}
return S_OK;
-} \ No newline at end of file
+}
diff --git a/Whisper/Whisper/ContextImpl.h b/Whisper/Whisper/ContextImpl.h
index cdad36b..33d3007 100644
--- a/Whisper/Whisper/ContextImpl.h
+++ b/Whisper/Whisper/ContextImpl.h
@@ -38,8 +38,22 @@ namespace Whisper
};
std::vector<Segment> result_all;
- typedef std::vector<whisper_token> prompt_t;
- std::vector<prompt_t> prompt_past_vec;
+ struct Context {
+ std::vector<whisper_token> prompt_past;
+ std::vector<float> probs;
+ std::vector<std::pair<double, Vocabulary::id>> probs_id;
+ // These are cleared on every frame of audio processed.
+ struct AudioFrameContext {
+ std::vector<whisper_token> prompt;
+ std::vector<sTokenData> tokens_cur;
+ std::vector<float> joint_probs;
+
+ int result_len = 0;
+ int n_past = 0;
+ };
+ AudioFrameContext loop_ctx;
+ };
+ std::vector<Context> ctx_;
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg = 0;
@@ -58,11 +72,6 @@ namespace Whisper
int wrapSegment( int max_len );
void expComputeTokenLevelTimestamps( int i_segment, float thold_pt, float thold_ptsum );
- typedef std::vector<float> probs_t;
- std::vector<probs_t> probs_vec{ 1 };
- typedef std::vector<std::pair<double, Vocabulary::id>> probs_id_t;
- std::vector<probs_id_t> probs_id_vec{ 1 };
-
mutable TranscribeResultStatic results;
HRESULT COMLIGHTCALL makeResults( eResultFlags flags, TranscribeResult& res ) const noexcept;
diff --git a/Whisper/Whisper/ContextImpl.misc.cpp b/Whisper/Whisper/ContextImpl.misc.cpp
index 90a1bec..52ef811 100644
--- a/Whisper/Whisper/ContextImpl.misc.cpp
+++ b/Whisper/Whisper/ContextImpl.misc.cpp
@@ -114,10 +114,8 @@ __m128i ContextImpl::getMemoryUse() const
size_t cb = vectorMemoryUse( result_all );
for( const auto& r : result_all )
cb += r.memoryUsage();
- cb += vectorMemoryUse( prompt_past_vec );
+ cb += vectorMemoryUse( ctx_ );
cb += vectorMemoryUse( energy );
- cb += vectorMemoryUse( probs_vec );
- cb += vectorMemoryUse( probs_id_vec );
cb += vectorMemoryUse( results.segments );
cb += vectorMemoryUse( results.tokens );
cb += spectrogram.memoryUsage();