From af96377a6ca3b7164675c75a41c99f981b4cac64 Mon Sep 17 00:00:00 2001 From: niansa Date: Sun, 23 Apr 2023 16:11:57 +0200 Subject: [PATCH] Only init translator if really needed --- main.cpp | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/main.cpp b/main.cpp index 611fa29..56bc6db 100644 --- a/main.cpp +++ b/main.cpp @@ -52,7 +52,7 @@ class Bot { Timer last_message_timer; std::shared_ptr stopping; LM::InferencePool llm_pool; - Translator translator; + std::unique_ptr translator; std::vector my_messages; std::unordered_map users; std::thread::id llm_tid; @@ -109,8 +109,8 @@ class Bot { // Must run in llama thread std::string_view llm_translate_to_en(std::string_view text) { ENSURE_LLM_THREAD(); - // No need for translation if language is english already - if (language == "EN") { + // Skip if there is no translator + if (translator == nullptr) { std::cout << "(" << language << ") " << text << std::endl; return text; } @@ -121,7 +121,7 @@ class Bot { str_replace_in_place(fres, bot.me.username, "[43]"); // Run translation try { - fres = translator.translate(fres, "EN", show_console_progress); + fres = translator->translate(fres, "EN", show_console_progress); } catch (const LM::Inference::ContextLengthException&) { // Handle potential context overflow error return "(Translation impossible)"; @@ -135,8 +135,8 @@ class Bot { // Must run in llama thread std::string_view llm_translate_from_en(std::string_view text) { ENSURE_LLM_THREAD(); - // No need for translation if language is english already - if (language == "EN") { + // Skip if there is no translator + if (translator == nullptr) { std::cout << "(" << language << ") " << text << std::endl; return text; } @@ -147,7 +147,7 @@ class Bot { str_replace_in_place(fres, bot.me.username, "[43]"); // Run translation try { - fres = translator.translate(fres, language, show_console_progress); + fres = translator->translate(fres, language, show_console_progress); } catch (const LM::Inference::ContextLengthException&) { // Handle potential context overflow error return "(Translation impossible)"; @@ -159,7 +159,7 @@ class Bot { } LM::Inference::Params llm_get_translation_params() const { - auto fres = translator.get_params(); + auto fres = translator->get_params(); fres.n_threads = config.threads; fres.use_mlock = config.mlock; return fres; @@ -348,13 +348,20 @@ public: } config; Bot(const Configuration& cfg) : config(cfg), bot(cfg.token), language(cfg.language), - llm_pool(cfg.pool_size, "discord_llama", !cfg.persistance), translator(cfg.translation_model, llm_get_translation_params()) { + llm_pool(cfg.pool_size, "discord_llama", !cfg.persistance) { // Configure llm_pool llm_pool.set_store_on_destruct(cfg.persistance); // Initialize thread pool tPool.init(); + // Prepare translator + if (language != "EN") { + tPool.submit([this] () { + translator = std::make_unique(config.translation_model, llm_get_translation_params()); + }); + } + // Prepare llm tPool.submit(std::bind(&Bot::llm_init, this));