From 27dfc3428a7016e2d05dd67b6d8b88c0b982baa9 Mon Sep 17 00:00:00 2001 From: Konstantin Date: Mon, 23 Jan 2023 14:38:12 +0100 Subject: Performance improvement, `softMax` shader --- Whisper/ML/MlContext.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) (limited to 'Whisper/ML/MlContext.cpp') diff --git a/Whisper/ML/MlContext.cpp b/Whisper/ML/MlContext.cpp index 6eeae09..a226999 100644 --- a/Whisper/ML/MlContext.cpp +++ b/Whisper/ML/MlContext.cpp @@ -556,7 +556,15 @@ void MlContext::softMax( Tensor& a, float inputScale ) printSizes.print( a ); #endif constexpr uint32_t FIXED_ROW_SIZE = 1500; - eComputeShader cs = ( a.ne[ 0 ] == FIXED_ROW_SIZE ) ? eComputeShader::softMaxFixed : eComputeShader::softMax; + + eComputeShader cs; + if( a.ne[ 0 ] == FIXED_ROW_SIZE ) + cs = eComputeShader::softMaxFixed; + else if( a.ne[ 0 ] >= ( 1024 * 4 ) ) + cs = eComputeShader::softMaxLong; + else + cs = eComputeShader::softMax; + bindShader( cs ); const uint32_t nr = a.countRows(); TensorShape dummyShape; -- cgit v1.2.3