diff options
Diffstat (limited to 'Whisper')
| -rw-r--r-- | Whisper/MF/AudioBuffer.h | 1 | ||||
| -rw-r--r-- | Whisper/Whisper/ContextImpl.cpp | 2 | ||||
| -rw-r--r-- | Whisper/WhisperCLI/WhisperCLI.cpp | 196 | ||||
| -rw-r--r-- | Whisper/WhisperCLI/WhisperCLI.vcxproj | 143 | ||||
| -rw-r--r-- | Whisper/WhisperCLI/WhisperCLI.vcxproj.filters | 22 |
5 files changed, 361 insertions, 3 deletions
diff --git a/Whisper/MF/AudioBuffer.h b/Whisper/MF/AudioBuffer.h index 63c4a8c..11b5ead 100644 --- a/Whisper/MF/AudioBuffer.h +++ b/Whisper/MF/AudioBuffer.h @@ -47,7 +47,6 @@ namespace Whisper void dropFirst(size_t len) { - assert(len <= mono.size()); if (len >= mono.size()) { mono.clear(); return; diff --git a/Whisper/Whisper/ContextImpl.cpp b/Whisper/Whisper/ContextImpl.cpp index 4e384ca..22347cf 100644 --- a/Whisper/Whisper/ContextImpl.cpp +++ b/Whisper/Whisper/ContextImpl.cpp @@ -837,12 +837,10 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const // in ascending order, we're always copying from old or // identical data. for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) { -#if 0 // Trivial optimization: only copy if beams differ. if (nth_beam == best_beams_and_tokens[nth_beam].first) { continue; } -#endif ctx_[nth_beam] = ctx_[best_beams_and_tokens[nth_beam].first]; } diff --git a/Whisper/WhisperCLI/WhisperCLI.cpp b/Whisper/WhisperCLI/WhisperCLI.cpp new file mode 100644 index 0000000..af0a59f --- /dev/null +++ b/Whisper/WhisperCLI/WhisperCLI.cpp @@ -0,0 +1,196 @@ +#define WIN32_LEAN_AND_MEAN + +#include <Unknwn.h> +#include <windows.h> + +#include "Whisper/API/whisperWindows.h" + +#include <iostream> +#include <locale> +#include <set> +#include <string> +#include <string_view> + +using std::cout; +using std::cerr; +using std::endl; +using namespace Whisper; + +struct Config { + std::wstring audio_path = L"input.wav"; + std::wstring model_path = L"model.bin"; + eSamplingStrategy decode_method = eSamplingStrategy::BeamSearch; +}; + +bool hasArg(int argc, int shift, char* arg) { + if (shift + 1 >= argc) { + cerr << "Error: " << arg << " is missing argument" << endl; + return false; + } + return true; +} + +std::wstring cstrToWstr(char* c_str) { + int length = MultiByteToWideChar(CP_UTF8, 0, c_str, -1, NULL, 0); + std::wstring result(length, 0); + MultiByteToWideChar(CP_UTF8, 0, c_str, -1, result.data(), result.size()); + return result; +} + + +bool parseArgs(int argc, char* argv[], Config& c) { + int shift = 1; + while (shift < argc) { + if (std::string_view(argv[shift]) == "--audio_path") { + if (!hasArg(argc, shift, argv[shift])) { + return false; + } + c.audio_path = cstrToWstr(argv[shift + 1]); + shift += 2; + continue; + } + if (std::string_view(argv[shift]) == "--model_path") { + if (!hasArg(argc, shift, argv[shift])) { + return false; + } + c.model_path = cstrToWstr(argv[shift + 1]); + shift += 2; + continue; + } + if (std::string_view(argv[shift]) == "--decode_method") { + if (!hasArg(argc, shift, argv[shift])) { + return false; + } + std::string_view decode_method(argv[shift + 1]); + if (decode_method == "greedy") { + cerr << "Using greedy decode " << endl; + c.decode_method = eSamplingStrategy::Greedy; + } + else if (decode_method == "beam") { + cerr << "Using beam decode " << endl; + c.decode_method = eSamplingStrategy::BeamSearch; + } + else { + cerr << "Unsupported decode method: " << decode_method << endl; + return false; + } + shift += 2; + continue; + } + cerr << "Unrecognized argument: \"" << argv[shift] << '"' << endl; + return false; + } + return true; +} + +int main(int argc, char* argv[]) +{ + Config c; + if (!parseArgs(argc, argv, c)) { + cerr << "Failed to parse args"; + return 1; + } + + iMediaFoundation* f = nullptr; + HRESULT err = initMediaFoundation(&f); + if (FAILED(err)) { + cerr << "Failed to init media foundation: " << err << endl; + return 1; + } + + Whisper::iAudioBuffer* buffer = nullptr; + err = f->loadAudioFile(c.audio_path.c_str(), /*stereo=*/false, &buffer); + if (FAILED(err)) { + cerr << "Failed to load audio file 'input.wav': " << err << endl; + return 1; + } + + Whisper::iModel* model = nullptr; + err = Whisper::loadModel(c.model_path.c_str(), eModelImplementation::GPU, /*flags=*/0, /*callbacks=*/nullptr, &model); + if (FAILED(err)) { + cerr << "Failed to open model 'model.bin': " << err << endl; + return 1; + } + + Whisper::iContext* context = nullptr; + err = model->createContext(&context); + if (FAILED(err)) { + cerr << "Failed to create context: " << err << endl; + return 1; + } + + Whisper::sFullParams wparams{}; + context->fullDefaultParams(c.decode_method, &wparams); + if (c.decode_method == eSamplingStrategy::BeamSearch) { + wparams.beam_search.beam_width = 5; + wparams.beam_search.n_best = 5; + } + wparams.language = Whisper::makeLanguageKey("en"); + wparams.n_max_text_ctx = 100; + + err = context->runFull(wparams, buffer); + if (FAILED(err)) { + cerr << "Failed to transcribe: " << err << endl; + return 1; + } + + Whisper::iTranscribeResult* result = nullptr; + err = context->getResults(eResultFlags::Tokens, &result); + if (FAILED(err)) { + cerr << "Failed to get transcription results: " << err << endl; + return 1; + } + + std::set<int> special_tokens; + { + Whisper::SpecialTokens st; + err = model->getSpecialTokens(st); + if (FAILED(err)) { + cerr << "Failed to get special tokens: " << err << endl; + } + special_tokens.insert(st.Not); + special_tokens.insert(st.PreviousWord); + special_tokens.insert(st.SentenceStart); + special_tokens.insert(st.TaskTranscribe); + special_tokens.insert(st.TaskTranslate); + special_tokens.insert(st.TranscriptionBegin); + special_tokens.insert(st.TranscriptionEnd); + special_tokens.insert(st.TranscriptionStart); + } + + sTranscribeLength length; + err = result->getSize(length); + if (FAILED(err)) { + cerr << "Failed to get transcription length: " << err << endl; + } + auto* segments = result->getSegments(); + auto* tokens = result->getTokens(); + bool is_metadata = false; + for (int i = 0; i < length.countSegments; i++) { + auto& segment = segments[i]; + for (int j = 0; j < segment.countTokens; j++) { + const sToken& tok = tokens[segment.firstToken + j]; + if (special_tokens.contains(tok.id)) { + continue; + } + std::string_view tok_str(tok.text); + if (tok_str.starts_with("[") || + tok_str.starts_with(" [")) { + if (tok_str.ends_with("]")) { + continue; + } + is_metadata = true; + continue; + } + if (is_metadata && + tok_str.ends_with("]")) { + is_metadata = false; + continue; + } + cout << tok.text; + } + } + cout << endl; + + return 0; +}
\ No newline at end of file diff --git a/Whisper/WhisperCLI/WhisperCLI.vcxproj b/Whisper/WhisperCLI/WhisperCLI.vcxproj new file mode 100644 index 0000000..8295eee --- /dev/null +++ b/Whisper/WhisperCLI/WhisperCLI.vcxproj @@ -0,0 +1,143 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup Label="ProjectConfigurations"> + <ProjectConfiguration Include="Debug|Win32"> + <Configuration>Debug</Configuration> + <Platform>Win32</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Release|Win32"> + <Configuration>Release</Configuration> + <Platform>Win32</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Debug|x64"> + <Configuration>Debug</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + <ProjectConfiguration Include="Release|x64"> + <Configuration>Release</Configuration> + <Platform>x64</Platform> + </ProjectConfiguration> + </ItemGroup> + <PropertyGroup Label="Globals"> + <VCProjectVersion>16.0</VCProjectVersion> + <Keyword>Win32Proj</Keyword> + <ProjectGuid>{b561d29f-be1d-4a4f-acfc-4d075cbc9108}</ProjectGuid> + <RootNamespace>WhisperCLI</RootNamespace> + <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration"> + <ConfigurationType>Application</ConfigurationType> + <UseDebugLibraries>true</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration"> + <ConfigurationType>Application</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <WholeProgramOptimization>true</WholeProgramOptimization> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration"> + <ConfigurationType>Application</ConfigurationType> + <UseDebugLibraries>true</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration"> + <ConfigurationType>Application</ConfigurationType> + <UseDebugLibraries>false</UseDebugLibraries> + <PlatformToolset>v143</PlatformToolset> + <WholeProgramOptimization>true</WholeProgramOptimization> + <CharacterSet>Unicode</CharacterSet> + </PropertyGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" /> + <ImportGroup Label="ExtensionSettings"> + </ImportGroup> + <ImportGroup Label="Shared"> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|Win32'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" /> + </ImportGroup> + <PropertyGroup Label="UserMacros" /> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <IncludePath>$(VC_IncludePath);$(WindowsSDK_IncludePath);$(SolutionDir);</IncludePath> + </PropertyGroup> + <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <IncludePath>$(VC_IncludePath);$(WindowsSDK_IncludePath);$(SolutionDir);</IncludePath> + </PropertyGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + <LanguageStandard>stdcpp20</LanguageStandard> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <GenerateDebugInformation>true</GenerateDebugInformation> + <AdditionalDependencies>$(CoreLibraryDependencies);%(AdditionalDependencies);$(OutDir)Whisper.lib</AdditionalDependencies> + </Link> + </ItemDefinitionGroup> + <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'"> + <ClCompile> + <WarningLevel>Level3</WarningLevel> + <FunctionLevelLinking>true</FunctionLevelLinking> + <IntrinsicFunctions>true</IntrinsicFunctions> + <SDLCheck>true</SDLCheck> + <PreprocessorDefinitions>NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions> + <ConformanceMode>true</ConformanceMode> + </ClCompile> + <Link> + <SubSystem>Console</SubSystem> + <EnableCOMDATFolding>true</EnableCOMDATFolding> + <OptimizeReferences>true</OptimizeReferences> + <GenerateDebugInformation>true</GenerateDebugInformation> + </Link> + </ItemDefinitionGroup> + <ItemGroup> + <ClCompile Include="WhisperCLI.cpp" /> + </ItemGroup> + <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" /> + <ImportGroup Label="ExtensionTargets"> + </ImportGroup> +</Project>
\ No newline at end of file diff --git a/Whisper/WhisperCLI/WhisperCLI.vcxproj.filters b/Whisper/WhisperCLI/WhisperCLI.vcxproj.filters new file mode 100644 index 0000000..34bd24e --- /dev/null +++ b/Whisper/WhisperCLI/WhisperCLI.vcxproj.filters @@ -0,0 +1,22 @@ +<?xml version="1.0" encoding="utf-8"?> +<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003"> + <ItemGroup> + <Filter Include="Source Files"> + <UniqueIdentifier>{4FC737F1-C7A5-4376-A066-2A32D752A2FF}</UniqueIdentifier> + <Extensions>cpp;c;cc;cxx;c++;cppm;ixx;def;odl;idl;hpj;bat;asm;asmx</Extensions> + </Filter> + <Filter Include="Header Files"> + <UniqueIdentifier>{93995380-89BD-4b04-88EB-625FBE52EBFB}</UniqueIdentifier> + <Extensions>h;hh;hpp;hxx;h++;hm;inl;inc;ipp;xsd</Extensions> + </Filter> + <Filter Include="Resource Files"> + <UniqueIdentifier>{67DA6AB6-F800-4c08-8B7A-83BB121AAD01}</UniqueIdentifier> + <Extensions>rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms</Extensions> + </Filter> + </ItemGroup> + <ItemGroup> + <ClCompile Include="WhisperCLI.cpp"> + <Filter>Source Files</Filter> + </ClCompile> + </ItemGroup> +</Project>
\ No newline at end of file |
