summaryrefslogtreecommitdiffstats
path: root/Whisper
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-01-23 14:38:12 +0100
committerKonstantin <const@const.me>2023-01-23 14:38:12 +0100
commit27dfc3428a7016e2d05dd67b6d8b88c0b982baa9 (patch)
treef969d54ebfb266ecf61285a039295a1da37200a0 /Whisper
parent01aba39f15a03ed96e034ffc3b6ee9ec12294b0d (diff)
Performance improvement, `softMax` shader
Diffstat (limited to 'Whisper')
-rw-r--r--Whisper/D3D/shaderNames.cpp3
-rw-r--r--Whisper/D3D/shaderNames.h3
-rw-r--r--Whisper/ML/MlContext.cpp10
3 files changed, 13 insertions, 3 deletions
diff --git a/Whisper/D3D/shaderNames.cpp b/Whisper/D3D/shaderNames.cpp
index 0605828..a631f08 100644
--- a/Whisper/D3D/shaderNames.cpp
+++ b/Whisper/D3D/shaderNames.cpp
@@ -2,7 +2,7 @@
#include "stdafx.h"
#include "shaderNames.h"
-static const std::array<const char*, 39> s_shaderNames =
+static const std::array<const char*, 40> s_shaderNames =
{
"add",
"addInPlace",
@@ -42,6 +42,7 @@ static const std::array<const char*, 39> s_shaderNames =
"softMax",
"softMaxCompat",
"softMaxFixed",
+ "softMaxLong",
"zeroMemory",
};
diff --git a/Whisper/D3D/shaderNames.h b/Whisper/D3D/shaderNames.h
index 5942e72..80fc4fa 100644
--- a/Whisper/D3D/shaderNames.h
+++ b/Whisper/D3D/shaderNames.h
@@ -44,7 +44,8 @@ namespace DirectCompute
softMax = 35,
softMaxCompat = 36,
softMaxFixed = 37,
- zeroMemory = 38,
+ softMaxLong = 38,
+ zeroMemory = 39,
};
const char* computeShaderName( eComputeShader cs );
diff --git a/Whisper/ML/MlContext.cpp b/Whisper/ML/MlContext.cpp
index 6eeae09..a226999 100644
--- a/Whisper/ML/MlContext.cpp
+++ b/Whisper/ML/MlContext.cpp
@@ -556,7 +556,15 @@ void MlContext::softMax( Tensor& a, float inputScale )
printSizes.print( a );
#endif
constexpr uint32_t FIXED_ROW_SIZE = 1500;
- eComputeShader cs = ( a.ne[ 0 ] == FIXED_ROW_SIZE ) ? eComputeShader::softMaxFixed : eComputeShader::softMax;
+
+ eComputeShader cs;
+ if( a.ne[ 0 ] == FIXED_ROW_SIZE )
+ cs = eComputeShader::softMaxFixed;
+ else if( a.ne[ 0 ] >= ( 1024 * 4 ) )
+ cs = eComputeShader::softMaxLong;
+ else
+ cs = eComputeShader::softMax;
+
bindShader( cs );
const uint32_t nr = a.countRows();
TensorShape dummyShape;