diff options
| author | Konstantin <const@const.me> | 2023-01-28 18:38:26 +0100 |
|---|---|---|
| committer | Konstantin <const@const.me> | 2023-01-28 18:38:26 +0100 |
| commit | c75bb96b531414a12fb1bfe599383348e56dbdb5 (patch) | |
| tree | abe00cac30fabd54542f686643efd7f32ce59801 | |
| parent | 9253de193022e78cc4f91f4f1f7e14ba099e6388 (diff) | |
Minor, micro-optimization
| -rw-r--r-- | Whisper/Whisper/ContextImpl.diarize.cpp | 17 |
1 files changed, 14 insertions, 3 deletions
diff --git a/Whisper/Whisper/ContextImpl.diarize.cpp b/Whisper/Whisper/ContextImpl.diarize.cpp index 90acb04..9d88fac 100644 --- a/Whisper/Whisper/ContextImpl.diarize.cpp +++ b/Whisper/Whisper/ContextImpl.diarize.cpp @@ -16,22 +16,33 @@ namespace // and return left / right numbers in the lower 2 lanes of the SSE vector inline __m128 __vectorcall computeChannelsEnergy( const std::vector<StereoSample>& sourceVector ) { + // Might be possible to implement way more sophisticated, and precise, version of this function. + // For example, compute these 3 metrics with VAD code, and cluster the numbers somehow. + // Not doing that currently; instead, replicating the simple version from the whisper.cpp original version. + const StereoSample* rsi = sourceVector.data(); const StereoSample* const rsiEnd = rsi + sourceVector.size(); const StereoSample* const rsiEndAligned = rsi + ( sourceVector.size() & ( ~(size_t)1 ) ); - const __m128 absMask = _mm_set1_ps( -0.0f ); + // Move 0x7FFFFFFF to lowest lane of the int32 vector; + // unlike float scalars or all vectors, integer scalar constants are in the instruction stream + __m128i absMaskInt = _mm_cvtsi32_si128( (int)0x7FFFFFFFu ); + // Broadcast over the complete vector + absMaskInt = _mm_shuffle_epi32( absMaskInt, 0 ); + // Bitcast to FP32 vector, for _mm_and_ps instruction + const __m128 absMask = _mm_castsi128_ps( absMaskInt ); + __m128 acc = _mm_setzero_ps(); for( ; rsi < rsiEndAligned; rsi += 2 ) { __m128 v = _mm_loadu_ps( (const float*)rsi ); - v = _mm_andnot_ps( absMask, v ); + v = _mm_and_ps( v, absMask ); acc = _mm_add_ps( acc, v ); } if( rsi != rsiEnd ) { __m128 v = _mm_castpd_ps( _mm_load_sd( (const double*)rsi ) ); - v = _mm_andnot_ps( absMask, v ); + v = _mm_and_ps( v, absMask ); acc = _mm_add_ps( acc, v ); } |
