From f2b21dd5afebd6b76b5835168f7d1bd3bec21f5d Mon Sep 17 00:00:00 2001 From: yum Date: Sun, 9 Jun 2024 15:54:30 -0700 Subject: Add checkbox for flash-attention Pre-3000 series GPUs don't support it. Oops! --- GUI/GUI/GUI/Config.cpp | 2 ++ GUI/GUI/GUI/Config.h | 1 + GUI/GUI/GUI/Frame.cpp | 18 ++++++++++++++++++ GUI/GUI/GUI/Frame.h | 1 + Scripts/transcribe_v2.py | 4 +++- 5 files changed, 25 insertions(+), 1 deletion(-) diff --git a/GUI/GUI/GUI/Config.cpp b/GUI/GUI/GUI/Config.cpp index 1314c4d..0a71ad1 100644 --- a/GUI/GUI/GUI/Config.cpp +++ b/GUI/GUI/GUI/Config.cpp @@ -76,6 +76,7 @@ AppConfig::AppConfig(wxTextCtrl* out) browser_src_port(8097), commit_fuzz_threshold(4), use_cpu(false), + use_flash_attention(false), use_builtin(false), enable_uwu_filter(false), remove_trailing_period(false), @@ -125,6 +126,7 @@ bool AppConfig::Serialize(const std::filesystem::path& path) { cm.Set("browser_src_port", browser_src_port); cm.Set("commit_fuzz_threshold", commit_fuzz_threshold); cm.Set("use_cpu", use_cpu); + cm.Set("use_flash_attention", use_flash_attention); cm.Set("use_builtin", use_builtin); cm.Set("enable_uwu_filter", enable_uwu_filter); cm.Set("remove_trailing_period", remove_trailing_period); diff --git a/GUI/GUI/GUI/Config.h b/GUI/GUI/GUI/Config.h index 906982e..e75e4d5 100644 --- a/GUI/GUI/GUI/Config.h +++ b/GUI/GUI/GUI/Config.h @@ -62,6 +62,7 @@ public: int browser_src_port; int commit_fuzz_threshold; bool use_cpu; + bool use_flash_attention; bool use_builtin; bool enable_uwu_filter; bool remove_trailing_period; diff --git a/GUI/GUI/GUI/Frame.cpp b/GUI/GUI/GUI/Frame.cpp index 2ee6f1b..9a69651 100644 --- a/GUI/GUI/GUI/Frame.cpp +++ b/GUI/GUI/GUI/Frame.cpp @@ -64,6 +64,7 @@ namespace { ID_PY_APP_ENABLE_LOCAL_BEEP, ID_PY_APP_ENABLE_BROWSER_SRC, ID_PY_APP_USE_CPU, + ID_PY_APP_USE_FLASH_ATTENTION, ID_PY_APP_USE_BUILTIN, ID_PY_APP_ENABLE_UWU_FILTER, ID_PY_APP_REMOVE_TRAILING_PERIOD, @@ -995,6 +996,16 @@ Frame::Frame() ); py_app_use_cpu_ = py_app_use_cpu; + auto* py_app_use_flash_attention = new wxCheckBox(py_config_panel, + ID_PY_APP_USE_FLASH_ATTENTION, "Use flash attention"); + py_app_use_flash_attention->SetValue(app_c_->use_flash_attention); + py_app_use_flash_attention->SetToolTip( + "If checked, the transcription engine will use flash " + "attention for inference. This is much faster and more " + "efficient, but requires a 3000 series GPU or higher." + ); + py_app_use_flash_attention_ = py_app_use_flash_attention; + auto* py_app_use_builtin = new wxCheckBox(py_config_panel, ID_PY_APP_USE_BUILTIN, "Use built-in chatbox"); py_app_use_builtin->SetValue(app_c_->use_builtin); @@ -1112,6 +1123,8 @@ Frame::Frame() /*flags=*/wxEXPAND); sizer->Add(py_app_use_cpu, /*proportion=*/0, /*flags=*/wxEXPAND); + sizer->Add(py_app_use_flash_attention, /*proportion=*/0, + /*flags=*/wxEXPAND); sizer->Add(py_app_use_builtin, /*proportion=*/0, /*flags=*/wxEXPAND); sizer->Add(py_app_enable_uwu_filter, /*proportion=*/0, @@ -1701,6 +1714,9 @@ void Frame::ApplyConfigToInputFields() auto* py_app_use_cpu = static_cast(FindWindowById(ID_PY_APP_USE_CPU)); py_app_use_cpu->SetValue(app_c_->use_cpu); + auto* py_app_use_flash_attention = static_cast(FindWindowById(ID_PY_APP_USE_FLASH_ATTENTION)); + py_app_use_flash_attention->SetValue(app_c_->use_flash_attention); + auto* py_app_use_builtin = static_cast(FindWindowById(ID_PY_APP_USE_BUILTIN)); py_app_use_builtin->SetValue(app_c_->use_builtin); @@ -2450,6 +2466,7 @@ void Frame::OnAppStart(wxCommandEvent& event) { const bool enable_local_beep = py_app_enable_local_beep_->GetValue(); const bool enable_browser_src = py_app_enable_browser_src_->GetValue(); const bool use_cpu = py_app_use_cpu_->GetValue(); + const bool use_flash_attention = py_app_use_flash_attention_->GetValue(); const bool use_builtin = py_app_use_builtin_->GetValue(); const bool enable_uwu_filter = py_app_enable_uwu_filter_->GetValue(); const bool remove_trailing_period = py_app_remove_trailing_period_->GetValue(); @@ -2490,6 +2507,7 @@ void Frame::OnAppStart(wxCommandEvent& event) { app_c_->enable_browser_src = enable_browser_src; app_c_->browser_src_port = browser_src_port; app_c_->use_cpu = use_cpu; + app_c_->use_flash_attention = use_flash_attention; app_c_->use_builtin = use_builtin; app_c_->enable_uwu_filter = enable_uwu_filter; app_c_->remove_trailing_period = remove_trailing_period; diff --git a/GUI/GUI/GUI/Frame.h b/GUI/GUI/GUI/Frame.h index f509dae..fc8bac8 100644 --- a/GUI/GUI/GUI/Frame.h +++ b/GUI/GUI/GUI/Frame.h @@ -71,6 +71,7 @@ private: wxCheckBox* py_app_enable_local_beep_; wxCheckBox* py_app_enable_browser_src_; wxCheckBox* py_app_use_cpu_; + wxCheckBox* py_app_use_flash_attention_; wxCheckBox* py_app_use_builtin_; wxCheckBox* py_app_enable_uwu_filter_; wxCheckBox* py_app_remove_trailing_period_; diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index 32deb42..2f37945 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -426,13 +426,15 @@ class Whisper: already_downloaded = os.path.exists(model_root) + print(f"Use flash attention {cfg['use_flash_attention']}") + self.model = WhisperModel(model_str, device = model_device, device_index = cfg["gpu_idx"], compute_type = cfg["compute_type"], download_root = model_root, local_files_only = already_downloaded, - flash_attention = True) + flash_attention = cfg["use_flash_attention"]) def transcribe(self, frames: bytes = None) -> typing.List[Segment]: if frames is None: -- cgit v1.2.3