summaryrefslogtreecommitdiffstats
path: root/Whisper/ML/Context.ops.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'Whisper/ML/Context.ops.cpp')
-rw-r--r--Whisper/ML/Context.ops.cpp280
1 files changed, 280 insertions, 0 deletions
diff --git a/Whisper/ML/Context.ops.cpp b/Whisper/ML/Context.ops.cpp
new file mode 100644
index 0000000..7dfca9f
--- /dev/null
+++ b/Whisper/ML/Context.ops.cpp
@@ -0,0 +1,280 @@
+#include "stdafx.h"
+#include "MlContext.h"
+#include "testUtils.h"
+using namespace DirectCompute;
+
+Tensor MlContext::createTensor( eDataType type, const std::array<uint32_t, 4>& ne )
+{
+ Tensor res;
+ check( res.create( type, ne ) );
+ return res;
+}
+
+Tensor MlContext::createTensor( eDataType type, std::initializer_list<uint32_t> ne )
+{
+ size_t nDims = ne.size();
+ if( 0 == nDims || nDims > 4 )
+ throw E_INVALIDARG;
+ std::array<uint32_t, 4> arr;
+ for( size_t i = 0; i < nDims; i++ )
+ arr[ i ] = ne.begin()[ i ];
+ for( size_t i = nDims; i < 4; i++ )
+ arr[ i ] = 1;
+ return createTensor( type, arr );
+}
+
+Tensor MlContext::conv_1d_1s( const Tensor& a, const Tensor& b )
+{
+ assert( b.isMatrix() );
+ assert( a.ne[ 1 ] == b.ne[ 1 ] );
+ assert( a.ne[ 3 ] == 1 );
+
+ Tensor res = createTensor( eDataType::FP32, { b.ne[ 0 ], a.ne[ 2 ] } );
+
+ convolution( a, b, res );
+ return res;
+}
+
+Tensor MlContext::conv_1d_2s( const Tensor& a, const Tensor& b )
+{
+ assert( b.isMatrix() );
+ assert( a.ne[ 1 ] == b.ne[ 1 ] );
+ assert( a.ne[ 3 ] == 1 );
+
+ Tensor res = createTensor( eDataType::FP32, { b.ne[ 0 ] / 2, a.ne[ 2 ] } );
+#if 0
+ static PrintUniqueTensorSizes printSize( "conv_1d_2s" );
+ printSize.print( a, b );
+#endif
+ convolution2( a, b, res );
+ return res;
+}
+
+namespace
+{
+ inline bool canRepeat( const TensorShape& t0, const TensorShape& t1 )
+ {
+ return ( t1.ne[ 0 ] % t0.ne[ 0 ] == 0 ) &&
+ ( t1.ne[ 1 ] % t0.ne[ 1 ] == 0 ) &&
+ ( t1.ne[ 2 ] % t0.ne[ 2 ] == 0 ) &&
+ ( t1.ne[ 3 ] % t0.ne[ 3 ] == 0 );
+ }
+}
+
+Tensor MlContext::cwiseBinary( const Tensor& a, const Tensor& b, eComputeShader cs )
+{
+ assert( isSameShape( a, b ) );
+ Tensor res = createTensor( a.getType(), a.ne );
+ cwiseBinary( a, b, res, cs );
+ return res;
+}
+
+Tensor __declspec( noinline ) MlContext::view2d( const Tensor& a, uint32_t ne0, uint32_t ne1, uint32_t nb1, uint32_t offset )
+{
+ if( 0 != offset )
+ throw E_NOTIMPL;
+
+ Tensor res = a;
+ res.ne = { ne0, ne1, 1, 1 };
+
+ res.nb[ 1 ] = nb1;
+ res.nb[ 2 ] = res.nb[ 3 ] = nb1 * ne1;
+ return res;
+}
+
+Tensor MlContext::transpose( const Tensor& a )
+{
+ Tensor result = a;
+ std::swap( result.ne[ 0 ], result.ne[ 1 ] );
+ std::swap( result.nb[ 0 ], result.nb[ 1 ] );
+ return result;
+}
+
+Tensor MlContext::norm( const Tensor& a )
+{
+ Tensor res = createTensor( a.getType(), a.ne );
+ norm( a, res );
+ return res;
+}
+
+Tensor MlContext::mulMat( const Tensor& a, const Tensor& b )
+{
+ if( !canMulMat( a, b ) )
+ throw E_INVALIDARG;
+ Tensor res = createTensor( eDataType::FP32, { a.ne[ 1 ], b.ne[ 1 ], a.ne[ 2 ], b.ne[ 3 ] } );
+ if constexpr( enableInexactOptimizations )
+ mulMatTiled( a, b, res );
+ else
+ mulMat( a, b, res );
+#if 0
+ Tensor testTiled;
+ check( testTiled.create( eDataType::FP32, res.ne ) );
+ mulMatTiled( a, b, testTiled );
+
+ std::vector<float> current, tiled;
+ res.download( current );
+ testTiled.download( tiled );
+ sTensorDiff diff = computeDiff( current.data(), tiled.data(), current.size() );
+ diff.print( "mulMatTiled" );
+#endif
+ return res;
+}
+
+Tensor MlContext::mulMatEx( const Tensor& a, const Tensor& b, const char* tagName )
+{
+ if( !canMulMat( a, b ) )
+ throw E_INVALIDARG;
+ if( 0 != a.nb[ 0 ] )
+ throw E_INVALIDARG; // The first argument is expected to be pre-transposed
+
+ const uint16_t tag = profiler.setNextTag( tagName );
+
+ if( b.ne[ 1 ] != 1 )
+ {
+ if( b.nb[ 0 ] != 0 )
+ {
+ Tensor rhs = reshapePanels( b );
+ profiler.setNextTag( tag );
+ return mulMatTiledEx( a, rhs );
+ }
+ else
+ {
+ // Second argument already reshaped into these panels
+ return mulMatTiledEx( a, b );
+ }
+ }
+ else
+ {
+ if( 0 != b.nb[ 0 ] )
+ return mulMatByRowTiledEx( a, b );
+
+ // That shader requires classic VRAM layout of the second argument, gonna fail with pre-transposed one
+ throw E_INVALIDARG;
+ }
+}
+
+Tensor MlContext::permute( const Tensor& a, uint8_t axis0, uint8_t axis1, uint8_t axis2, uint8_t axis3 )
+{
+ assert( axis0 < 4 );
+ assert( axis1 < 4 );
+ assert( axis2 < 4 );
+ assert( axis3 < 4 );
+
+ assert( axis0 != axis1 );
+ assert( axis0 != axis2 );
+ assert( axis0 != axis3 );
+ assert( axis1 != axis2 );
+ assert( axis1 != axis3 );
+ assert( axis2 != axis3 );
+
+ Tensor res = a;
+ res.ne[ axis0 ] = a.ne[ 0 ];
+ res.ne[ axis1 ] = a.ne[ 1 ];
+ res.ne[ axis2 ] = a.ne[ 2 ];
+ res.ne[ axis3 ] = a.ne[ 3 ];
+
+ res.nb[ axis0 ] = a.nb[ 0 ];
+ res.nb[ axis1 ] = a.nb[ 1 ];
+ res.nb[ axis2 ] = a.nb[ 2 ];
+ res.nb[ axis3 ] = a.nb[ 3 ];
+ return res;
+}
+
+Tensor MlContext::flashAttention( const Tensor& q, const Tensor& k, const Tensor& v, bool masked )
+{
+ if( !canMulMat( k, q ) )
+ throw E_INVALIDARG;
+
+ if constexpr( enableInexactOptimizations )
+ {
+ if( !masked )
+ {
+ profiler.setNextTag( "flashAttn.1" );
+ Tensor tmp = mulMat( k, q );
+
+ const float tempScale = (float)( 1.0 / sqrt( (double)(int)q.ne[ 0 ] ) );
+ softMax( tmp, tempScale );
+
+ profiler.setNextTag( "flashAttn.2" );
+ return mulMat( v, tmp );
+ }
+ }
+
+ Tensor res = createTensor( eDataType::FP32, q.ne );
+ flashAttention( q, k, v, res, masked );
+
+#if 0
+ Tensor tmpMat = mulMat( k, q );
+ float scale = (float)( 1.0 / sqrt( (double)(int)q.ne[ 0 ] ) );
+ softMax( tmpMat, scale );
+ Tensor testRes = mulMat( v, tmpMat );
+ computeDiff( res, testRes ).print( "flashAttention mulmat" );
+#endif
+
+ return res;
+}
+
+Tensor MlContext::copy( const Tensor& a, eDataType type, std::initializer_list<uint32_t> size )
+{
+ const size_t dims = size.size();
+ if( 0 == dims || dims > 4 )
+ throw E_BOUNDS;
+
+ size_t nRequested = 1;
+ for( size_t i = 0; i < dims; i++ )
+ {
+ uint32_t n = size.begin()[ i ];
+ nRequested *= n;
+ }
+ if( nRequested != a.countElements() )
+ throw E_INVALIDARG;
+
+ const eDataType st = a.getType();
+ Tensor res;
+ if( a.isContinuous() && st == type )
+ {
+ // Same type, and it's dense - no need to call any compute shaders, equal to reshape
+ res = a;
+ for( size_t i = 0; i < dims; i++ )
+ res.ne[ i ] = size.begin()[ i ];;
+ for( size_t i = dims; i < 4; i++ )
+ res.ne[ i ] = 1;
+ res.setDenseStrides();
+ }
+ else
+ {
+ // Either converting non-continuous to continuous, or converting types
+ res = createTensor( type, size );
+ copyImpl( a, res, st == eDataType::FP32 && type == eDataType::FP16 );
+ }
+ return res;
+}
+
+void MlContext::copyInPlace( Tensor& dest, const Tensor& a, eDataType type, std::initializer_list<uint32_t> size )
+{
+ assert( type == dest.getType() );
+
+ const size_t dims = size.size();
+ if( 0 == dims || dims > 4 )
+ throw E_BOUNDS;
+
+ size_t nRequested = 1;
+ for( size_t i = 0; i < dims; i++ )
+ {
+ uint32_t n = size.begin()[ i ];
+ nRequested *= n;
+ }
+ if( nRequested != a.countElements() || nRequested != dest.countElements() )
+ throw E_INVALIDARG;
+
+ // Reshape the destination
+ for( size_t i = 0; i < dims; i++ )
+ dest.ne[ i ] = size.begin()[ i ];
+ for( size_t i = dims; i < 4; i++ )
+ dest.ne[ i ] = 1;
+ dest.setDenseStrides();
+
+ // Call the shader
+ const eDataType st = a.getType();
+ copyImpl( a, dest, st == eDataType::FP32 && type == eDataType::FP16 );
+} \ No newline at end of file