diff options
| author | Konstantin <const@const.me> | 2023-01-23 14:38:12 +0100 |
|---|---|---|
| committer | Konstantin <const@const.me> | 2023-01-23 14:38:12 +0100 |
| commit | 27dfc3428a7016e2d05dd67b6d8b88c0b982baa9 (patch) | |
| tree | f969d54ebfb266ecf61285a039295a1da37200a0 /Whisper | |
| parent | 01aba39f15a03ed96e034ffc3b6ee9ec12294b0d (diff) | |
Performance improvement, `softMax` shader
Diffstat (limited to 'Whisper')
| -rw-r--r-- | Whisper/D3D/shaderNames.cpp | 3 | ||||
| -rw-r--r-- | Whisper/D3D/shaderNames.h | 3 | ||||
| -rw-r--r-- | Whisper/ML/MlContext.cpp | 10 |
3 files changed, 13 insertions, 3 deletions
diff --git a/Whisper/D3D/shaderNames.cpp b/Whisper/D3D/shaderNames.cpp index 0605828..a631f08 100644 --- a/Whisper/D3D/shaderNames.cpp +++ b/Whisper/D3D/shaderNames.cpp @@ -2,7 +2,7 @@ #include "stdafx.h" #include "shaderNames.h" -static const std::array<const char*, 39> s_shaderNames = +static const std::array<const char*, 40> s_shaderNames = { "add", "addInPlace", @@ -42,6 +42,7 @@ static const std::array<const char*, 39> s_shaderNames = "softMax", "softMaxCompat", "softMaxFixed", + "softMaxLong", "zeroMemory", }; diff --git a/Whisper/D3D/shaderNames.h b/Whisper/D3D/shaderNames.h index 5942e72..80fc4fa 100644 --- a/Whisper/D3D/shaderNames.h +++ b/Whisper/D3D/shaderNames.h @@ -44,7 +44,8 @@ namespace DirectCompute softMax = 35, softMaxCompat = 36, softMaxFixed = 37, - zeroMemory = 38, + softMaxLong = 38, + zeroMemory = 39, }; const char* computeShaderName( eComputeShader cs ); diff --git a/Whisper/ML/MlContext.cpp b/Whisper/ML/MlContext.cpp index 6eeae09..a226999 100644 --- a/Whisper/ML/MlContext.cpp +++ b/Whisper/ML/MlContext.cpp @@ -556,7 +556,15 @@ void MlContext::softMax( Tensor& a, float inputScale ) printSizes.print( a ); #endif constexpr uint32_t FIXED_ROW_SIZE = 1500; - eComputeShader cs = ( a.ne[ 0 ] == FIXED_ROW_SIZE ) ? eComputeShader::softMaxFixed : eComputeShader::softMax; + + eComputeShader cs; + if( a.ne[ 0 ] == FIXED_ROW_SIZE ) + cs = eComputeShader::softMaxFixed; + else if( a.ne[ 0 ] >= ( 1024 * 4 ) ) + cs = eComputeShader::softMaxLong; + else + cs = eComputeShader::softMax; + bindShader( cs ); const uint32_t nr = a.countRows(); TensorShape dummyShape; |
