summaryrefslogtreecommitdiffstats
path: root/Whisper
diff options
context:
space:
mode:
Diffstat (limited to 'Whisper')
-rw-r--r--Whisper/MF/AudioBuffer.h1
-rw-r--r--Whisper/Whisper/ContextImpl.cpp2
-rw-r--r--Whisper/WhisperCLI/WhisperCLI.cpp196
-rw-r--r--Whisper/WhisperCLI/WhisperCLI.vcxproj143
-rw-r--r--Whisper/WhisperCLI/WhisperCLI.vcxproj.filters22
5 files changed, 361 insertions, 3 deletions
diff --git a/Whisper/MF/AudioBuffer.h b/Whisper/MF/AudioBuffer.h
index 63c4a8c..11b5ead 100644
--- a/Whisper/MF/AudioBuffer.h
+++ b/Whisper/MF/AudioBuffer.h
@@ -47,7 +47,6 @@ namespace Whisper
void dropFirst(size_t len)
{
- assert(len <= mono.size());
if (len >= mono.size()) {
mono.clear();
return;
diff --git a/Whisper/Whisper/ContextImpl.cpp b/Whisper/Whisper/ContextImpl.cpp
index 4e384ca..22347cf 100644
--- a/Whisper/Whisper/ContextImpl.cpp
+++ b/Whisper/Whisper/ContextImpl.cpp
@@ -837,12 +837,10 @@ HRESULT COMLIGHTCALL ContextImpl::runFullImpl( const sFullParams& params, const
// in ascending order, we're always copying from old or
// identical data.
for (int nth_beam = 0; nth_beam < ctx_.size(); nth_beam++) {
-#if 0
// Trivial optimization: only copy if beams differ.
if (nth_beam == best_beams_and_tokens[nth_beam].first) {
continue;
}
-#endif
ctx_[nth_beam] = ctx_[best_beams_and_tokens[nth_beam].first];
}
diff --git a/Whisper/WhisperCLI/WhisperCLI.cpp b/Whisper/WhisperCLI/WhisperCLI.cpp
new file mode 100644
index 0000000..af0a59f
--- /dev/null
+++ b/Whisper/WhisperCLI/WhisperCLI.cpp
@@ -0,0 +1,196 @@
+#define WIN32_LEAN_AND_MEAN
+
+#include <Unknwn.h>
+#include <windows.h>
+
+#include "Whisper/API/whisperWindows.h"
+
+#include <iostream>
+#include <locale>
+#include <set>
+#include <string>
+#include <string_view>
+
+using std::cout;
+using std::cerr;
+using std::endl;
+using namespace Whisper;
+
+struct Config {
+ std::wstring audio_path = L"input.wav";
+ std::wstring model_path = L"model.bin";
+ eSamplingStrategy decode_method = eSamplingStrategy::BeamSearch;
+};
+
+bool hasArg(int argc, int shift, char* arg) {
+ if (shift + 1 >= argc) {
+ cerr << "Error: " << arg << " is missing argument" << endl;
+ return false;
+ }
+ return true;
+}
+
+std::wstring cstrToWstr(char* c_str) {
+ int length = MultiByteToWideChar(CP_UTF8, 0, c_str, -1, NULL, 0);
+ std::wstring result(length, 0);
+ MultiByteToWideChar(CP_UTF8, 0, c_str, -1, result.data(), result.size());
+ return result;
+}
+
+
+bool parseArgs(int argc, char* argv[], Config& c) {
+ int shift = 1;
+ while (shift < argc) {
+ if (std::string_view(argv[shift]) == "--audio_path") {
+ if (!hasArg(argc, shift, argv[shift])) {
+ return false;
+ }
+ c.audio_path = cstrToWstr(argv[shift + 1]);
+ shift += 2;
+ continue;
+ }
+ if (std::string_view(argv[shift]) == "--model_path") {
+ if (!hasArg(argc, shift, argv[shift])) {
+ return false;
+ }
+ c.model_path = cstrToWstr(argv[shift + 1]);
+ shift += 2;
+ continue;
+ }
+ if (std::string_view(argv[shift]) == "--decode_method") {
+ if (!hasArg(argc, shift, argv[shift])) {
+ return false;
+ }
+ std::string_view decode_method(argv[shift + 1]);
+ if (decode_method == "greedy") {
+ cerr << "Using greedy decode " << endl;
+ c.decode_method = eSamplingStrategy::Greedy;
+ }
+ else if (decode_method == "beam") {
+ cerr << "Using beam decode " << endl;
+ c.decode_method = eSamplingStrategy::BeamSearch;
+ }
+ else {
+ cerr << "Unsupported decode method: " << decode_method << endl;
+ return false;
+ }
+ shift += 2;
+ continue;
+ }
+ cerr << "Unrecognized argument: \"" << argv[shift] << '"' << endl;
+ return false;
+ }
+ return true;
+}
+
+int main(int argc, char* argv[])
+{
+ Config c;
+ if (!parseArgs(argc, argv, c)) {
+ cerr << "Failed to parse args";
+ return 1;
+ }
+
+ iMediaFoundation* f = nullptr;
+ HRESULT err = initMediaFoundation(&f);
+ if (FAILED(err)) {
+ cerr << "Failed to init media foundation: " << err << endl;
+ return 1;
+ }
+
+ Whisper::iAudioBuffer* buffer = nullptr;
+ err = f->loadAudioFile(c.audio_path.c_str(), /*stereo=*/false, &buffer);
+ if (FAILED(err)) {
+ cerr << "Failed to load audio file 'input.wav': " << err << endl;
+ return 1;
+ }
+
+ Whisper::iModel* model = nullptr;
+ err = Whisper::loadModel(c.model_path.c_str(), eModelImplementation::GPU, /*flags=*/0, /*callbacks=*/nullptr, &model);
+ if (FAILED(err)) {
+ cerr << "Failed to open model 'model.bin': " << err << endl;
+ return 1;
+ }
+
+ Whisper::iContext* context = nullptr;
+ err = model->createContext(&context);
+ if (FAILED(err)) {
+ cerr << "Failed to create context: " << err << endl;
+ return 1;
+ }
+
+ Whisper::sFullParams wparams{};
+ context->fullDefaultParams(c.decode_method, &wparams);
+ if (c.decode_method == eSamplingStrategy::BeamSearch) {
+ wparams.beam_search.beam_width = 5;
+ wparams.beam_search.n_best = 5;
+ }
+ wparams.language = Whisper::makeLanguageKey("en");
+ wparams.n_max_text_ctx = 100;
+
+ err = context->runFull(wparams, buffer);
+ if (FAILED(err)) {
+ cerr << "Failed to transcribe: " << err << endl;
+ return 1;
+ }
+
+ Whisper::iTranscribeResult* result = nullptr;
+ err = context->getResults(eResultFlags::Tokens, &result);
+ if (FAILED(err)) {
+ cerr << "Failed to get transcription results: " << err << endl;
+ return 1;
+ }
+
+ std::set<int> special_tokens;
+ {
+ Whisper::SpecialTokens st;
+ err = model->getSpecialTokens(st);
+ if (FAILED(err)) {
+ cerr << "Failed to get special tokens: " << err << endl;
+ }
+ special_tokens.insert(st.Not);
+ special_tokens.insert(st.PreviousWord);
+ special_tokens.insert(st.SentenceStart);
+ special_tokens.insert(st.TaskTranscribe);
+ special_tokens.insert(st.TaskTranslate);
+ special_tokens.insert(st.TranscriptionBegin);
+ special_tokens.insert(st.TranscriptionEnd);
+ special_tokens.insert(st.TranscriptionStart);
+ }
+
+ sTranscribeLength length;
+ err = result->getSize(length);
+ if (FAILED(err)) {
+ cerr << "Failed to get transcription length: " << err << endl;
+ }
+ auto* segments = result->getSegments();
+ auto* tokens = result->getTokens();
+ bool is_metadata = false;
+ for (int i = 0; i < length.countSegments; i++) {
+ auto& segment = segments[i];
+ for (int j = 0; j < segment.countTokens; j++) {
+ const sToken& tok = tokens[segment.firstToken + j];
+ if (special_tokens.contains(tok.id)) {
+ continue;
+ }
+ std::string_view tok_str(tok.text);
+ if (tok_str.starts_with("[") ||
+ tok_str.starts_with(" [")) {
+ if (tok_str.ends_with("]")) {
+ continue;
+ }
+ is_metadata = true;
+ continue;
+ }
+ if (is_metadata &&
+ tok_str.ends_with("]")) {
+ is_metadata = false;
+ continue;
+ }
+ cout << tok.text;
+ }
+ }
+ cout << endl;
+
+ return 0;
+} \ No newline at end of file
diff --git a/Whisper/WhisperCLI/WhisperCLI.vcxproj b/Whisper/WhisperCLI/WhisperCLI.vcxproj
new file mode 100644
index 0000000..8295eee
--- /dev/null
+++ b/Whisper/WhisperCLI/WhisperCLI.vcxproj
@@ -0,0 +1,143 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project DefaultTargets="Build" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup Label="ProjectConfigurations">
+ <ProjectConfiguration Include="Debug|Win32">
+ <Configuration>Debug</Configuration>
+ <Platform>Win32</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Release|Win32">
+ <Configuration>Release</Configuration>
+ <Platform>Win32</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Debug|x64">
+ <Configuration>Debug</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ <ProjectConfiguration Include="Release|x64">
+ <Configuration>Release</Configuration>
+ <Platform>x64</Platform>
+ </ProjectConfiguration>
+ </ItemGroup>
+ <PropertyGroup Label="Globals">
+ <VCProjectVersion>16.0</VCProjectVersion>
+ <Keyword>Win32Proj</Keyword>
+ <ProjectGuid>{b561d29f-be1d-4a4f-acfc-4d075cbc9108}</ProjectGuid>
+ <RootNamespace>WhisperCLI</RootNamespace>
+ <WindowsTargetPlatformVersion>10.0</WindowsTargetPlatformVersion>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.Default.props" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'" Label="Configuration">
+ <ConfigurationType>Application</ConfigurationType>
+ <UseDebugLibraries>true</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'" Label="Configuration">
+ <ConfigurationType>Application</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <WholeProgramOptimization>true</WholeProgramOptimization>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'" Label="Configuration">
+ <ConfigurationType>Application</ConfigurationType>
+ <UseDebugLibraries>true</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'" Label="Configuration">
+ <ConfigurationType>Application</ConfigurationType>
+ <UseDebugLibraries>false</UseDebugLibraries>
+ <PlatformToolset>v143</PlatformToolset>
+ <WholeProgramOptimization>true</WholeProgramOptimization>
+ <CharacterSet>Unicode</CharacterSet>
+ </PropertyGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
+ <ImportGroup Label="ExtensionSettings">
+ </ImportGroup>
+ <ImportGroup Label="Shared">
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
+ </ImportGroup>
+ <PropertyGroup Label="UserMacros" />
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <IncludePath>$(VC_IncludePath);$(WindowsSDK_IncludePath);$(SolutionDir);</IncludePath>
+ </PropertyGroup>
+ <PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <IncludePath>$(VC_IncludePath);$(WindowsSDK_IncludePath);$(SolutionDir);</IncludePath>
+ </PropertyGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>WIN32;_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>WIN32;NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>_DEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ <LanguageStandard>stdcpp20</LanguageStandard>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ <AdditionalDependencies>$(CoreLibraryDependencies);%(AdditionalDependencies);$(OutDir)Whisper.lib</AdditionalDependencies>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemDefinitionGroup Condition="'$(Configuration)|$(Platform)'=='Release|x64'">
+ <ClCompile>
+ <WarningLevel>Level3</WarningLevel>
+ <FunctionLevelLinking>true</FunctionLevelLinking>
+ <IntrinsicFunctions>true</IntrinsicFunctions>
+ <SDLCheck>true</SDLCheck>
+ <PreprocessorDefinitions>NDEBUG;_CONSOLE;%(PreprocessorDefinitions)</PreprocessorDefinitions>
+ <ConformanceMode>true</ConformanceMode>
+ </ClCompile>
+ <Link>
+ <SubSystem>Console</SubSystem>
+ <EnableCOMDATFolding>true</EnableCOMDATFolding>
+ <OptimizeReferences>true</OptimizeReferences>
+ <GenerateDebugInformation>true</GenerateDebugInformation>
+ </Link>
+ </ItemDefinitionGroup>
+ <ItemGroup>
+ <ClCompile Include="WhisperCLI.cpp" />
+ </ItemGroup>
+ <Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
+ <ImportGroup Label="ExtensionTargets">
+ </ImportGroup>
+</Project> \ No newline at end of file
diff --git a/Whisper/WhisperCLI/WhisperCLI.vcxproj.filters b/Whisper/WhisperCLI/WhisperCLI.vcxproj.filters
new file mode 100644
index 0000000..34bd24e
--- /dev/null
+++ b/Whisper/WhisperCLI/WhisperCLI.vcxproj.filters
@@ -0,0 +1,22 @@
+<?xml version="1.0" encoding="utf-8"?>
+<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
+ <ItemGroup>
+ <Filter Include="Source Files">
+ <UniqueIdentifier>{4FC737F1-C7A5-4376-A066-2A32D752A2FF}</UniqueIdentifier>
+ <Extensions>cpp;c;cc;cxx;c++;cppm;ixx;def;odl;idl;hpj;bat;asm;asmx</Extensions>
+ </Filter>
+ <Filter Include="Header Files">
+ <UniqueIdentifier>{93995380-89BD-4b04-88EB-625FBE52EBFB}</UniqueIdentifier>
+ <Extensions>h;hh;hpp;hxx;h++;hm;inl;inc;ipp;xsd</Extensions>
+ </Filter>
+ <Filter Include="Resource Files">
+ <UniqueIdentifier>{67DA6AB6-F800-4c08-8B7A-83BB121AAD01}</UniqueIdentifier>
+ <Extensions>rc;ico;cur;bmp;dlg;rc2;rct;bin;rgs;gif;jpg;jpeg;jpe;resx;tiff;tif;png;wav;mfcribbon-ms</Extensions>
+ </Filter>
+ </ItemGroup>
+ <ItemGroup>
+ <ClCompile Include="WhisperCLI.cpp">
+ <Filter>Source Files</Filter>
+ </ClCompile>
+ </ItemGroup>
+</Project> \ No newline at end of file