summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--ComputeShaders/ComputeShaders.vcxproj1
-rw-r--r--ComputeShaders/ComputeShaders.vcxproj.filters1
-rw-r--r--ComputeShaders/addRepeatEx.hlsl76
-rw-r--r--Whisper/D3D/shaderNames.cpp3
-rw-r--r--Whisper/D3D/shaderNames.h71
-rw-r--r--Whisper/ML/MlContext.cpp15
-rw-r--r--Whisper/ML/MlContext.h3
-rw-r--r--Whisper/Whisper/WhisperContext.cpp14
8 files changed, 138 insertions, 46 deletions
diff --git a/ComputeShaders/ComputeShaders.vcxproj b/ComputeShaders/ComputeShaders.vcxproj
index 350d266..1d9343d 100644
--- a/ComputeShaders/ComputeShaders.vcxproj
+++ b/ComputeShaders/ComputeShaders.vcxproj
@@ -160,6 +160,7 @@
<FxCompile Include="addInPlace.hlsl" />
<FxCompile Include="addRepeat.hlsl" />
<FxCompile Include="addRepeat64.hlsl" />
+ <FxCompile Include="addRepeatEx.hlsl" />
<FxCompile Include="addRepeatGelu.hlsl" />
<FxCompile Include="addRepeatGelu64.hlsl" />
<FxCompile Include="addRepeatScale.hlsl" />
diff --git a/ComputeShaders/ComputeShaders.vcxproj.filters b/ComputeShaders/ComputeShaders.vcxproj.filters
index 12f1559..b827710 100644
--- a/ComputeShaders/ComputeShaders.vcxproj.filters
+++ b/ComputeShaders/ComputeShaders.vcxproj.filters
@@ -50,6 +50,7 @@
<FxCompile Include="mulMatTiledEx.hlsl" />
<FxCompile Include="matReshapePanels.hlsl" />
<FxCompile Include="mulMatByRowTiledEx.hlsl" />
+ <FxCompile Include="addRepeatEx.hlsl" />
</ItemGroup>
<ItemGroup>
<None Include="componentwiseBinaryOp.hlsli" />
diff --git a/ComputeShaders/addRepeatEx.hlsl b/ComputeShaders/addRepeatEx.hlsl
new file mode 100644
index 0000000..ea510b3
--- /dev/null
+++ b/ComputeShaders/addRepeatEx.hlsl
@@ -0,0 +1,76 @@
+// An equivalent of "addRepeat.hlsl" followed by "addInPlace.hlsl".
+// Merging into a single shader saves some global memory bandwidth and reduces CPU overhead wasted binding resources and dispatching shaders
+RWBuffer<float> tensor: register( u0 );
+Buffer<float> pattern: register( t0 );
+Buffer<float> finalAdd: register( t1 );
+
+cbuffer Constants: register( b0 )
+{
+ uint4 tensorSize: packoffset( c0 );
+ uint4 tensorStrides: packoffset( c1 );
+ uint4 patternSize: packoffset( c2 );
+ uint4 patternStrides: packoffset( c3 );
+ // uint4 finalSize: packoffset( c4 );
+ uint4 finalStrides: packoffset( c5 );
+}
+
+#ifndef THREADS
+#define THREADS 256
+#endif
+
+#include "repeatUtils.hlsli"
+
+// The micro-kernel of the shader, computes tensor[ rsi.x ] += pattern + finalAdd[ rsi.y ]
+inline void add2( uint2 rsi, float pattern )
+{
+ float f = tensor[ rsi.x ];
+ f += pattern;
+ f += finalAdd[ rsi.y ];
+ tensor[ rsi.x ] = f;
+}
+
+[ numthreads( THREADS, 1, 1 ) ]
+void main( uint3 group: SV_GroupID, uint thread : SV_GroupIndex )
+{
+ const uint2 stridesX = uint2( tensorStrides.x, finalStrides.x );
+ uint2 rsi;
+ rsi.x = rowOffset( group, tensorStrides );
+ rsi.y = rowOffset( group, finalStrides );
+ const uint rsiEnd = rsi.x + tensorSize.x * stridesX.x;
+ rsi += stridesX * thread;
+
+ uint pat = rowOffset( group % patternSize.yzw, patternStrides );
+
+ if( patternSize.x == 1 )
+ {
+ // The pattern only has 1 column, broadcasting over the row
+ const uint2 rsiInc = stridesX * THREADS;
+ const float p = pattern[ pat ];
+ for( ; rsi.x < rsiEnd; rsi += rsiInc )
+ add2( rsi, p );
+ }
+ else if( patternSize.x <= THREADS )
+ {
+ // pattern size doesn't exceed thread group size, load outside of the loop
+ const uint threadsPerGroup = THREADS - ( THREADS % patternSize.x );
+ if( thread >= threadsPerGroup )
+ return;
+
+ const uint2 rsiInc = stridesX * threadsPerGroup;
+ pat += ( thread % patternSize.x ) * patternStrides.x;
+ const float p = pattern[ pat ];
+ for( ; rsi.x < rsiEnd; rsi += rsiInc )
+ add2( rsi, p );
+ }
+ else
+ {
+ // Pattern rows are longer than the thread group, need to stream from both buffers
+ uint3 rsi3;
+ rsi3.xy = rsi;
+ rsi3.z = pat + thread * patternStrides.x;
+
+ const uint3 rsiInc = uint3( stridesX, patternStrides.x ) * THREADS;
+ for( ; rsi3.x < rsiEnd; rsi3 += rsiInc )
+ add2( rsi3.xy, pattern[ rsi3.z ] );
+ }
+} \ No newline at end of file
diff --git a/Whisper/D3D/shaderNames.cpp b/Whisper/D3D/shaderNames.cpp
index b52f5db..0605828 100644
--- a/Whisper/D3D/shaderNames.cpp
+++ b/Whisper/D3D/shaderNames.cpp
@@ -2,11 +2,12 @@
#include "stdafx.h"
#include "shaderNames.h"
-static const std::array<const char*, 38> s_shaderNames =
+static const std::array<const char*, 39> s_shaderNames =
{
"add",
"addInPlace",
"addRepeat",
+ "addRepeatEx",
"addRepeatGelu",
"addRepeatScale",
"addRows",
diff --git a/Whisper/D3D/shaderNames.h b/Whisper/D3D/shaderNames.h
index ccfab86..5942e72 100644
--- a/Whisper/D3D/shaderNames.h
+++ b/Whisper/D3D/shaderNames.h
@@ -9,41 +9,42 @@ namespace DirectCompute
add = 0,
addInPlace = 1,
addRepeat = 2,
- addRepeatGelu = 3,
- addRepeatScale = 4,
- addRows = 5,
- convolutionMain = 6,
- convolutionMain2 = 7,
- convolutionMain2Fixed = 8,
- convolutionPrep1 = 9,
- convolutionPrep2 = 10,
- copyConvert = 11,
- copyTranspose = 12,
- diagMaskInf = 13,
- flashAttention = 14,
- flashAttentionCompat1 = 15,
- flashAttentionCompat2 = 16,
- flashAttentionCompat3 = 17,
- fmaRepeat1 = 18,
- fmaRepeat2 = 19,
- matReshapePanels = 20,
- mulMatByRow = 21,
- mulMatByRowTiled = 22,
- mulMatByRowTiledEx = 23,
- mulMatByScalar = 24,
- mulMatDotMain = 25,
- mulMatDotReshape = 26,
- mulMatMadMain = 27,
- mulMatTiled = 28,
- mulMatTiledEx = 29,
- norm = 30,
- normCompat = 31,
- normFixed = 32,
- scaleInPlace = 33,
- softMax = 34,
- softMaxCompat = 35,
- softMaxFixed = 36,
- zeroMemory = 37,
+ addRepeatEx = 3,
+ addRepeatGelu = 4,
+ addRepeatScale = 5,
+ addRows = 6,
+ convolutionMain = 7,
+ convolutionMain2 = 8,
+ convolutionMain2Fixed = 9,
+ convolutionPrep1 = 10,
+ convolutionPrep2 = 11,
+ copyConvert = 12,
+ copyTranspose = 13,
+ diagMaskInf = 14,
+ flashAttention = 15,
+ flashAttentionCompat1 = 16,
+ flashAttentionCompat2 = 17,
+ flashAttentionCompat3 = 18,
+ fmaRepeat1 = 19,
+ fmaRepeat2 = 20,
+ matReshapePanels = 21,
+ mulMatByRow = 22,
+ mulMatByRowTiled = 23,
+ mulMatByRowTiledEx = 24,
+ mulMatByScalar = 25,
+ mulMatDotMain = 26,
+ mulMatDotReshape = 27,
+ mulMatMadMain = 28,
+ mulMatTiled = 29,
+ mulMatTiledEx = 30,
+ norm = 31,
+ normCompat = 32,
+ normFixed = 33,
+ scaleInPlace = 34,
+ softMax = 35,
+ softMaxCompat = 36,
+ softMaxFixed = 37,
+ zeroMemory = 38,
};
const char* computeShaderName( eComputeShader cs );
diff --git a/Whisper/ML/MlContext.cpp b/Whisper/ML/MlContext.cpp
index 5a29b85..6eeae09 100644
--- a/Whisper/ML/MlContext.cpp
+++ b/Whisper/ML/MlContext.cpp
@@ -734,6 +734,21 @@ Tensor MlContext::mulMatByRowTiledEx( const Tensor& a, const Tensor& b )
return res;
}
+void MlContext::addRepeatEx( Tensor& dest, const Tensor& pattern, const Tensor& finalAdd )
+{
+ if( !isSameShape( dest, finalAdd ) )
+ throw E_INVALIDARG;
+ assert( dest.getType() == eDataType::FP32 );
+
+ check( cb.update( dest, pattern, finalAdd ) );
+ bindShader( eComputeShader::addRepeatEx );
+ cb.bind();
+
+ Binder bind;
+ bind.bind( pattern, finalAdd, dest );
+ context()->Dispatch( dest.ne[ 1 ], dest.ne[ 2 ], dest.ne[ 3 ] );
+}
+
__m128i MlContext::getMemoryUse() const
{
__m128i v = cb.getMemoryUse();
diff --git a/Whisper/ML/MlContext.h b/Whisper/ML/MlContext.h
index d0c2d9e..fe0d48d 100644
--- a/Whisper/ML/MlContext.h
+++ b/Whisper/ML/MlContext.h
@@ -106,6 +106,9 @@ namespace DirectCompute
Tensor mulMatTiledEx( const Tensor& a, const Tensor& b );
Tensor mulMatByRowTiledEx( const Tensor& a, const Tensor& b );
+ // An equivalent of addRepeat( dest, pattern ) followed by addInPlace( dest, finalAdd )
+ void addRepeatEx( Tensor& dest, const Tensor& pattern, const Tensor& finalAdd );
+
__m128i getMemoryUse() const;
};
} \ No newline at end of file
diff --git a/Whisper/Whisper/WhisperContext.cpp b/Whisper/Whisper/WhisperContext.cpp
index c983039..d558aa6 100644
--- a/Whisper/Whisper/WhisperContext.cpp
+++ b/Whisper/Whisper/WhisperContext.cpp
@@ -244,10 +244,9 @@ Tensor WhisperContext::encodeLayer( const Tensor& source, size_t index, uint32_t
profiler.setNextTag( "enc.layer.4" );
cur = mulMat( layer.attnLn1.w, cur );
}
- addRepeat( cur, layer.attnLn1.b );
// add the input
- addInPlace( cur, source );
+ addRepeatEx( cur, layer.attnLn1.b, source );
// feed-forward network
Tensor inpFF = cur;
@@ -284,10 +283,8 @@ Tensor WhisperContext::encodeLayer( const Tensor& source, size_t index, uint32_t
cur = mulMat( layer.mlp1.w, cur );
}
- addRepeat( cur, layer.mlp1.b );
-
// output from this layer
- addInPlace( cur, inpFF );
+ addRepeatEx( cur, layer.mlp1.b, inpFF );
return cur;
}
@@ -523,10 +520,9 @@ Tensor WhisperContext::decodeLayer( const Tensor& inpL, size_t il, const sLayerD
{
profiler.setNextTag( "dec.layer.10" );
cur = mulMat( layer.crossAttnLn1.w, cur );
- addRepeat( cur, layer.crossAttnLn1.b );
}
// add the input
- addInPlace( cur, inpCA );
+ addRepeatEx( cur, layer.crossAttnLn1.b, inpCA );
Tensor inpFF = cur;
// feed-forward network
@@ -573,11 +569,9 @@ Tensor WhisperContext::decodeLayer( const Tensor& inpL, size_t il, const sLayerD
profiler.setNextTag( "dec.layer.12" );
cur = mulMat( layer.mlp1.w, cur );
}
- addRepeat( cur, layer.mlp1.b );
}
-
// output from this layer
- addInPlace( cur, inpFF );
+ addRepeatEx( cur, layer.mlp1.b, inpFF );
return cur;
}