summaryrefslogtreecommitdiffstats
path: root/Whisper/ML
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-01-23 14:38:12 +0100
committerKonstantin <const@const.me>2023-01-23 14:38:12 +0100
commit27dfc3428a7016e2d05dd67b6d8b88c0b982baa9 (patch)
treef969d54ebfb266ecf61285a039295a1da37200a0 /Whisper/ML
parent01aba39f15a03ed96e034ffc3b6ee9ec12294b0d (diff)
Performance improvement, `softMax` shader
Diffstat (limited to 'Whisper/ML')
-rw-r--r--Whisper/ML/MlContext.cpp10
1 files changed, 9 insertions, 1 deletions
diff --git a/Whisper/ML/MlContext.cpp b/Whisper/ML/MlContext.cpp
index 6eeae09..a226999 100644
--- a/Whisper/ML/MlContext.cpp
+++ b/Whisper/ML/MlContext.cpp
@@ -556,7 +556,15 @@ void MlContext::softMax( Tensor& a, float inputScale )
printSizes.print( a );
#endif
constexpr uint32_t FIXED_ROW_SIZE = 1500;
- eComputeShader cs = ( a.ne[ 0 ] == FIXED_ROW_SIZE ) ? eComputeShader::softMaxFixed : eComputeShader::softMax;
+
+ eComputeShader cs;
+ if( a.ne[ 0 ] == FIXED_ROW_SIZE )
+ cs = eComputeShader::softMaxFixed;
+ else if( a.ne[ 0 ] >= ( 1024 * 4 ) )
+ cs = eComputeShader::softMaxLong;
+ else
+ cs = eComputeShader::softMax;
+
bindShader( cs );
const uint32_t nr = a.countRows();
TensorShape dummyShape;