diff options
| author | Konstantin <const@const.me> | 2023-02-14 20:51:48 +0100 |
|---|---|---|
| committer | Konstantin <const@const.me> | 2023-02-14 20:51:48 +0100 |
| commit | 2ec5b19e8384538bcd9e440ad99aba9e172f8924 (patch) | |
| tree | 15366fd0f59398873945be98dd2675b24a7f122b | |
| parent | eabba4f72411617b891f260433af2978d78d0d21 (diff) | |
Restored missing token-level timestamps experimental feature
| -rw-r--r-- | Whisper/Whisper/ContextImpl.cpp | 245 |
1 files changed, 244 insertions, 1 deletions
diff --git a/Whisper/Whisper/ContextImpl.cpp b/Whisper/Whisper/ContextImpl.cpp index 5a58e5f..1160f95 100644 --- a/Whisper/Whisper/ContextImpl.cpp +++ b/Whisper/Whisper/ContextImpl.cpp @@ -167,10 +167,253 @@ sTokenData ContextImpl::sampleTimestamp( bool initial ) return sampleBest( probs.data() + ( probs.size() - n_vocab ), true, initial ); } +// a cost-function / heuristic that is high for text that takes longer to pronounce +// Obviously, can be improved +static float voice_length( const char* text ) +{ + if( nullptr == text ) + return 0; + + float res = 0.0f; + while( true ) + { + const char c = *text; + if( c == '\0' ) + return res; + text++; + + // Figure out the increment + float inc; + if( c >= '0' && c <= '9' ) + inc = 3.0f; + else + { + switch( c ) + { + case ' ': inc = 0.01f; break; + case ',': inc = 2.00f; break; + case '.': + case '!': + case '?': + inc = 3.00f; break; + default: + inc = 1.0f; + } + } + + res += inc; + } +} + +static int timestamp_to_sample( int64_t t, int n_samples ) +{ + return std::max( 0, std::min( (int)n_samples - 1, (int)( ( t * SAMPLE_RATE ) / 100 ) ) ); +} + +static int64_t sample_to_timestamp( int i_sample ) +{ + return ( 100 * i_sample ) / SAMPLE_RATE; +} + void ContextImpl::expComputeTokenLevelTimestamps( int i_segment, float thold_pt, float thold_ptsum ) { // whisper_exp_compute_token_level_timestamps - throw E_NOTIMPL; + + auto& segment = result_all[ i_segment ]; + auto& tokens = segment.tokens; + + const int n_samples = energy.size(); + + if( n_samples == 0 ) + { + logWarning( u8"%s: no signal data available", __func__ ); + return; + } + + const int64_t t0 = segment.t0; + const int64_t t1 = segment.t1; + const int n = tokens.size(); + + if( n == 0 ) + return; + + if( n == 1 ) + { + tokens[ 0 ].t0 = t0; + tokens[ 0 ].t1 = t1; + return; + } + + auto& t_beg = this->t_beg; + auto& t_last = this->t_last; + auto& tid_last = this->tid_last; + + for( int j = 0; j < n; ++j ) + { + auto& token = tokens[ j ]; + + if( j == 0 ) + { + if( token.id == model.vocab.token_beg ) + { + tokens[ j ].t0 = t0; + tokens[ j ].t1 = t0; + tokens[ j + 1 ].t0 = t0; + + t_beg = t0; + t_last = t0; + tid_last = model.vocab.token_beg; + } + else + { + tokens[ j ].t0 = t_last; + } + } + + const int64_t tt = t_beg + 2 * ( token.tid - model.vocab.token_beg ); + + tokens[ j ].id = token.id; + tokens[ j ].tid = token.tid; + tokens[ j ].p = token.p; + tokens[ j ].pt = token.pt; + tokens[ j ].ptsum = token.ptsum; + tokens[ j ].vlen = voice_length( model.vocab.string( token.id ) ); + + if( token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1 ) + { + if( j > 0 ) + tokens[ j - 1 ].t1 = tt; + tokens[ j ].t0 = tt; + tid_last = token.tid; + } + } + + tokens[ n - 2 ].t1 = t1; + tokens[ n - 1 ].t0 = t1; + tokens[ n - 1 ].t1 = t1; + t_last = t1; + + // find intervals of tokens with unknown timestamps + // fill the timestamps by proportionally splitting the interval based on the token voice lengths + { + int p0 = 0; + int p1 = 0; + + while( true ) + { + while( p1 < n && tokens[ p1 ].t1 < 0 ) + p1++; + + if( p1 >= n ) + p1--; + + if( p1 > p0 ) + { + double psum = 0.0; + for( int j = p0; j <= p1; j++ ) + psum += tokens[ j ].vlen; + + //printf("analyzing %d - %d, psum = %f\n", p0, p1, psum); + const double dt = tokens[ p1 ].t1 - tokens[ p0 ].t0; + + // split the time proportionally to the voice length + for( int j = p0 + 1; j <= p1; j++ ) + { + const double ct = tokens[ j - 1 ].t0 + dt * tokens[ j - 1 ].vlen / psum; + tokens[ j - 1 ].t1 = ct; + tokens[ j ].t0 = ct; + } + } + + p1++; + p0 = p1; + if( p1 >= n ) + break; + } + } + + // fix up (just in case) + for( int j = 0; j < n - 1; j++ ) + { + if( tokens[ j ].t1 < 0 ) + tokens[ j + 1 ].t0 = tokens[ j ].t1; + + if( j > 0 ) + { + if( tokens[ j - 1 ].t1 > tokens[ j ].t0 ) { + tokens[ j ].t0 = tokens[ j - 1 ].t1; + tokens[ j ].t1 = std::max( tokens[ j ].t0, tokens[ j ].t1 ); + } + } + } + + // VAD + // expand or contract tokens based on voice activity + { + constexpr int hw = SAMPLE_RATE / 8; + + for( int j = 0; j < n; j++ ) + { + if( tokens[ j ].id >= model.vocab.token_eot ) + continue; + + int s0 = timestamp_to_sample( tokens[ j ].t0, n_samples ); + int s1 = timestamp_to_sample( tokens[ j ].t1, n_samples ); + + const int ss0 = std::max( s0 - hw, 0 ); + const int ss1 = std::min( s1 + hw, n_samples ); + + const int ns = ss1 - ss0; + + float sum = 0.0f; + for( int k = ss0; k < ss1; k++ ) + sum += this->energy[ k ]; + + const float thold = 0.5 * sum / ns; + + { + int k = s0; + if( this->energy[ k ] > thold && j > 0 ) + { + while( k > 0 && this->energy[ k ] > thold ) + k--; + tokens[ j ].t0 = sample_to_timestamp( k ); + if( tokens[ j ].t0 < tokens[ j - 1 ].t1 ) + tokens[ j ].t0 = tokens[ j - 1 ].t1; + else + s0 = k; + } + else + { + while( this->energy[ k ] < thold && k < s1 ) + k++; + s0 = k; + tokens[ j ].t0 = sample_to_timestamp( k ); + } + } + + { + int k = s1; + if( this->energy[ k ] > thold ) + { + while( k < n_samples - 1 && this->energy[ k ] > thold ) + k++; + tokens[ j ].t1 = sample_to_timestamp( k ); + if( j < ns - 1 && tokens[ j ].t1 > tokens[ j + 1 ].t0 ) + tokens[ j ].t1 = tokens[ j + 1 ].t0; + else + s1 = k; + } + else + { + while( this->energy[ k ] < thold && k > s0 ) + k--; + s1 = k; + tokens[ j ].t1 = sample_to_timestamp( k ); + } + } + } + } } static std::string to_timestamp( int64_t t, bool comma = false ) |
