summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2024-11-16 15:27:25 -0800
committeryum <yum.food.vr@gmail.com>2024-11-16 15:52:24 -0800
commit953c03d21d6ef75d115a25ff83e2e3a706b685a6 (patch)
tree0fc15d8be147a021b6e9aaaf50749a7aa9f191df
parent673d701ea471daebecb1fb0c1edf79a2017a78ac (diff)
Remove flash_attention toggle
Deprecated in the Python release of CTranslate2 as of 4.4.0: https://github.com/OpenNMT/CTranslate2/blob/master/CHANGELOG.md#v440-2024-09-09
-rw-r--r--GUI/GUI/GUI/Frame.cpp18
-rw-r--r--GUI/GUI/GUI/Frame.h1
-rw-r--r--Scripts/transcribe_v2.py12
3 files changed, 2 insertions, 29 deletions
diff --git a/GUI/GUI/GUI/Frame.cpp b/GUI/GUI/GUI/Frame.cpp
index 0782376..697e18a 100644
--- a/GUI/GUI/GUI/Frame.cpp
+++ b/GUI/GUI/GUI/Frame.cpp
@@ -65,7 +65,6 @@ namespace {
ID_PY_APP_ENABLE_ORIG_LANG,
ID_PY_APP_ENABLE_BROWSER_SRC,
ID_PY_APP_USE_CPU,
- ID_PY_APP_USE_FLASH_ATTENTION,
ID_PY_APP_USE_BUILTIN,
ID_PY_APP_ENABLE_UWU_FILTER,
ID_PY_APP_REMOVE_TRAILING_PERIOD,
@@ -1022,16 +1021,6 @@ Frame::Frame()
);
py_app_use_cpu_ = py_app_use_cpu;
- auto* py_app_use_flash_attention = new wxCheckBox(py_config_panel,
- ID_PY_APP_USE_FLASH_ATTENTION, "Use flash attention");
- py_app_use_flash_attention->SetValue(app_c_->use_flash_attention);
- py_app_use_flash_attention->SetToolTip(
- "If checked, the transcription engine will use flash "
- "attention for inference. This is much faster and more "
- "efficient, but requires a 3000 series GPU or higher."
- );
- py_app_use_flash_attention_ = py_app_use_flash_attention;
-
auto* py_app_use_builtin = new wxCheckBox(py_config_panel,
ID_PY_APP_USE_BUILTIN, "Use built-in chatbox");
py_app_use_builtin->SetValue(app_c_->use_builtin);
@@ -1151,8 +1140,6 @@ Frame::Frame()
/*flags=*/wxEXPAND);
sizer->Add(py_app_use_cpu, /*proportion=*/0,
/*flags=*/wxEXPAND);
- sizer->Add(py_app_use_flash_attention, /*proportion=*/0,
- /*flags=*/wxEXPAND);
sizer->Add(py_app_use_builtin, /*proportion=*/0,
/*flags=*/wxEXPAND);
sizer->Add(py_app_enable_uwu_filter, /*proportion=*/0,
@@ -1745,9 +1732,6 @@ void Frame::ApplyConfigToInputFields()
auto* py_app_use_cpu = static_cast<wxCheckBox*>(FindWindowById(ID_PY_APP_USE_CPU));
py_app_use_cpu->SetValue(app_c_->use_cpu);
- auto* py_app_use_flash_attention = static_cast<wxCheckBox*>(FindWindowById(ID_PY_APP_USE_FLASH_ATTENTION));
- py_app_use_flash_attention->SetValue(app_c_->use_flash_attention);
-
auto* py_app_use_builtin = static_cast<wxCheckBox*>(FindWindowById(ID_PY_APP_USE_BUILTIN));
py_app_use_builtin->SetValue(app_c_->use_builtin);
@@ -2587,7 +2571,6 @@ void Frame::OnAppStart(wxCommandEvent& event) {
const bool enable_orig_lang = py_app_enable_orig_lang_->GetValue();
const bool enable_browser_src = py_app_enable_browser_src_->GetValue();
const bool use_cpu = py_app_use_cpu_->GetValue();
- const bool use_flash_attention = py_app_use_flash_attention_->GetValue();
const bool use_builtin = py_app_use_builtin_->GetValue();
const bool enable_uwu_filter = py_app_enable_uwu_filter_->GetValue();
const bool remove_trailing_period = py_app_remove_trailing_period_->GetValue();
@@ -2629,7 +2612,6 @@ void Frame::OnAppStart(wxCommandEvent& event) {
app_c_->enable_browser_src = enable_browser_src;
app_c_->browser_src_port = browser_src_port;
app_c_->use_cpu = use_cpu;
- app_c_->use_flash_attention = use_flash_attention;
app_c_->use_builtin = use_builtin;
app_c_->enable_uwu_filter = enable_uwu_filter;
app_c_->remove_trailing_period = remove_trailing_period;
diff --git a/GUI/GUI/GUI/Frame.h b/GUI/GUI/GUI/Frame.h
index 2d682a7..7e5c7c7 100644
--- a/GUI/GUI/GUI/Frame.h
+++ b/GUI/GUI/GUI/Frame.h
@@ -72,7 +72,6 @@ private:
wxCheckBox* py_app_enable_orig_lang_;
wxCheckBox* py_app_enable_browser_src_;
wxCheckBox* py_app_use_cpu_;
- wxCheckBox* py_app_use_flash_attention_;
wxCheckBox* py_app_use_builtin_;
wxCheckBox* py_app_enable_uwu_filter_;
wxCheckBox* py_app_remove_trailing_period_;
diff --git a/Scripts/transcribe_v2.py b/Scripts/transcribe_v2.py
index 63d7f52..e024bae 100644
--- a/Scripts/transcribe_v2.py
+++ b/Scripts/transcribe_v2.py
@@ -423,9 +423,6 @@ class Whisper:
model_device = "cuda"
if cfg["use_cpu"]:
model_device = "cpu"
- if cfg["use_flash_attention"]:
- print(f"Flash attention disabled on CPU", file=sys.stderr)
- cfg["use_flash_attention"] = False
already_downloaded = os.path.exists(model_root)
@@ -434,8 +431,7 @@ class Whisper:
device_index = cfg["gpu_idx"],
compute_type = cfg["compute_type"],
download_root = model_root,
- local_files_only = already_downloaded,
- flash_attention = cfg["use_flash_attention"])
+ local_files_only = already_downloaded)
def transcribe(self, frames: bytes = None) -> typing.List[Segment]:
if frames is None:
@@ -612,15 +608,11 @@ class TranslationPlugin(StreamingPlugin):
model_device = "cuda"
if cfg["use_cpu"]:
model_device = "cpu"
- if cfg["use_flash_attention"]:
- print(f"Flash attention disabled on CPU", file=sys.stderr)
- cfg["use_flash_attention"] = False
self.translator = ctranslate2.Translator(output_dir,
device = model_device,
device_index = cfg["gpu_idx"],
- compute_type = cfg["compute_type"],
- flash_attention = cfg["use_flash_attention"])
+ compute_type = cfg["compute_type"])
whisper_lang = cfg["language"]
nllb_lang = lang_compat.whisper_to_nllb[whisper_lang]