summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-02-14 20:51:48 +0100
committerKonstantin <const@const.me>2023-02-14 20:51:48 +0100
commit2ec5b19e8384538bcd9e440ad99aba9e172f8924 (patch)
tree15366fd0f59398873945be98dd2675b24a7f122b
parenteabba4f72411617b891f260433af2978d78d0d21 (diff)
Restored missing token-level timestamps experimental feature
-rw-r--r--Whisper/Whisper/ContextImpl.cpp245
1 files changed, 244 insertions, 1 deletions
diff --git a/Whisper/Whisper/ContextImpl.cpp b/Whisper/Whisper/ContextImpl.cpp
index 5a58e5f..1160f95 100644
--- a/Whisper/Whisper/ContextImpl.cpp
+++ b/Whisper/Whisper/ContextImpl.cpp
@@ -167,10 +167,253 @@ sTokenData ContextImpl::sampleTimestamp( bool initial )
return sampleBest( probs.data() + ( probs.size() - n_vocab ), true, initial );
}
+// a cost-function / heuristic that is high for text that takes longer to pronounce
+// Obviously, can be improved
+static float voice_length( const char* text )
+{
+ if( nullptr == text )
+ return 0;
+
+ float res = 0.0f;
+ while( true )
+ {
+ const char c = *text;
+ if( c == '\0' )
+ return res;
+ text++;
+
+ // Figure out the increment
+ float inc;
+ if( c >= '0' && c <= '9' )
+ inc = 3.0f;
+ else
+ {
+ switch( c )
+ {
+ case ' ': inc = 0.01f; break;
+ case ',': inc = 2.00f; break;
+ case '.':
+ case '!':
+ case '?':
+ inc = 3.00f; break;
+ default:
+ inc = 1.0f;
+ }
+ }
+
+ res += inc;
+ }
+}
+
+static int timestamp_to_sample( int64_t t, int n_samples )
+{
+ return std::max( 0, std::min( (int)n_samples - 1, (int)( ( t * SAMPLE_RATE ) / 100 ) ) );
+}
+
+static int64_t sample_to_timestamp( int i_sample )
+{
+ return ( 100 * i_sample ) / SAMPLE_RATE;
+}
+
void ContextImpl::expComputeTokenLevelTimestamps( int i_segment, float thold_pt, float thold_ptsum )
{
// whisper_exp_compute_token_level_timestamps
- throw E_NOTIMPL;
+
+ auto& segment = result_all[ i_segment ];
+ auto& tokens = segment.tokens;
+
+ const int n_samples = energy.size();
+
+ if( n_samples == 0 )
+ {
+ logWarning( u8"%s: no signal data available", __func__ );
+ return;
+ }
+
+ const int64_t t0 = segment.t0;
+ const int64_t t1 = segment.t1;
+ const int n = tokens.size();
+
+ if( n == 0 )
+ return;
+
+ if( n == 1 )
+ {
+ tokens[ 0 ].t0 = t0;
+ tokens[ 0 ].t1 = t1;
+ return;
+ }
+
+ auto& t_beg = this->t_beg;
+ auto& t_last = this->t_last;
+ auto& tid_last = this->tid_last;
+
+ for( int j = 0; j < n; ++j )
+ {
+ auto& token = tokens[ j ];
+
+ if( j == 0 )
+ {
+ if( token.id == model.vocab.token_beg )
+ {
+ tokens[ j ].t0 = t0;
+ tokens[ j ].t1 = t0;
+ tokens[ j + 1 ].t0 = t0;
+
+ t_beg = t0;
+ t_last = t0;
+ tid_last = model.vocab.token_beg;
+ }
+ else
+ {
+ tokens[ j ].t0 = t_last;
+ }
+ }
+
+ const int64_t tt = t_beg + 2 * ( token.tid - model.vocab.token_beg );
+
+ tokens[ j ].id = token.id;
+ tokens[ j ].tid = token.tid;
+ tokens[ j ].p = token.p;
+ tokens[ j ].pt = token.pt;
+ tokens[ j ].ptsum = token.ptsum;
+ tokens[ j ].vlen = voice_length( model.vocab.string( token.id ) );
+
+ if( token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1 )
+ {
+ if( j > 0 )
+ tokens[ j - 1 ].t1 = tt;
+ tokens[ j ].t0 = tt;
+ tid_last = token.tid;
+ }
+ }
+
+ tokens[ n - 2 ].t1 = t1;
+ tokens[ n - 1 ].t0 = t1;
+ tokens[ n - 1 ].t1 = t1;
+ t_last = t1;
+
+ // find intervals of tokens with unknown timestamps
+ // fill the timestamps by proportionally splitting the interval based on the token voice lengths
+ {
+ int p0 = 0;
+ int p1 = 0;
+
+ while( true )
+ {
+ while( p1 < n && tokens[ p1 ].t1 < 0 )
+ p1++;
+
+ if( p1 >= n )
+ p1--;
+
+ if( p1 > p0 )
+ {
+ double psum = 0.0;
+ for( int j = p0; j <= p1; j++ )
+ psum += tokens[ j ].vlen;
+
+ //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
+ const double dt = tokens[ p1 ].t1 - tokens[ p0 ].t0;
+
+ // split the time proportionally to the voice length
+ for( int j = p0 + 1; j <= p1; j++ )
+ {
+ const double ct = tokens[ j - 1 ].t0 + dt * tokens[ j - 1 ].vlen / psum;
+ tokens[ j - 1 ].t1 = ct;
+ tokens[ j ].t0 = ct;
+ }
+ }
+
+ p1++;
+ p0 = p1;
+ if( p1 >= n )
+ break;
+ }
+ }
+
+ // fix up (just in case)
+ for( int j = 0; j < n - 1; j++ )
+ {
+ if( tokens[ j ].t1 < 0 )
+ tokens[ j + 1 ].t0 = tokens[ j ].t1;
+
+ if( j > 0 )
+ {
+ if( tokens[ j - 1 ].t1 > tokens[ j ].t0 ) {
+ tokens[ j ].t0 = tokens[ j - 1 ].t1;
+ tokens[ j ].t1 = std::max( tokens[ j ].t0, tokens[ j ].t1 );
+ }
+ }
+ }
+
+ // VAD
+ // expand or contract tokens based on voice activity
+ {
+ constexpr int hw = SAMPLE_RATE / 8;
+
+ for( int j = 0; j < n; j++ )
+ {
+ if( tokens[ j ].id >= model.vocab.token_eot )
+ continue;
+
+ int s0 = timestamp_to_sample( tokens[ j ].t0, n_samples );
+ int s1 = timestamp_to_sample( tokens[ j ].t1, n_samples );
+
+ const int ss0 = std::max( s0 - hw, 0 );
+ const int ss1 = std::min( s1 + hw, n_samples );
+
+ const int ns = ss1 - ss0;
+
+ float sum = 0.0f;
+ for( int k = ss0; k < ss1; k++ )
+ sum += this->energy[ k ];
+
+ const float thold = 0.5 * sum / ns;
+
+ {
+ int k = s0;
+ if( this->energy[ k ] > thold && j > 0 )
+ {
+ while( k > 0 && this->energy[ k ] > thold )
+ k--;
+ tokens[ j ].t0 = sample_to_timestamp( k );
+ if( tokens[ j ].t0 < tokens[ j - 1 ].t1 )
+ tokens[ j ].t0 = tokens[ j - 1 ].t1;
+ else
+ s0 = k;
+ }
+ else
+ {
+ while( this->energy[ k ] < thold && k < s1 )
+ k++;
+ s0 = k;
+ tokens[ j ].t0 = sample_to_timestamp( k );
+ }
+ }
+
+ {
+ int k = s1;
+ if( this->energy[ k ] > thold )
+ {
+ while( k < n_samples - 1 && this->energy[ k ] > thold )
+ k++;
+ tokens[ j ].t1 = sample_to_timestamp( k );
+ if( j < ns - 1 && tokens[ j ].t1 > tokens[ j + 1 ].t0 )
+ tokens[ j ].t1 = tokens[ j + 1 ].t0;
+ else
+ s1 = k;
+ }
+ else
+ {
+ while( this->energy[ k ] < thold && k > s0 )
+ k--;
+ s1 = k;
+ tokens[ j ].t1 = sample_to_timestamp( k );
+ }
+ }
+ }
+ }
}
static std::string to_timestamp( int64_t t, bool comma = false )