summaryrefslogtreecommitdiffstats
path: root/Whisper/ML/tensorOpsTests.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'Whisper/ML/tensorOpsTests.cpp')
-rw-r--r--Whisper/ML/tensorOpsTests.cpp183
1 files changed, 183 insertions, 0 deletions
diff --git a/Whisper/ML/tensorOpsTests.cpp b/Whisper/ML/tensorOpsTests.cpp
new file mode 100644
index 0000000..adc020e
--- /dev/null
+++ b/Whisper/ML/tensorOpsTests.cpp
@@ -0,0 +1,183 @@
+#include "stdafx.h"
+#include "tensorOpsTests.h"
+#include "MlContext.h"
+#include "TensorEx.h"
+#include "../D3D/shaders.h"
+#include "../D3D/Binder.h"
+#include "testUtils.h"
+#include "../Whisper/WhisperContext.h"
+
+void DirectCompute::testMulMat( const ggml_tensor* src0, const ggml_tensor* src1, const ggml_tensor* dst, const void* tempBuffer )
+{
+ return;
+ CaptureRaii capture;
+ const size_t nb00 = src0->nb[ 0 ];
+ const size_t nb01 = src0->nb[ 1 ];
+
+ if( src0->type != GGML_TYPE_F16 )
+ return; // TODO
+
+ if( nb01 < nb00 )
+ return; // TODO
+
+ WhisperContext& ctx = WhisperContext::current();
+
+ Tensor arg0, arg1;
+ check( arg0.create( *src0, eBufferUse::Immutable, true ) );
+ check( arg1.create( *src1, eBufferUse::Immutable, true ) );
+ TensorEx res;
+ check( res.create( *dst, eBufferUse::ReadWriteDownload, false ) );
+
+ ctx.mulMat( arg0, arg1, res );
+
+ std::vector<float> tv;
+ check( res.download( tv ) );
+
+ const size_t len = tv.size();
+ computeDiff( tv.data(), (const float*)dst->data, len ).print( "testMulMat-product" );
+
+#if 0
+ dbgWriteBinaryFile( L"product-orig.bin", dst->data, len * 4 );
+ dbgWriteBinaryFile( L"product-gpu.bin", tv.data(), len * 4 );
+ __debugbreak();
+#endif
+}
+
+#if 0
+void DirectCompute::testMulMatReshape( const ggml_tensor* src1, const void* tempBuffer )
+{
+ Tensor src;
+ check( src.create( *src1, eBufferUse::Immutable, true ) );
+
+ const size_t ne10 = (uint32_t)src1->ne[ 0 ];
+ const size_t ne11 = (uint32_t)src1->ne[ 1 ];
+ const size_t ne12 = (uint32_t)src1->ne[ 2 ];
+ const size_t ne13 = (uint32_t)src1->ne[ 3 ];
+ if( 1 != ne13 )
+ throw E_UNEXPECTED;
+ const size_t tempLength = ne10 * ne11 * ne12 * ne13;
+
+ Context& ctx = Context::current();
+ const ReadWriteViews& temp = ctx.temp.fp16( tempLength );
+
+ {
+ Binder bind;
+ ctx.cb.bind();
+
+ bindShader( eComputeShader::mulMatDotReshape );
+
+ ctx.cb.update( src );
+ bind.bind( src, temp );
+ context()->Dispatch( (UINT)ne11, (UINT)ne12, 1 );
+ }
+
+ std::vector<uint16_t> reshaped;
+ check( downloadBuffer( temp, reshaped ) );
+ computeDiff( reshaped.data(), (const uint16_t*)tempBuffer, reshaped.size() ).print( "testMulMatReshape" );
+
+#if 0
+ dbgWriteBinaryFile( L"fp32.bin", src1->data, ggml_nbytes( src1 ) );
+ dbgWriteBinaryFile( L"fp16-cpu.bin", tempBuffer, reshaped.size() * 2 );
+ dbgWriteBinaryFile( L"fp16-gpu.bin", reshaped.data(), reshaped.size() * 2 );
+ __debugbreak();
+#endif
+}
+#endif
+
+void DirectCompute::computeMulMat( const ggml_tensor* src0, const ggml_tensor* src1, ggml_tensor* dst )
+{
+ CaptureRaii capture;
+ const size_t nb00 = src0->nb[ 0 ];
+ const size_t nb01 = src0->nb[ 1 ];
+
+ if( src0->type != GGML_TYPE_F16 )
+ throw E_INVALIDARG;
+ if( nb01 < nb00 )
+ throw E_INVALIDARG;
+
+ WhisperContext& ctx = WhisperContext::current();
+
+ Tensor arg0, arg1;
+ check( arg0.create( *src0, eBufferUse::Immutable, true ) );
+ check( arg1.create( *src1, eBufferUse::Immutable, true ) );
+ TensorEx res;
+ check( res.create( *dst, eBufferUse::ReadWriteDownload, false ) );
+
+ ctx.mulMat( arg0, arg1, res );
+
+ check( res.download( dst->data ) );
+}
+
+void DirectCompute::testFlashAttention( const ggml_tensor* q, const ggml_tensor* k, const ggml_tensor* v, bool masked, const ggml_tensor* dst )
+{
+ CaptureRaii capture;
+
+ Tensor Q, K, V;
+ TensorEx res;
+ check( Q.create( *q, eBufferUse::Immutable, true ) );
+ check( K.create( *k, eBufferUse::Immutable, true ) );
+ check( V.create( *v, eBufferUse::Immutable, true ) );
+ check( res.create( *dst, eBufferUse::ReadWriteDownload, false ) );
+
+ WhisperContext& ctx = WhisperContext::current();
+ ctx.flashAttention( Q, K, V, res, masked );
+
+ std::vector<float> tv;
+ check( res.download( tv ) );
+
+ const size_t len = tv.size();
+ computeDiff( tv.data(), (const float*)dst->data, len ).print( "testFlashAttention" );
+}
+
+void DirectCompute::computeFlashAttention( const ggml_tensor* q, const ggml_tensor* k, const ggml_tensor* v, bool masked, ggml_tensor* dst )
+{
+ CaptureRaii capture;
+
+ Tensor Q, K, V;
+ TensorEx res;
+ check( Q.create( *q, eBufferUse::Immutable, true ) );
+ check( K.create( *k, eBufferUse::Immutable, true ) );
+ check( V.create( *v, eBufferUse::Immutable, true ) );
+ check( res.create( *dst, eBufferUse::ReadWriteDownload, false ) );
+
+ WhisperContext& ctx = WhisperContext::current();
+ ctx.flashAttention( Q, K, V, res, masked );
+
+ check( res.download( dst->data ) );
+}
+
+void DirectCompute::testConvolution( const ggml_tensor* src0, const ggml_tensor* src1, const ggml_tensor* dst )
+{
+ CaptureRaii capture;
+
+ Tensor arg0, arg1;
+ check( arg0.create( *src0, eBufferUse::Immutable, true ) );
+ check( arg1.create( *src1, eBufferUse::Immutable, true ) );
+ TensorEx res;
+ check( res.create( *dst, eBufferUse::ReadWriteDownload, false ) );
+
+ WhisperContext& ctx = WhisperContext::current();
+ ctx.convolution( arg0, arg1, res );
+
+ std::vector<float> tv;
+ check( res.download( tv ) );
+
+ const size_t len = tv.size();
+ computeDiff( tv.data(), (const float*)dst->data, len ).print( "testConvolution" );
+}
+
+void DirectCompute::computeConvolution( const ggml_tensor* src0, const ggml_tensor* src1, ggml_tensor* dst )
+{
+ CaptureRaii capture;
+
+ Tensor arg0, arg1;
+ check( arg0.create( *src0, eBufferUse::Immutable, true ) );
+ check( arg1.create( *src1, eBufferUse::Immutable, true ) );
+ TensorEx res;
+ check( res.create( *dst, eBufferUse::ReadWriteDownload, false ) );
+
+ WhisperContext& ctx = WhisperContext::current();
+ ctx.convolution( arg0, arg1, res );
+
+ res.download( dst->data );
+} \ No newline at end of file