1
0
Fork 0
mirror of https://gitlab.com/niansa/discord_llama.git synced 2025-03-06 20:48:25 +01:00

Only init translator if really needed

This commit is contained in:
niansa 2023-04-23 16:11:57 +02:00
parent 024d7a06ce
commit af96377a6c

View file

@ -52,7 +52,7 @@ class Bot {
Timer last_message_timer;
std::shared_ptr<bool> stopping;
LM::InferencePool llm_pool;
Translator translator;
std::unique_ptr<Translator> translator;
std::vector<dpp::snowflake> my_messages;
std::unordered_map<dpp::snowflake, dpp::user> users;
std::thread::id llm_tid;
@ -109,8 +109,8 @@ class Bot {
// Must run in llama thread
std::string_view llm_translate_to_en(std::string_view text) {
ENSURE_LLM_THREAD();
// No need for translation if language is english already
if (language == "EN") {
// Skip if there is no translator
if (translator == nullptr) {
std::cout << "(" << language << ") " << text << std::endl;
return text;
}
@ -121,7 +121,7 @@ class Bot {
str_replace_in_place(fres, bot.me.username, "[43]");
// Run translation
try {
fres = translator.translate(fres, "EN", show_console_progress);
fres = translator->translate(fres, "EN", show_console_progress);
} catch (const LM::Inference::ContextLengthException&) {
// Handle potential context overflow error
return "(Translation impossible)";
@ -135,8 +135,8 @@ class Bot {
// Must run in llama thread
std::string_view llm_translate_from_en(std::string_view text) {
ENSURE_LLM_THREAD();
// No need for translation if language is english already
if (language == "EN") {
// Skip if there is no translator
if (translator == nullptr) {
std::cout << "(" << language << ") " << text << std::endl;
return text;
}
@ -147,7 +147,7 @@ class Bot {
str_replace_in_place(fres, bot.me.username, "[43]");
// Run translation
try {
fres = translator.translate(fres, language, show_console_progress);
fres = translator->translate(fres, language, show_console_progress);
} catch (const LM::Inference::ContextLengthException&) {
// Handle potential context overflow error
return "(Translation impossible)";
@ -159,7 +159,7 @@ class Bot {
}
LM::Inference::Params llm_get_translation_params() const {
auto fres = translator.get_params();
auto fres = translator->get_params();
fres.n_threads = config.threads;
fres.use_mlock = config.mlock;
return fres;
@ -348,13 +348,20 @@ public:
} config;
Bot(const Configuration& cfg) : config(cfg), bot(cfg.token), language(cfg.language),
llm_pool(cfg.pool_size, "discord_llama", !cfg.persistance), translator(cfg.translation_model, llm_get_translation_params()) {
llm_pool(cfg.pool_size, "discord_llama", !cfg.persistance) {
// Configure llm_pool
llm_pool.set_store_on_destruct(cfg.persistance);
// Initialize thread pool
tPool.init();
// Prepare translator
if (language != "EN") {
tPool.submit([this] () {
translator = std::make_unique<Translator>(config.translation_model, llm_get_translation_params());
});
}
// Prepare llm
tPool.submit(std::bind(&Bot::llm_init, this));