diff options
| -rw-r--r-- | GUI/GUI/GUI/Config.cpp | 2 | ||||
| -rw-r--r-- | GUI/GUI/GUI/Config.h | 1 | ||||
| -rw-r--r-- | GUI/GUI/GUI/Frame.cpp | 18 | ||||
| -rw-r--r-- | GUI/GUI/GUI/Frame.h | 1 | ||||
| -rw-r--r-- | Scripts/transcribe_v2.py | 4 |
5 files changed, 25 insertions, 1 deletions
diff --git a/GUI/GUI/GUI/Config.cpp b/GUI/GUI/GUI/Config.cpp index 1314c4d..0a71ad1 100644 --- a/GUI/GUI/GUI/Config.cpp +++ b/GUI/GUI/GUI/Config.cpp @@ -76,6 +76,7 @@ AppConfig::AppConfig(wxTextCtrl* out) browser_src_port(8097),
commit_fuzz_threshold(4),
use_cpu(false),
+ use_flash_attention(false),
use_builtin(false),
enable_uwu_filter(false),
remove_trailing_period(false),
@@ -125,6 +126,7 @@ bool AppConfig::Serialize(const std::filesystem::path& path) { cm.Set("browser_src_port", browser_src_port);
cm.Set("commit_fuzz_threshold", commit_fuzz_threshold);
cm.Set("use_cpu", use_cpu);
+ cm.Set("use_flash_attention", use_flash_attention);
cm.Set("use_builtin", use_builtin);
cm.Set("enable_uwu_filter", enable_uwu_filter);
cm.Set("remove_trailing_period", remove_trailing_period);
diff --git a/GUI/GUI/GUI/Config.h b/GUI/GUI/GUI/Config.h index 906982e..e75e4d5 100644 --- a/GUI/GUI/GUI/Config.h +++ b/GUI/GUI/GUI/Config.h @@ -62,6 +62,7 @@ public: int browser_src_port;
int commit_fuzz_threshold;
bool use_cpu;
+ bool use_flash_attention;
bool use_builtin;
bool enable_uwu_filter;
bool remove_trailing_period;
diff --git a/GUI/GUI/GUI/Frame.cpp b/GUI/GUI/GUI/Frame.cpp index 2ee6f1b..9a69651 100644 --- a/GUI/GUI/GUI/Frame.cpp +++ b/GUI/GUI/GUI/Frame.cpp @@ -64,6 +64,7 @@ namespace { ID_PY_APP_ENABLE_LOCAL_BEEP,
ID_PY_APP_ENABLE_BROWSER_SRC,
ID_PY_APP_USE_CPU,
+ ID_PY_APP_USE_FLASH_ATTENTION,
ID_PY_APP_USE_BUILTIN,
ID_PY_APP_ENABLE_UWU_FILTER,
ID_PY_APP_REMOVE_TRAILING_PERIOD,
@@ -995,6 +996,16 @@ Frame::Frame() );
py_app_use_cpu_ = py_app_use_cpu;
+ auto* py_app_use_flash_attention = new wxCheckBox(py_config_panel,
+ ID_PY_APP_USE_FLASH_ATTENTION, "Use flash attention");
+ py_app_use_flash_attention->SetValue(app_c_->use_flash_attention);
+ py_app_use_flash_attention->SetToolTip(
+ "If checked, the transcription engine will use flash "
+ "attention for inference. This is much faster and more "
+ "efficient, but requires a 3000 series GPU or higher."
+ );
+ py_app_use_flash_attention_ = py_app_use_flash_attention;
+
auto* py_app_use_builtin = new wxCheckBox(py_config_panel,
ID_PY_APP_USE_BUILTIN, "Use built-in chatbox");
py_app_use_builtin->SetValue(app_c_->use_builtin);
@@ -1112,6 +1123,8 @@ Frame::Frame() /*flags=*/wxEXPAND);
sizer->Add(py_app_use_cpu, /*proportion=*/0,
/*flags=*/wxEXPAND);
+ sizer->Add(py_app_use_flash_attention, /*proportion=*/0,
+ /*flags=*/wxEXPAND);
sizer->Add(py_app_use_builtin, /*proportion=*/0,
/*flags=*/wxEXPAND);
sizer->Add(py_app_enable_uwu_filter, /*proportion=*/0,
@@ -1701,6 +1714,9 @@ void Frame::ApplyConfigToInputFields() auto* py_app_use_cpu = static_cast<wxCheckBox*>(FindWindowById(ID_PY_APP_USE_CPU));
py_app_use_cpu->SetValue(app_c_->use_cpu);
+ auto* py_app_use_flash_attention = static_cast<wxCheckBox*>(FindWindowById(ID_PY_APP_USE_FLASH_ATTENTION));
+ py_app_use_flash_attention->SetValue(app_c_->use_flash_attention);
+
auto* py_app_use_builtin = static_cast<wxCheckBox*>(FindWindowById(ID_PY_APP_USE_BUILTIN));
py_app_use_builtin->SetValue(app_c_->use_builtin);
@@ -2450,6 +2466,7 @@ void Frame::OnAppStart(wxCommandEvent& event) { const bool enable_local_beep = py_app_enable_local_beep_->GetValue();
const bool enable_browser_src = py_app_enable_browser_src_->GetValue();
const bool use_cpu = py_app_use_cpu_->GetValue();
+ const bool use_flash_attention = py_app_use_flash_attention_->GetValue();
const bool use_builtin = py_app_use_builtin_->GetValue();
const bool enable_uwu_filter = py_app_enable_uwu_filter_->GetValue();
const bool remove_trailing_period = py_app_remove_trailing_period_->GetValue();
@@ -2490,6 +2507,7 @@ void Frame::OnAppStart(wxCommandEvent& event) { app_c_->enable_browser_src = enable_browser_src;
app_c_->browser_src_port = browser_src_port;
app_c_->use_cpu = use_cpu;
+ app_c_->use_flash_attention = use_flash_attention;
app_c_->use_builtin = use_builtin;
app_c_->enable_uwu_filter = enable_uwu_filter;
app_c_->remove_trailing_period = remove_trailing_period;
diff --git a/GUI/GUI/GUI/Frame.h b/GUI/GUI/GUI/Frame.h index f509dae..fc8bac8 100644 --- a/GUI/GUI/GUI/Frame.h +++ b/GUI/GUI/GUI/Frame.h @@ -71,6 +71,7 @@ private: wxCheckBox* py_app_enable_local_beep_;
wxCheckBox* py_app_enable_browser_src_;
wxCheckBox* py_app_use_cpu_;
+ wxCheckBox* py_app_use_flash_attention_;
wxCheckBox* py_app_use_builtin_;
wxCheckBox* py_app_enable_uwu_filter_;
wxCheckBox* py_app_remove_trailing_period_;
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py index 32deb42..2f37945 100644 --- a/Scripts/transcribe_v2.py +++ b/Scripts/transcribe_v2.py @@ -426,13 +426,15 @@ class Whisper: already_downloaded = os.path.exists(model_root) + print(f"Use flash attention {cfg['use_flash_attention']}") + self.model = WhisperModel(model_str, device = model_device, device_index = cfg["gpu_idx"], compute_type = cfg["compute_type"], download_root = model_root, local_files_only = already_downloaded, - flash_attention = True) + flash_attention = cfg["use_flash_attention"]) def transcribe(self, frames: bytes = None) -> typing.List[Segment]: if frames is None: |
