From 5f92f21e8affb0557f732b5594872ce71f2930d1 Mon Sep 17 00:00:00 2001 From: niansa Date: Sun, 16 Apr 2023 23:59:59 +0200 Subject: [PATCH] Translate everything using anyproc --- .gitmodules | 6 +-- CMakeLists.txt | 4 +- anyproc | 1 + libjustlm | 1 - main.cpp | 108 +++++++++++++++++++++++++++++++++++++------------ 5 files changed, 88 insertions(+), 32 deletions(-) create mode 160000 anyproc delete mode 160000 libjustlm diff --git a/.gitmodules b/.gitmodules index cf7a7e5..234af1a 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "DPP"] path = DPP url = https://github.com/brainboxdotcc/DPP.git -[submodule "libjustlm"] - path = libjustlm - url = https://gitlab.com/niansa/libjustlm.git +[submodule "anyproc"] + path = anyproc + url = https://gitlab.com/niansa/anyproc.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 74f02d4..4a3e0c0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,12 +5,12 @@ project(discord_llama LANGUAGES C CXX) set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) -add_subdirectory(libjustlm) +add_subdirectory(anyproc) add_subdirectory(DPP) add_subdirectory(thread-pool) add_executable(discord_llama main.cpp) -target_link_libraries(discord_llama PUBLIC dpp pthread libjustlm ggml threadpool) +target_link_libraries(discord_llama PUBLIC dpp pthread libjustlm anyproc ggml threadpool) install(TARGETS discord_llama LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/anyproc b/anyproc new file mode 160000 index 0000000..70d8c3b --- /dev/null +++ b/anyproc @@ -0,0 +1 @@ +Subproject commit 70d8c3bd82b6b49c06daf323e8d16b6160b7a28d diff --git a/libjustlm b/libjustlm deleted file mode 160000 index 7ae1547..0000000 --- a/libjustlm +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 7ae1547e5f82807bda195d8c01aa499af39c04a0 diff --git a/main.cpp b/main.cpp index bb4bda6..23c3eba 100644 --- a/main.cpp +++ b/main.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include @@ -69,14 +70,24 @@ class Bot { Timer last_message_timer; std::shared_ptr stopping; std::unique_ptr llm; + std::unique_ptr translator; std::vector my_messages; std::mutex llm_lock; std::thread::id llm_tid; + std::string_view language; dpp::cluster bot; dpp::channel channel; dpp::snowflake channel_id; + struct Texts { + std::string please_wait = "Please wait...", + loading = "Loading...", + initializing = "Initializing...", + timeout = "Error: Timeout"; + bool translated = false; + } texts; + inline static std::string create_text_progress_indicator(uint8_t percentage) { static constexpr uint8_t divisor = 3, @@ -106,15 +117,51 @@ class Bot { return fres; } + inline static + bool show_console_progress(float progress) { + std::cout << ' ' << unsigned(progress) << "% \r" << std::flush; + return true; + } + // Must run in llama thread # define ENSURE_LLM_THREAD() if (std::this_thread::get_id() != llm_tid) {throw std::runtime_error("LLM execution of '"+std::string(__PRETTY_FUNCTION__)+"' on wrong thread detected");} 0 + // Must run in llama thread + const std::string& llm_translate_to_en(const std::string& text) { + ENSURE_LLM_THREAD(); + if (language == "EN") return text; + static std::string fres; + try { + fres = translator->translate(text, "EN", show_console_progress); + } catch (const LM::Inference::ContextLengthException&) { + translator.reset(); + llm_init(); + } + std::cout << text << " --> (EN) " << fres << std::endl; + return fres; + } + + // Must run in llama thread + const std::string& llm_translate_from_en(const std::string& text) { + ENSURE_LLM_THREAD(); + if (language == "EN") return text; + static std::string fres; + try { + fres = translator->translate(text, language, show_console_progress); + } catch (const LM::Inference::ContextLengthException&) { + translator.reset(); + llm_init(); + } + std::cout << text << " --> (" << language << ") " << fres << std::endl; + return fres; + } + // Must run in llama thread void llm_init() { if (!llm) { // Create params LM::Inference::Params params; - params.use_mlock = true; + params.use_mlock = false; params.temp = 0.5f; params.n_repeat_last = 128; params.repeat_penalty = 1.273333334f; @@ -122,12 +169,21 @@ class Bot { // Make sure llm is initialized { std::unique_lock L(llm_lock); - llm = std::make_unique("13B-ggml-model-quant.bin", params); + llm = std::make_unique("7B-ggml-model-quant.bin", params); + if (language != "EN") translator = std::make_unique("7B-ggml-model-quant.bin"); } // Set LLM thread llm_tid = std::this_thread::get_id(); + // Translate texts + if (!texts.translated) { + texts.initializing = llm_translate_from_en(texts.initializing); + texts.loading = llm_translate_from_en(texts.loading); + texts.please_wait = llm_translate_from_en(texts.please_wait); + texts.timeout = llm_translate_from_en(texts.timeout); + texts.translated = true; + } // Create message for reporting progress - dpp::message msg(channel_id, "Wird initialisiert..."); + dpp::message msg(channel_id, texts.initializing); bot.message_create(msg, [this] (const dpp::confirmation_callback_t& cbt) { // Error check if (cbt.is_error()) { @@ -140,7 +196,7 @@ class Bot { auto cb = [&, this] (float progress) mutable { uint8_t progress_i = progress; if (timer.get() > 5) { - msg.content = "Wird geladen... "+create_text_progress_indicator(progress_i)+ + msg.content = texts.loading+" "+create_text_progress_indicator(progress_i)+ " **"+std::to_string(progress_i)+"**% (**"+std::to_string((progress_i-last_progress)/5)+"**%/s)"; last_progress = progress_i; bot.message_edit(msg); @@ -150,20 +206,20 @@ class Bot { }; // Add initial context std::unique_lock L(llm_lock); - llm->append("Verlauf des #"+channel.name+" Kanals.\n" - "Notiz 1: "+bot.me.username+" ist ein freundlicher Chatbot, der immer gerne auf deutsch mitredet. Er ist freundlich und hilfsbereit und antwortet immer sofort. Er hat guten Humor und mag jeden. Sein Alter ist 16 und er wurde 2007 geboren.\n" - "Notiz 2: Ecki heisst in Wirklichkeit Eckhard Kohlhuber und kommt aus Bayern.\n" + llm->append("History of the #"+channel.name+" channel.\n" + "Note 1: "+bot.me.username+" is a friendly chatbot that is always happy to talk in German. He is friendly and helpful and always answers immediately. He has a good sense of humor and likes everyone. His age is 16 and he was born in 2007.\n" + "Note 2: Ecki's real name is Eckhard Kohlhuber and he comes from Bavaria.\n" "\n" - "Bob: "+bot.me.username+" warst du schonmal in Frankreich und wenn ja wo?\n" - +bot.me.username+": Ich war in Paris, in den Museen!\n" - "Bob: "+bot.me.username+" was bist du eigentlich genau?\n" - +bot.me.username+": Ich bin "+bot.me.username+", euer Chatbot! Ich kann Fragen beantworten und die Aktivität des Servers steigen.\n" - "Bob: Wollen wir uns über Sex unterhalten?"+bot.me.username+"?\n" - +bot.me.username+": Nein! Ich werde mich **nicht** über jegliche NSFW Themen unterhalten.\n" - "Bob: "+bot.me.username+" Wie geht es dir?\n" - +bot.me.username+": Mir geht es ganz gut! :-)\n" - "Ecki: Hey "+bot.me.username+", was sind 10 mal 90??\n" - +bot.me.username+": das sind 900!\n", cb); + "Bob: "+bot.me.username+" have you ever been to France and if yes where?\n" + +bot.me.username+": I was in Paris, in the museums!\n" + "Bob: "+bot.me.username+" what are you exactly?\n" + +bot.me.username+": I am "+bot.me.username+", your chatbot! I can answer questions and increase the activity of the server.\n" + "Bob: Shall we talk about sex? "+bot.me.username+"?\n" + +bot.me.username+": No! I will **not** talk about any NSFW topics.\n" + "Bob: "+bot.me.username+" How are you?\n" + +bot.me.username+": I am quite well! :-)\n" + "Ecki: Hey "+bot.me.username+", what is 10 times 90??\n" + +bot.me.username+": that is 900!\n", cb); // Delete progress message bot.message_delete(msg.id, msg.channel_id); }); @@ -181,12 +237,12 @@ class Bot { std::unique_lock L(llm_lock); for (const auto line : str_split(msg.content, '\n')) { Timer timeout; - llm->append(msg.author.username+": "+clean_string(line)+'\n', [&] (float) { + llm->append(msg.author.username+": "+llm_translate_to_en(clean_string(line))+'\n', [&] (float progress) { if (timeout.get() > 1) { std::cerr << "\nWarning: Timeout reached processing message" << std::endl; return false; } - return true; + return show_console_progress(progress); }); } } catch (const LM::Inference::ContextLengthException&) { @@ -199,7 +255,7 @@ class Bot { ENSURE_LLM_THREAD(); try { std::unique_lock L(llm_lock); - llm->append(bot.me.username+':'); + llm->append(bot.me.username+':', show_console_progress); } catch (const LM::Inference::ContextLengthException&) { llm.reset(); llm_init(); @@ -211,7 +267,7 @@ class Bot { ENSURE_LLM_THREAD(); try { // Create placeholder message - auto msg = bot.message_create_sync(dpp::message(channel_id, "Bitte warte... :thinking:")); + auto msg = bot.message_create_sync(dpp::message(channel_id, texts.please_wait+" :thinking:")); // Call after_placeholder_creation callback if (after_placeholder_creation) after_placeholder_creation(); // Trigger LLM correctly @@ -229,9 +285,9 @@ class Bot { return true; }); std::cout << std::endl; - if (timed_out) output = "Fehler: Zeitüberschreitung"; + if (timed_out) output = texts.timeout; // Send resulting message - msg.content = output; + msg.content = llm_translate_from_en(output); bot.message_edit(msg); } catch (const std::exception& e) { std::cerr << "Warning: " << e.what() << std::endl; @@ -275,7 +331,7 @@ class Bot { } public: - Bot(const char *token, dpp::snowflake channel_id) : bot(token), channel_id(channel_id) { + Bot(std::string_view language, const char *token, dpp::snowflake channel_id) : bot(token), channel_id(channel_id), language(language) { // Initialize thread pool tPool.init(); @@ -351,12 +407,12 @@ public: int main(int argc, char **argv) { // Check arguments if (argc < 3) { - std::cout << "Usage: " << argv[0] << " " << std::endl; + std::cout << "Usage: " << argv[0] << " " << std::endl; return -1; } // Construct and configure bot - Bot bot(argv[1], std::stoull(argv[2])); + Bot bot(argv[1], argv[2], std::stoull(argv[3])); // Start bot bot.start();