summaryrefslogtreecommitdiffstats
path: root/GUI
diff options
context:
space:
mode:
authoryum <yum.food.vr@gmail.com>2024-06-09 15:54:30 -0700
committeryum <yum.food.vr@gmail.com>2024-06-09 15:54:30 -0700
commitf2b21dd5afebd6b76b5835168f7d1bd3bec21f5d (patch)
tree9a72c443088efdd7fdac31ab3ea69b756e75f7df /GUI
parent72b9fb8337cfb7bddc58f74b8977e4a2283e6728 (diff)
Add checkbox for flash-attention
Pre-3000 series GPUs don't support it. Oops!
Diffstat (limited to 'GUI')
-rw-r--r--GUI/GUI/GUI/Config.cpp2
-rw-r--r--GUI/GUI/GUI/Config.h1
-rw-r--r--GUI/GUI/GUI/Frame.cpp18
-rw-r--r--GUI/GUI/GUI/Frame.h1
4 files changed, 22 insertions, 0 deletions
diff --git a/GUI/GUI/GUI/Config.cpp b/GUI/GUI/GUI/Config.cpp
index 1314c4d..0a71ad1 100644
--- a/GUI/GUI/GUI/Config.cpp
+++ b/GUI/GUI/GUI/Config.cpp
@@ -76,6 +76,7 @@ AppConfig::AppConfig(wxTextCtrl* out)
browser_src_port(8097),
commit_fuzz_threshold(4),
use_cpu(false),
+ use_flash_attention(false),
use_builtin(false),
enable_uwu_filter(false),
remove_trailing_period(false),
@@ -125,6 +126,7 @@ bool AppConfig::Serialize(const std::filesystem::path& path) {
cm.Set("browser_src_port", browser_src_port);
cm.Set("commit_fuzz_threshold", commit_fuzz_threshold);
cm.Set("use_cpu", use_cpu);
+ cm.Set("use_flash_attention", use_flash_attention);
cm.Set("use_builtin", use_builtin);
cm.Set("enable_uwu_filter", enable_uwu_filter);
cm.Set("remove_trailing_period", remove_trailing_period);
diff --git a/GUI/GUI/GUI/Config.h b/GUI/GUI/GUI/Config.h
index 906982e..e75e4d5 100644
--- a/GUI/GUI/GUI/Config.h
+++ b/GUI/GUI/GUI/Config.h
@@ -62,6 +62,7 @@ public:
int browser_src_port;
int commit_fuzz_threshold;
bool use_cpu;
+ bool use_flash_attention;
bool use_builtin;
bool enable_uwu_filter;
bool remove_trailing_period;
diff --git a/GUI/GUI/GUI/Frame.cpp b/GUI/GUI/GUI/Frame.cpp
index 2ee6f1b..9a69651 100644
--- a/GUI/GUI/GUI/Frame.cpp
+++ b/GUI/GUI/GUI/Frame.cpp
@@ -64,6 +64,7 @@ namespace {
ID_PY_APP_ENABLE_LOCAL_BEEP,
ID_PY_APP_ENABLE_BROWSER_SRC,
ID_PY_APP_USE_CPU,
+ ID_PY_APP_USE_FLASH_ATTENTION,
ID_PY_APP_USE_BUILTIN,
ID_PY_APP_ENABLE_UWU_FILTER,
ID_PY_APP_REMOVE_TRAILING_PERIOD,
@@ -995,6 +996,16 @@ Frame::Frame()
);
py_app_use_cpu_ = py_app_use_cpu;
+ auto* py_app_use_flash_attention = new wxCheckBox(py_config_panel,
+ ID_PY_APP_USE_FLASH_ATTENTION, "Use flash attention");
+ py_app_use_flash_attention->SetValue(app_c_->use_flash_attention);
+ py_app_use_flash_attention->SetToolTip(
+ "If checked, the transcription engine will use flash "
+ "attention for inference. This is much faster and more "
+ "efficient, but requires a 3000 series GPU or higher."
+ );
+ py_app_use_flash_attention_ = py_app_use_flash_attention;
+
auto* py_app_use_builtin = new wxCheckBox(py_config_panel,
ID_PY_APP_USE_BUILTIN, "Use built-in chatbox");
py_app_use_builtin->SetValue(app_c_->use_builtin);
@@ -1112,6 +1123,8 @@ Frame::Frame()
/*flags=*/wxEXPAND);
sizer->Add(py_app_use_cpu, /*proportion=*/0,
/*flags=*/wxEXPAND);
+ sizer->Add(py_app_use_flash_attention, /*proportion=*/0,
+ /*flags=*/wxEXPAND);
sizer->Add(py_app_use_builtin, /*proportion=*/0,
/*flags=*/wxEXPAND);
sizer->Add(py_app_enable_uwu_filter, /*proportion=*/0,
@@ -1701,6 +1714,9 @@ void Frame::ApplyConfigToInputFields()
auto* py_app_use_cpu = static_cast<wxCheckBox*>(FindWindowById(ID_PY_APP_USE_CPU));
py_app_use_cpu->SetValue(app_c_->use_cpu);
+ auto* py_app_use_flash_attention = static_cast<wxCheckBox*>(FindWindowById(ID_PY_APP_USE_FLASH_ATTENTION));
+ py_app_use_flash_attention->SetValue(app_c_->use_flash_attention);
+
auto* py_app_use_builtin = static_cast<wxCheckBox*>(FindWindowById(ID_PY_APP_USE_BUILTIN));
py_app_use_builtin->SetValue(app_c_->use_builtin);
@@ -2450,6 +2466,7 @@ void Frame::OnAppStart(wxCommandEvent& event) {
const bool enable_local_beep = py_app_enable_local_beep_->GetValue();
const bool enable_browser_src = py_app_enable_browser_src_->GetValue();
const bool use_cpu = py_app_use_cpu_->GetValue();
+ const bool use_flash_attention = py_app_use_flash_attention_->GetValue();
const bool use_builtin = py_app_use_builtin_->GetValue();
const bool enable_uwu_filter = py_app_enable_uwu_filter_->GetValue();
const bool remove_trailing_period = py_app_remove_trailing_period_->GetValue();
@@ -2490,6 +2507,7 @@ void Frame::OnAppStart(wxCommandEvent& event) {
app_c_->enable_browser_src = enable_browser_src;
app_c_->browser_src_port = browser_src_port;
app_c_->use_cpu = use_cpu;
+ app_c_->use_flash_attention = use_flash_attention;
app_c_->use_builtin = use_builtin;
app_c_->enable_uwu_filter = enable_uwu_filter;
app_c_->remove_trailing_period = remove_trailing_period;
diff --git a/GUI/GUI/GUI/Frame.h b/GUI/GUI/GUI/Frame.h
index f509dae..fc8bac8 100644
--- a/GUI/GUI/GUI/Frame.h
+++ b/GUI/GUI/GUI/Frame.h
@@ -71,6 +71,7 @@ private:
wxCheckBox* py_app_enable_local_beep_;
wxCheckBox* py_app_enable_browser_src_;
wxCheckBox* py_app_use_cpu_;
+ wxCheckBox* py_app_use_flash_attention_;
wxCheckBox* py_app_use_builtin_;
wxCheckBox* py_app_enable_uwu_filter_;
wxCheckBox* py_app_remove_trailing_period_;