summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-01-18 20:35:30 +0100
committerKonstantin <const@const.me>2023-01-18 20:35:30 +0100
commit11c399b70c7ad5664b6060b39632e6b9fa815350 (patch)
tree763afed51699017749d3f0398f16928aad7544a4
parentad097a744759c6a78e1b33ea9d2b4b2af01c529d (diff)
Optional startup flags to override performance-related defaults for the compute shaders
-rw-r--r--Examples/WhisperDesktop/LoadModelDlg.cpp2
-rw-r--r--Whisper/API/eGpuModelFlags.h13
-rw-r--r--Whisper/API/iContext.cl.h3
-rw-r--r--Whisper/API/iContext.h3
-rw-r--r--Whisper/D3D/device.cpp54
-rw-r--r--Whisper/D3D/device.h20
-rw-r--r--Whisper/D3D/startup.cpp4
-rw-r--r--Whisper/D3D/startup.h2
-rw-r--r--Whisper/ML/mlStartup.cpp4
-rw-r--r--Whisper/ML/mlStartup.h2
-rw-r--r--Whisper/Whisper.vcxproj1
-rw-r--r--Whisper/Whisper.vcxproj.filters1
-rw-r--r--Whisper/Whisper/ModelImpl.cpp6
-rw-r--r--Whisper/Whisper/ModelImpl.h3
-rw-r--r--Whisper/Whisper/WhisperModel.cpp12
-rw-r--r--Whisper/modelFactory.cpp10
-rw-r--r--Whisper/modelFactory.h2
-rw-r--r--Whisper/stdafx.h7
-rw-r--r--WhisperNet/API/eGpuModelFlags.cs28
-rw-r--r--WhisperNet/API/eModelImplementation.cs2
-rw-r--r--WhisperNet/Library.cs2
21 files changed, 131 insertions, 50 deletions
diff --git a/Examples/WhisperDesktop/LoadModelDlg.cpp b/Examples/WhisperDesktop/LoadModelDlg.cpp
index 1b2bf03..3aa84e7 100644
--- a/Examples/WhisperDesktop/LoadModelDlg.cpp
+++ b/Examples/WhisperDesktop/LoadModelDlg.cpp
@@ -140,7 +140,7 @@ void __stdcall LoadModelDlg::poolCallback() noexcept
lmcb.cancel = nullptr;
lmcb.progress = &LoadModelDlg::progressCallback;
lmcb.pv = this;
- HRESULT hr = Whisper::loadModel( path, impl, &lmcb, &model );
+ HRESULT hr = Whisper::loadModel( path, impl, 0, &lmcb, &model );
if( SUCCEEDED( hr ) )
appState.model = model;
else
diff --git a/Whisper/API/eGpuModelFlags.h b/Whisper/API/eGpuModelFlags.h
new file mode 100644
index 0000000..96f5a76
--- /dev/null
+++ b/Whisper/API/eGpuModelFlags.h
@@ -0,0 +1,13 @@
+#pragma once
+#include <stdint.h>
+
+namespace Whisper
+{
+ enum struct eGpuModelFlags : uint32_t
+ {
+ Wave32 = 1,
+ Wave64 = 2,
+ NoReshapedMatMul = 4,
+ UseReshapedMatMul = 8,
+ };
+} \ No newline at end of file
diff --git a/Whisper/API/iContext.cl.h b/Whisper/API/iContext.cl.h
index 97d34c7..fdb15ce 100644
--- a/Whisper/API/iContext.cl.h
+++ b/Whisper/API/iContext.cl.h
@@ -5,6 +5,7 @@
#include "loggerApi.h"
#include "sLanguageList.h"
#include "sLoadModelCallbacks.h"
+#include "eGpuModelFlags.h"
namespace Whisper
{
@@ -55,7 +56,7 @@ namespace Whisper
};
HRESULT COMLIGHTCALL setupLogger( const sLoggerSetup& setup );
- HRESULT COMLIGHTCALL loadModel( const wchar_t* path, eModelImplementation impl, const sLoadModelCallbacks* callbacks, iModel** pp );
+ HRESULT COMLIGHTCALL loadModel( const wchar_t* path, eModelImplementation impl, uint32_t flags, const sLoadModelCallbacks* callbacks, iModel** pp );
uint32_t COMLIGHTCALL findLanguageKeyW( const wchar_t* lang );
uint32_t COMLIGHTCALL findLanguageKeyA( const char* lang );
diff --git a/Whisper/API/iContext.h b/Whisper/API/iContext.h
index 9661093..d6ca29c 100644
--- a/Whisper/API/iContext.h
+++ b/Whisper/API/iContext.h
@@ -4,6 +4,7 @@
#include "loggerApi.h"
#include "sLanguageList.h"
#include "sLoadModelCallbacks.h"
+#include "eGpuModelFlags.h"
namespace Whisper
{
@@ -50,7 +51,7 @@ namespace Whisper
};
HRESULT __stdcall setupLogger( const sLoggerSetup& setup );
- HRESULT __stdcall loadModel( const wchar_t* path, eModelImplementation impl, const sLoadModelCallbacks* callbacks, iModel** pp );
+ HRESULT __stdcall loadModel( const wchar_t* path, eModelImplementation impl, uint32_t flags, const sLoadModelCallbacks* callbacks, iModel** pp );
uint32_t __stdcall findLanguageKeyW( const wchar_t* lang );
uint32_t __stdcall findLanguageKeyA( const char* lang );
diff --git a/Whisper/D3D/device.cpp b/Whisper/D3D/device.cpp
index 4eb5a60..5b0a6e8 100644
--- a/Whisper/D3D/device.cpp
+++ b/Whisper/D3D/device.cpp
@@ -4,6 +4,7 @@
#include <ammintrin.h>
#pragma comment(lib, "D3D11.lib")
#include "RenderDoc/renderDoc.h"
+#include "../API/eGpuModelFlags.h"
namespace DirectCompute
{
@@ -54,7 +55,25 @@ namespace DirectCompute
sGpuInfo s_gpuInfo = {};
const sGpuInfo& gpuInfo = s_gpuInfo;
- static HRESULT queryDeviceInfo()
+ using Whisper::eGpuModelFlags;
+ inline constexpr uint32_t operator|( eGpuModelFlags a, eGpuModelFlags b )
+ {
+ return (uint32_t)a | (uint32_t)b;
+ }
+ inline bool operator&( uint32_t flags, eGpuModelFlags bit )
+ {
+ return 0 != ( flags & (uint32_t)bit );
+ }
+ inline bool merge3( uint32_t flags, eGpuModelFlags enabled, eGpuModelFlags disabled, bool def )
+ {
+ if( flags & enabled )
+ return true;
+ if( flags & disabled )
+ return false;
+ return def;
+ }
+
+ static HRESULT queryDeviceInfo( uint32_t flags )
{
if( nullptr == g_device )
return OLE_E_BLANK;
@@ -77,15 +96,44 @@ namespace DirectCompute
s_gpuInfo.vramDedicated = desc.DedicatedVideoMemory;
s_gpuInfo.ramDedicated = desc.DedicatedSystemMemory;
s_gpuInfo.ramShared = desc.SharedSystemMemory;
+
+ // Set up these flags
+ uint8_t ef = 0;
+ const bool amd = ( s_gpuInfo.vendor == eGpuVendor::AMD );
+ if( merge3( flags, eGpuModelFlags::Wave64, eGpuModelFlags::Wave32, amd ) )
+ ef |= (uint8_t)eGpuEffectiveFlags::Wave64;
+ if( merge3( flags, eGpuModelFlags::UseReshapedMatMul, eGpuModelFlags::NoReshapedMatMul, amd ) )
+ ef |= (uint8_t)eGpuEffectiveFlags::ReshapedMatMul;
+ s_gpuInfo.flags = (eGpuEffectiveFlags)ef;
+
+ return S_OK;
+ }
+
+ static HRESULT validateFlags( uint32_t flags )
+ {
+ constexpr uint32_t waveBoth = eGpuModelFlags::Wave32 | eGpuModelFlags::Wave64;
+ if( ( flags & waveBoth ) == waveBoth )
+ {
+ logError( u8"eGpuModelFlags.%s and eGpuModelFlags.%s are mutually exclusive", "Wave32", "Wave64" );
+ return E_INVALIDARG;
+ }
+
+ constexpr uint32_t reshapedBoth = eGpuModelFlags::NoReshapedMatMul | eGpuModelFlags::UseReshapedMatMul;
+ if( ( flags & reshapedBoth ) == reshapedBoth )
+ {
+ logError( u8"eGpuModelFlags.%s and eGpuModelFlags.%s are mutually exclusive", "NoReshapedMatMul", "UseReshapedMatMul" );
+ return E_INVALIDARG;
+ }
return S_OK;
}
- HRESULT initialize()
+ HRESULT initialize( uint32_t flags )
{
+ CHECK( validateFlags( flags ) );
HRESULT hr = createDevice();
if( hr != S_OK )
return hr;
- queryDeviceInfo();
+ queryDeviceInfo( flags );
return S_OK;
}
diff --git a/Whisper/D3D/device.h b/Whisper/D3D/device.h
index dfcb766..474013e 100644
--- a/Whisper/D3D/device.h
+++ b/Whisper/D3D/device.h
@@ -8,7 +8,7 @@ namespace DirectCompute
ID3D11DeviceContext* context();
D3D_FEATURE_LEVEL featureLevel();
- HRESULT initialize();
+ HRESULT initialize( uint32_t flags );
void terminate();
// DXGI_ADAPTER_DESC.VendorId magic numbers; they come from that database: https://pcisig.com/membership/member-companies
@@ -20,17 +20,24 @@ namespace DirectCompute
VMWare = 0x15ad,
};
+ enum struct eGpuEffectiveFlags : uint8_t
+ {
+ Wave64 = 1,
+ ReshapedMatMul = 2,
+ };
+
struct sGpuInfo
{
- std::wstring description;
+ eGpuEffectiveFlags flags;
eGpuVendor vendor;
uint16_t device, revision;
uint32_t subsystem;
size_t vramDedicated, ramDedicated, ramShared;
+ std::wstring description;
inline bool wave64() const
{
- return vendor == eGpuVendor::AMD;
+ return 0 != ( (uint8_t)flags & (uint8_t)eGpuEffectiveFlags::Wave64 );
}
// On nVidia 1080Ti that approach is much slower, by a factor of 2.4
@@ -38,15 +45,10 @@ namespace DirectCompute
// Dunno why that is, maybe 'coz on that AMD complete panels fit in L3 cache.
// Anyway, we do want extra 30% perf on AMD Cezanne, so only using that code on AMD GPUs.
// Dunno how it gonna behave on other GPUs, need to test.
-#if RESHAPED_MATRIX_MULTIPLY
inline bool useReshapedMatMul() const
{
- // return true;
- return vendor == eGpuVendor::AMD;
+ return 0 != ( (uint8_t)flags & (uint8_t)eGpuEffectiveFlags::ReshapedMatMul );
}
-#else
- constexpr bool useReshapedMatMul() const { return false; }
-#endif
};
extern const sGpuInfo& gpuInfo;
diff --git a/Whisper/D3D/startup.cpp b/Whisper/D3D/startup.cpp
index 2ff0b0c..540e7d8 100644
--- a/Whisper/D3D/startup.cpp
+++ b/Whisper/D3D/startup.cpp
@@ -2,9 +2,9 @@
#include "startup.h"
#include "device.h"
-HRESULT DirectCompute::d3dStartup()
+HRESULT DirectCompute::d3dStartup( uint32_t flags )
{
- HRESULT hr = DirectCompute::initialize();
+ HRESULT hr = DirectCompute::initialize( flags );
if( SUCCEEDED( hr ) )
hr = createComputeShaders();
return hr;
diff --git a/Whisper/D3D/startup.h b/Whisper/D3D/startup.h
index de42fae..e798142 100644
--- a/Whisper/D3D/startup.h
+++ b/Whisper/D3D/startup.h
@@ -3,7 +3,7 @@ using HRESULT = long;
namespace DirectCompute
{
- HRESULT d3dStartup();
+ HRESULT d3dStartup( uint32_t flags );
void d3dShutdown();
HRESULT createComputeShaders();
diff --git a/Whisper/ML/mlStartup.cpp b/Whisper/ML/mlStartup.cpp
index 815ce52..5a0c88a 100644
--- a/Whisper/ML/mlStartup.cpp
+++ b/Whisper/ML/mlStartup.cpp
@@ -12,9 +12,9 @@ namespace DirectCompute
{
const LookupTables& lookupTables = s_tables;
- HRESULT mlStartup()
+ HRESULT mlStartup( uint32_t flags )
{
- CHECK( d3dStartup() );
+ CHECK( d3dStartup( flags ) );
CHECK( s_tables.create() );
return S_OK;
}
diff --git a/Whisper/ML/mlStartup.h b/Whisper/ML/mlStartup.h
index ef5020c..a18a5cf 100644
--- a/Whisper/ML/mlStartup.h
+++ b/Whisper/ML/mlStartup.h
@@ -3,6 +3,6 @@ using HRESULT = long;
namespace DirectCompute
{
- HRESULT mlStartup();
+ HRESULT mlStartup( uint32_t flags );
void mlShutdown();
} \ No newline at end of file
diff --git a/Whisper/Whisper.vcxproj b/Whisper/Whisper.vcxproj
index f270440..5cf4f08 100644
--- a/Whisper/Whisper.vcxproj
+++ b/Whisper/Whisper.vcxproj
@@ -215,6 +215,7 @@
<ClCompile Include="whisperCom.cpp" />
</ItemGroup>
<ItemGroup>
+ <ClInclude Include="API\eGpuModelFlags.h" />
<ClInclude Include="API\iContext.h" />
<ClInclude Include="API\iMediaFoundation.h" />
<ClInclude Include="API\iTranscribeResult.h" />
diff --git a/Whisper/Whisper.vcxproj.filters b/Whisper/Whisper.vcxproj.filters
index 193fbe7..8a7c371 100644
--- a/Whisper/Whisper.vcxproj.filters
+++ b/Whisper/Whisper.vcxproj.filters
@@ -187,6 +187,7 @@
<ClInclude Include="API\sLoadModelCallbacks.h" />
<ClInclude Include="ML\Reshaper.h" />
<ClInclude Include="ML\reshapedMultiply.h" />
+ <ClInclude Include="API\eGpuModelFlags.h" />
</ItemGroup>
<ItemGroup>
<None Include="whisper.def" />
diff --git a/Whisper/Whisper/ModelImpl.cpp b/Whisper/Whisper/ModelImpl.cpp
index 968c3ce..3e46076 100644
--- a/Whisper/Whisper/ModelImpl.cpp
+++ b/Whisper/Whisper/ModelImpl.cpp
@@ -16,7 +16,7 @@ HRESULT ModelImpl::FinalConstruct()
{
if( 1 != InterlockedIncrement( &s_refCounter ) )
return S_FALSE;
- return DirectCompute::mlStartup();
+ return DirectCompute::mlStartup( gpuFlags );
}
void ModelImpl::FinalRelease()
@@ -75,7 +75,7 @@ inline bool hasAvxAndFma()
return true;
}
-HRESULT __stdcall Whisper::loadGpuModel( const wchar_t* path, bool hybrid, const sLoadModelCallbacks* callbacks, iModel** pp )
+HRESULT __stdcall Whisper::loadGpuModel( const wchar_t* path, bool hybrid, uint32_t flags, const sLoadModelCallbacks* callbacks, iModel** pp )
{
if( nullptr == path || nullptr == pp )
return E_POINTER;
@@ -108,7 +108,7 @@ HRESULT __stdcall Whisper::loadGpuModel( const wchar_t* path, bool hybrid, const
}
ComLight::CComPtr<ComLight::Object<ModelImpl>> obj;
- CHECK( ComLight::Object<ModelImpl>::create( obj ) );
+ CHECK( ComLight::Object<ModelImpl>::create( obj, flags ) );
hr = obj->load( &stream, hybrid, callbacks );
if( FAILED( hr ) )
{
diff --git a/Whisper/Whisper/ModelImpl.h b/Whisper/Whisper/ModelImpl.h
index 4bcea12..0571cfb 100644
--- a/Whisper/Whisper/ModelImpl.h
+++ b/Whisper/Whisper/ModelImpl.h
@@ -11,6 +11,7 @@ namespace Whisper
class ModelImpl : public ComLight::ObjectRoot<iModel>
{
WhisperModel model;
+ const uint32_t gpuFlags;
HRESULT COMLIGHTCALL createContext( iContext** pp ) override final;
@@ -31,7 +32,7 @@ namespace Whisper
}
public:
-
+ ModelImpl( uint32_t flags ) : gpuFlags( flags ) { }
HRESULT FinalConstruct();
void FinalRelease();
diff --git a/Whisper/Whisper/WhisperModel.cpp b/Whisper/Whisper/WhisperModel.cpp
index 28e4540..a99529c 100644
--- a/Whisper/Whisper/WhisperModel.cpp
+++ b/Whisper/Whisper/WhisperModel.cpp
@@ -35,7 +35,6 @@ namespace
PendingTensor( DirectCompute::Tensor& tensor, ePostProcessing pp = ePostProcessing::None ) :
dest( &tensor ), postProcessing( pp ) { }
-#if RESHAPED_MATRIX_MULTIPLY
// If you wonder why not reshape them after all tensors are loaded, doing that on the fly is faster because CPU and GPU work in parallel
// In the current version, CPU reads data for a next tensor, while in the meantime GPU reshapes a previously loaded tensor.
HRESULT postProcess( Reshaper& rs, eDataType dt )
@@ -59,7 +58,6 @@ namespace
return E_UNEXPECTED;
}
}
-#endif
};
void populateEncodeTensorsMap( CAtlMap<CStringA, PendingTensor>& map, int layersEnc, DirectCompute::ModelBuffers& tensors )
@@ -261,9 +259,7 @@ HRESULT WhisperModel::loadGpu( ComLight::iReadStream* stm, CallbacksImpl& callba
CAtlMap<CStringA, PendingTensor> map;
populateTensorsMap( map, parameters.n_audio_layer, parameters.n_text_layer, tensors, false );
-#if RESHAPED_MATRIX_MULTIPLY
DirectCompute::Reshaper reshape;
-#endif
std::vector<uint8_t> bytesVector;
size_t countLoaded = 0;
@@ -328,9 +324,7 @@ HRESULT WhisperModel::loadGpu( ComLight::iReadStream* stm, CallbacksImpl& callba
CHECK( readBytes( stm, bytesVector.data(), bytesVector.size() ) );
cb += bytesVector.size();
CHECK( p->m_value.dest->createImmutable( dt, ne, bytesVector.data() ) );
-#if RESHAPED_MATRIX_MULTIPLY
CHECK( p->m_value.postProcess( reshape, dt ) );
-#endif
countLoaded++;
}
@@ -350,11 +344,7 @@ HRESULT WhisperModel::loadHybrid( ComLight::iReadStream* stm, CallbacksImpl& cal
{
CAtlMap<CStringA, PendingTensor> map;
populateTensorsMap( map, parameters.n_audio_layer, parameters.n_text_layer, tensors, true );
-
-#if RESHAPED_MATRIX_MULTIPLY
DirectCompute::Reshaper reshape;
-#endif
-
CpuCompute::HybridLoader loader( hybridTensors, parameters.n_text_layer );
std::vector<uint8_t> bytesVector;
@@ -422,9 +412,7 @@ HRESULT WhisperModel::loadHybrid( ComLight::iReadStream* stm, CallbacksImpl& cal
}
CHECK( readBytes( stm, bytesVector.data(), bytesVector.size() ) );
CHECK( p->m_value.dest->createImmutable( dt, ne, bytesVector.data() ) );
-#if RESHAPED_MATRIX_MULTIPLY
CHECK( p->m_value.postProcess( reshape, dt ) );
-#endif
countLoaded++;
cb += bytesVector.size();
}
diff --git a/Whisper/modelFactory.cpp b/Whisper/modelFactory.cpp
index a708551..d50b854 100644
--- a/Whisper/modelFactory.cpp
+++ b/Whisper/modelFactory.cpp
@@ -1,16 +1,18 @@
-#include "stdafx.h"
+#include "stdafx.h"
#include "modelFactory.h"
#include "API/iContext.cl.h"
-HRESULT COMLIGHTCALL Whisper::loadModel( const wchar_t* path, eModelImplementation impl, const sLoadModelCallbacks* callbacks, iModel** pp )
+HRESULT COMLIGHTCALL Whisper::loadModel( const wchar_t* path, eModelImplementation impl, uint32_t flags, const sLoadModelCallbacks* callbacks, iModel** pp )
{
switch( impl )
{
case eModelImplementation::GPU:
- return loadGpuModel( path, false, callbacks, pp );
+ return loadGpuModel( path, false, flags, callbacks, pp );
case eModelImplementation::Hybrid:
- return loadGpuModel( path, true, callbacks, pp );
+ return loadGpuModel( path, true, flags, callbacks, pp );
case eModelImplementation::Reference:
+ if( 0 != flags )
+ logWarning( u8"The reference model doesn’t currently use any flags, argument ignored" );
return loadReferenceCpuModel( path, pp );
}
diff --git a/Whisper/modelFactory.h b/Whisper/modelFactory.h
index ebe77b1..78c04d4 100644
--- a/Whisper/modelFactory.h
+++ b/Whisper/modelFactory.h
@@ -5,7 +5,7 @@ namespace Whisper
{
struct iModel;
- HRESULT __stdcall loadGpuModel( const wchar_t* path, bool hybrid, const sLoadModelCallbacks* callbacks, iModel** pp );
+ HRESULT __stdcall loadGpuModel( const wchar_t* path, bool hybrid, uint32_t flags, const sLoadModelCallbacks* callbacks, iModel** pp );
HRESULT __stdcall loadReferenceCpuModel( const wchar_t* path, iModel** pp );
} \ No newline at end of file
diff --git a/Whisper/stdafx.h b/Whisper/stdafx.h
index e08f3ab..5df6e1b 100644
--- a/Whisper/stdafx.h
+++ b/Whisper/stdafx.h
@@ -36,9 +36,4 @@
// In addition to collecting total GPU times per compute shader, also collect and print performance data about individual invocations of some of the most expensive shaders
// The feature is relatively cheap in terms of performance overhead, but pretty much useless in production, and clutters debug console with all these numbers
-#define PROFILER_COLLECT_TAGS 0
-
-// Reshape some of the tensors to a better VRAM layout while loading a model
-// So far, the feature is only used on AMD GPUs. On AMD Vega integrated GPUs it helps by up to 30%.
-// Should be enabled in production build
-#define RESHAPED_MATRIX_MULTIPLY 1 \ No newline at end of file
+#define PROFILER_COLLECT_TAGS 0 \ No newline at end of file
diff --git a/WhisperNet/API/eGpuModelFlags.cs b/WhisperNet/API/eGpuModelFlags.cs
new file mode 100644
index 0000000..106235f
--- /dev/null
+++ b/WhisperNet/API/eGpuModelFlags.cs
@@ -0,0 +1,28 @@
+namespace Whisper
+{
+ /// <summary>These flags affect compute shaders performance (which ones are faster depends on GPU model),<br/>
+ /// and VRAM memory usage (UseReshapedMatMul needs slightly more VRAM).</summary>
+ [Flags]
+ public enum eGpuModelFlags: uint
+ {
+ /// <summary>Equivalent to <c>Wave32 | NoReshapedMatMul</c> on Intel and nVidia GPUs,<br/>
+ /// and <c>Wave64 | UseReshapedMatMul</c> on AMD GPUs</summary>
+ None = 0,
+
+ /// <summary>Use Wave32 version of compute shaders even on AMD GPUs</summary>
+ /// <remarks>Incompatible with <see cref="Wave64" /></remarks>
+ Wave32 = 1,
+
+ /// <summary>Use Wave64 version of compute shaders even on nVidia and Intel GPUs</summary>
+ /// <remarks>Incompatible with <see cref="Wave32" /></remarks>
+ Wave64 = 2,
+
+ /// <summary>Do not use reshaped matrix multiplication shaders on AMD GPUs</summary>
+ /// <remarks>Incompatible with <see cref="UseReshapedMatMul" /></remarks>
+ NoReshapedMatMul = 4,
+
+ /// <summary>Use reshaped matrix multiplication shaders even on nVidia and Intel GPUs</summary>
+ /// <remarks>Incompatible with <see cref="NoReshapedMatMul" /></remarks>
+ UseReshapedMatMul = 8,
+ }
+} \ No newline at end of file
diff --git a/WhisperNet/API/eModelImplementation.cs b/WhisperNet/API/eModelImplementation.cs
index 1b0a079..a0e61fb 100644
--- a/WhisperNet/API/eModelImplementation.cs
+++ b/WhisperNet/API/eModelImplementation.cs
@@ -1,6 +1,6 @@
namespace Whisper
{
- /// <summary>Implementation value for the <see cref="Library.loadModel(string, eModelImplementation)" /> factory function</summary>
+ /// <summary>Implementation value for the <see cref="Library.loadModel(string, eGpuModelFlags, eModelImplementation)" /> factory function</summary>
public enum eModelImplementation: uint
{
/// <summary>GPGPU implementation based on Direct3D 11.0 compute shaders</summary>
diff --git a/WhisperNet/Library.cs b/WhisperNet/Library.cs
index 72ecb6e..5bdb0a3 100644
--- a/WhisperNet/Library.cs
+++ b/WhisperNet/Library.cs
@@ -35,7 +35,7 @@ namespace Whisper
/// <remarks>Models are large, depending on user’s disk speed this might take a while, and this function blocks the calling thread.<br/>
/// Consider <see cref="loadModelAsync" /> instead.</remarks>
/// <seealso href="https://huggingface.co/datasets/ggerganov/whisper.cpp" />
- public static iModel loadModel( string path, eModelImplementation impl = eModelImplementation.GPU )
+ public static iModel loadModel( string path, eGpuModelFlags flags = eGpuModelFlags.None, eModelImplementation impl = eModelImplementation.GPU )
{
iModel model;
sLoadModelCallbacks callbacks = default;