diff options
| -rw-r--r-- | ComputeShaders/ComputeShaders.vcxproj | 1 | ||||
| -rw-r--r-- | ComputeShaders/ComputeShaders.vcxproj.filters | 1 | ||||
| -rw-r--r-- | ComputeShaders/addRepeatEx.hlsl | 76 | ||||
| -rw-r--r-- | Whisper/D3D/shaderNames.cpp | 3 | ||||
| -rw-r--r-- | Whisper/D3D/shaderNames.h | 71 | ||||
| -rw-r--r-- | Whisper/ML/MlContext.cpp | 15 | ||||
| -rw-r--r-- | Whisper/ML/MlContext.h | 3 | ||||
| -rw-r--r-- | Whisper/Whisper/WhisperContext.cpp | 14 |
8 files changed, 138 insertions, 46 deletions
diff --git a/ComputeShaders/ComputeShaders.vcxproj b/ComputeShaders/ComputeShaders.vcxproj index 350d266..1d9343d 100644 --- a/ComputeShaders/ComputeShaders.vcxproj +++ b/ComputeShaders/ComputeShaders.vcxproj @@ -160,6 +160,7 @@ <FxCompile Include="addInPlace.hlsl" /> <FxCompile Include="addRepeat.hlsl" /> <FxCompile Include="addRepeat64.hlsl" /> + <FxCompile Include="addRepeatEx.hlsl" /> <FxCompile Include="addRepeatGelu.hlsl" /> <FxCompile Include="addRepeatGelu64.hlsl" /> <FxCompile Include="addRepeatScale.hlsl" /> diff --git a/ComputeShaders/ComputeShaders.vcxproj.filters b/ComputeShaders/ComputeShaders.vcxproj.filters index 12f1559..b827710 100644 --- a/ComputeShaders/ComputeShaders.vcxproj.filters +++ b/ComputeShaders/ComputeShaders.vcxproj.filters @@ -50,6 +50,7 @@ <FxCompile Include="mulMatTiledEx.hlsl" /> <FxCompile Include="matReshapePanels.hlsl" /> <FxCompile Include="mulMatByRowTiledEx.hlsl" /> + <FxCompile Include="addRepeatEx.hlsl" /> </ItemGroup> <ItemGroup> <None Include="componentwiseBinaryOp.hlsli" /> diff --git a/ComputeShaders/addRepeatEx.hlsl b/ComputeShaders/addRepeatEx.hlsl new file mode 100644 index 0000000..ea510b3 --- /dev/null +++ b/ComputeShaders/addRepeatEx.hlsl @@ -0,0 +1,76 @@ +// An equivalent of "addRepeat.hlsl" followed by "addInPlace.hlsl". +// Merging into a single shader saves some global memory bandwidth and reduces CPU overhead wasted binding resources and dispatching shaders +RWBuffer<float> tensor: register( u0 ); +Buffer<float> pattern: register( t0 ); +Buffer<float> finalAdd: register( t1 ); + +cbuffer Constants: register( b0 ) +{ + uint4 tensorSize: packoffset( c0 ); + uint4 tensorStrides: packoffset( c1 ); + uint4 patternSize: packoffset( c2 ); + uint4 patternStrides: packoffset( c3 ); + // uint4 finalSize: packoffset( c4 ); + uint4 finalStrides: packoffset( c5 ); +} + +#ifndef THREADS +#define THREADS 256 +#endif + +#include "repeatUtils.hlsli" + +// The micro-kernel of the shader, computes tensor[ rsi.x ] += pattern + finalAdd[ rsi.y ] +inline void add2( uint2 rsi, float pattern ) +{ + float f = tensor[ rsi.x ]; + f += pattern; + f += finalAdd[ rsi.y ]; + tensor[ rsi.x ] = f; +} + +[ numthreads( THREADS, 1, 1 ) ] +void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex ) +{ + const uint2 stridesX = uint2( tensorStrides.x, finalStrides.x ); + uint2 rsi; + rsi.x = rowOffset( group, tensorStrides ); + rsi.y = rowOffset( group, finalStrides ); + const uint rsiEnd = rsi.x + tensorSize.x * stridesX.x; + rsi += stridesX * thread; + + uint pat = rowOffset( group % patternSize.yzw, patternStrides ); + + if( patternSize.x == 1 ) + { + // The pattern only has 1 column, broadcasting over the row + const uint2 rsiInc = stridesX * THREADS; + const float p = pattern[ pat ]; + for( ; rsi.x < rsiEnd; rsi += rsiInc ) + add2( rsi, p ); + } + else if( patternSize.x <= THREADS ) + { + // pattern size doesn't exceed thread group size, load outside of the loop + const uint threadsPerGroup = THREADS - ( THREADS % patternSize.x ); + if( thread >= threadsPerGroup ) + return; + + const uint2 rsiInc = stridesX * threadsPerGroup; + pat += ( thread % patternSize.x ) * patternStrides.x; + const float p = pattern[ pat ]; + for( ; rsi.x < rsiEnd; rsi += rsiInc ) + add2( rsi, p ); + } + else + { + // Pattern rows are longer than the thread group, need to stream from both buffers + uint3 rsi3; + rsi3.xy = rsi; + rsi3.z = pat + thread * patternStrides.x; + + const uint3 rsiInc = uint3( stridesX, patternStrides.x ) * THREADS; + for( ; rsi3.x < rsiEnd; rsi3 += rsiInc ) + add2( rsi3.xy, pattern[ rsi3.z ] ); + } +}
\ No newline at end of file diff --git a/Whisper/D3D/shaderNames.cpp b/Whisper/D3D/shaderNames.cpp index b52f5db..0605828 100644 --- a/Whisper/D3D/shaderNames.cpp +++ b/Whisper/D3D/shaderNames.cpp @@ -2,11 +2,12 @@ #include "stdafx.h" #include "shaderNames.h" -static const std::array<const char*, 38> s_shaderNames = +static const std::array<const char*, 39> s_shaderNames = { "add", "addInPlace", "addRepeat", + "addRepeatEx", "addRepeatGelu", "addRepeatScale", "addRows", diff --git a/Whisper/D3D/shaderNames.h b/Whisper/D3D/shaderNames.h index ccfab86..5942e72 100644 --- a/Whisper/D3D/shaderNames.h +++ b/Whisper/D3D/shaderNames.h @@ -9,41 +9,42 @@ namespace DirectCompute add = 0, addInPlace = 1, addRepeat = 2, - addRepeatGelu = 3, - addRepeatScale = 4, - addRows = 5, - convolutionMain = 6, - convolutionMain2 = 7, - convolutionMain2Fixed = 8, - convolutionPrep1 = 9, - convolutionPrep2 = 10, - copyConvert = 11, - copyTranspose = 12, - diagMaskInf = 13, - flashAttention = 14, - flashAttentionCompat1 = 15, - flashAttentionCompat2 = 16, - flashAttentionCompat3 = 17, - fmaRepeat1 = 18, - fmaRepeat2 = 19, - matReshapePanels = 20, - mulMatByRow = 21, - mulMatByRowTiled = 22, - mulMatByRowTiledEx = 23, - mulMatByScalar = 24, - mulMatDotMain = 25, - mulMatDotReshape = 26, - mulMatMadMain = 27, - mulMatTiled = 28, - mulMatTiledEx = 29, - norm = 30, - normCompat = 31, - normFixed = 32, - scaleInPlace = 33, - softMax = 34, - softMaxCompat = 35, - softMaxFixed = 36, - zeroMemory = 37, + addRepeatEx = 3, + addRepeatGelu = 4, + addRepeatScale = 5, + addRows = 6, + convolutionMain = 7, + convolutionMain2 = 8, + convolutionMain2Fixed = 9, + convolutionPrep1 = 10, + convolutionPrep2 = 11, + copyConvert = 12, + copyTranspose = 13, + diagMaskInf = 14, + flashAttention = 15, + flashAttentionCompat1 = 16, + flashAttentionCompat2 = 17, + flashAttentionCompat3 = 18, + fmaRepeat1 = 19, + fmaRepeat2 = 20, + matReshapePanels = 21, + mulMatByRow = 22, + mulMatByRowTiled = 23, + mulMatByRowTiledEx = 24, + mulMatByScalar = 25, + mulMatDotMain = 26, + mulMatDotReshape = 27, + mulMatMadMain = 28, + mulMatTiled = 29, + mulMatTiledEx = 30, + norm = 31, + normCompat = 32, + normFixed = 33, + scaleInPlace = 34, + softMax = 35, + softMaxCompat = 36, + softMaxFixed = 37, + zeroMemory = 38, }; const char* computeShaderName( eComputeShader cs ); diff --git a/Whisper/ML/MlContext.cpp b/Whisper/ML/MlContext.cpp index 5a29b85..6eeae09 100644 --- a/Whisper/ML/MlContext.cpp +++ b/Whisper/ML/MlContext.cpp @@ -734,6 +734,21 @@ Tensor MlContext::mulMatByRowTiledEx( const Tensor& a, const Tensor& b ) return res; } +void MlContext::addRepeatEx( Tensor& dest, const Tensor& pattern, const Tensor& finalAdd ) +{ + if( !isSameShape( dest, finalAdd ) ) + throw E_INVALIDARG; + assert( dest.getType() == eDataType::FP32 ); + + check( cb.update( dest, pattern, finalAdd ) ); + bindShader( eComputeShader::addRepeatEx ); + cb.bind(); + + Binder bind; + bind.bind( pattern, finalAdd, dest ); + context()->Dispatch( dest.ne[ 1 ], dest.ne[ 2 ], dest.ne[ 3 ] ); +} + __m128i MlContext::getMemoryUse() const { __m128i v = cb.getMemoryUse(); diff --git a/Whisper/ML/MlContext.h b/Whisper/ML/MlContext.h index d0c2d9e..fe0d48d 100644 --- a/Whisper/ML/MlContext.h +++ b/Whisper/ML/MlContext.h @@ -106,6 +106,9 @@ namespace DirectCompute Tensor mulMatTiledEx( const Tensor& a, const Tensor& b ); Tensor mulMatByRowTiledEx( const Tensor& a, const Tensor& b ); + // An equivalent of addRepeat( dest, pattern ) followed by addInPlace( dest, finalAdd ) + void addRepeatEx( Tensor& dest, const Tensor& pattern, const Tensor& finalAdd ); + __m128i getMemoryUse() const; }; }
\ No newline at end of file diff --git a/Whisper/Whisper/WhisperContext.cpp b/Whisper/Whisper/WhisperContext.cpp index c983039..d558aa6 100644 --- a/Whisper/Whisper/WhisperContext.cpp +++ b/Whisper/Whisper/WhisperContext.cpp @@ -244,10 +244,9 @@ Tensor WhisperContext::encodeLayer( const Tensor& source, size_t index, uint32_t profiler.setNextTag( "enc.layer.4" ); cur = mulMat( layer.attnLn1.w, cur ); } - addRepeat( cur, layer.attnLn1.b ); // add the input - addInPlace( cur, source ); + addRepeatEx( cur, layer.attnLn1.b, source ); // feed-forward network Tensor inpFF = cur; @@ -284,10 +283,8 @@ Tensor WhisperContext::encodeLayer( const Tensor& source, size_t index, uint32_t cur = mulMat( layer.mlp1.w, cur ); } - addRepeat( cur, layer.mlp1.b ); - // output from this layer - addInPlace( cur, inpFF ); + addRepeatEx( cur, layer.mlp1.b, inpFF ); return cur; } @@ -523,10 +520,9 @@ Tensor WhisperContext::decodeLayer( const Tensor& inpL, size_t il, const sLayerD { profiler.setNextTag( "dec.layer.10" ); cur = mulMat( layer.crossAttnLn1.w, cur ); - addRepeat( cur, layer.crossAttnLn1.b ); } // add the input - addInPlace( cur, inpCA ); + addRepeatEx( cur, layer.crossAttnLn1.b, inpCA ); Tensor inpFF = cur; // feed-forward network @@ -573,11 +569,9 @@ Tensor WhisperContext::decodeLayer( const Tensor& inpL, size_t il, const sLayerD profiler.setNextTag( "dec.layer.12" ); cur = mulMat( layer.mlp1.w, cur ); } - addRepeat( cur, layer.mlp1.b ); } - // output from this layer - addInPlace( cur, inpFF ); + addRepeatEx( cur, layer.mlp1.b, inpFF ); return cur; } |
