mirror of
https://gitlab.com/niansa/discord_llama.git
synced 2025-03-06 20:48:25 +01:00
Only init translator if really needed
This commit is contained in:
parent
024d7a06ce
commit
af96377a6c
1 changed files with 16 additions and 9 deletions
25
main.cpp
25
main.cpp
|
@ -52,7 +52,7 @@ class Bot {
|
|||
Timer last_message_timer;
|
||||
std::shared_ptr<bool> stopping;
|
||||
LM::InferencePool llm_pool;
|
||||
Translator translator;
|
||||
std::unique_ptr<Translator> translator;
|
||||
std::vector<dpp::snowflake> my_messages;
|
||||
std::unordered_map<dpp::snowflake, dpp::user> users;
|
||||
std::thread::id llm_tid;
|
||||
|
@ -109,8 +109,8 @@ class Bot {
|
|||
// Must run in llama thread
|
||||
std::string_view llm_translate_to_en(std::string_view text) {
|
||||
ENSURE_LLM_THREAD();
|
||||
// No need for translation if language is english already
|
||||
if (language == "EN") {
|
||||
// Skip if there is no translator
|
||||
if (translator == nullptr) {
|
||||
std::cout << "(" << language << ") " << text << std::endl;
|
||||
return text;
|
||||
}
|
||||
|
@ -121,7 +121,7 @@ class Bot {
|
|||
str_replace_in_place(fres, bot.me.username, "[43]");
|
||||
// Run translation
|
||||
try {
|
||||
fres = translator.translate(fres, "EN", show_console_progress);
|
||||
fres = translator->translate(fres, "EN", show_console_progress);
|
||||
} catch (const LM::Inference::ContextLengthException&) {
|
||||
// Handle potential context overflow error
|
||||
return "(Translation impossible)";
|
||||
|
@ -135,8 +135,8 @@ class Bot {
|
|||
// Must run in llama thread
|
||||
std::string_view llm_translate_from_en(std::string_view text) {
|
||||
ENSURE_LLM_THREAD();
|
||||
// No need for translation if language is english already
|
||||
if (language == "EN") {
|
||||
// Skip if there is no translator
|
||||
if (translator == nullptr) {
|
||||
std::cout << "(" << language << ") " << text << std::endl;
|
||||
return text;
|
||||
}
|
||||
|
@ -147,7 +147,7 @@ class Bot {
|
|||
str_replace_in_place(fres, bot.me.username, "[43]");
|
||||
// Run translation
|
||||
try {
|
||||
fres = translator.translate(fres, language, show_console_progress);
|
||||
fres = translator->translate(fres, language, show_console_progress);
|
||||
} catch (const LM::Inference::ContextLengthException&) {
|
||||
// Handle potential context overflow error
|
||||
return "(Translation impossible)";
|
||||
|
@ -159,7 +159,7 @@ class Bot {
|
|||
}
|
||||
|
||||
LM::Inference::Params llm_get_translation_params() const {
|
||||
auto fres = translator.get_params();
|
||||
auto fres = translator->get_params();
|
||||
fres.n_threads = config.threads;
|
||||
fres.use_mlock = config.mlock;
|
||||
return fres;
|
||||
|
@ -348,13 +348,20 @@ public:
|
|||
} config;
|
||||
|
||||
Bot(const Configuration& cfg) : config(cfg), bot(cfg.token), language(cfg.language),
|
||||
llm_pool(cfg.pool_size, "discord_llama", !cfg.persistance), translator(cfg.translation_model, llm_get_translation_params()) {
|
||||
llm_pool(cfg.pool_size, "discord_llama", !cfg.persistance) {
|
||||
// Configure llm_pool
|
||||
llm_pool.set_store_on_destruct(cfg.persistance);
|
||||
|
||||
// Initialize thread pool
|
||||
tPool.init();
|
||||
|
||||
// Prepare translator
|
||||
if (language != "EN") {
|
||||
tPool.submit([this] () {
|
||||
translator = std::make_unique<Translator>(config.translation_model, llm_get_translation_params());
|
||||
});
|
||||
}
|
||||
|
||||
// Prepare llm
|
||||
tPool.submit(std::bind(&Bot::llm_init, this));
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue