summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-01-28 18:38:26 +0100
committerKonstantin <const@const.me>2023-01-28 18:38:26 +0100
commitc75bb96b531414a12fb1bfe599383348e56dbdb5 (patch)
treeabe00cac30fabd54542f686643efd7f32ce59801
parent9253de193022e78cc4f91f4f1f7e14ba099e6388 (diff)
Minor, micro-optimization
-rw-r--r--Whisper/Whisper/ContextImpl.diarize.cpp17
1 files changed, 14 insertions, 3 deletions
diff --git a/Whisper/Whisper/ContextImpl.diarize.cpp b/Whisper/Whisper/ContextImpl.diarize.cpp
index 90acb04..9d88fac 100644
--- a/Whisper/Whisper/ContextImpl.diarize.cpp
+++ b/Whisper/Whisper/ContextImpl.diarize.cpp
@@ -16,22 +16,33 @@ namespace
// and return left / right numbers in the lower 2 lanes of the SSE vector
inline __m128 __vectorcall computeChannelsEnergy( const std::vector<StereoSample>& sourceVector )
{
+ // Might be possible to implement way more sophisticated, and precise, version of this function.
+ // For example, compute these 3 metrics with VAD code, and cluster the numbers somehow.
+ // Not doing that currently; instead, replicating the simple version from the whisper.cpp original version.
+
const StereoSample* rsi = sourceVector.data();
const StereoSample* const rsiEnd = rsi + sourceVector.size();
const StereoSample* const rsiEndAligned = rsi + ( sourceVector.size() & ( ~(size_t)1 ) );
- const __m128 absMask = _mm_set1_ps( -0.0f );
+ // Move 0x7FFFFFFF to lowest lane of the int32 vector;
+ // unlike float scalars or all vectors, integer scalar constants are in the instruction stream
+ __m128i absMaskInt = _mm_cvtsi32_si128( (int)0x7FFFFFFFu );
+ // Broadcast over the complete vector
+ absMaskInt = _mm_shuffle_epi32( absMaskInt, 0 );
+ // Bitcast to FP32 vector, for _mm_and_ps instruction
+ const __m128 absMask = _mm_castsi128_ps( absMaskInt );
+
__m128 acc = _mm_setzero_ps();
for( ; rsi < rsiEndAligned; rsi += 2 )
{
__m128 v = _mm_loadu_ps( (const float*)rsi );
- v = _mm_andnot_ps( absMask, v );
+ v = _mm_and_ps( v, absMask );
acc = _mm_add_ps( acc, v );
}
if( rsi != rsiEnd )
{
__m128 v = _mm_castpd_ps( _mm_load_sd( (const double*)rsi ) );
- v = _mm_andnot_ps( absMask, v );
+ v = _mm_and_ps( v, absMask );
acc = _mm_add_ps( acc, v );
}