#include "Random.hpp" #include "Timer.hpp" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include static std::vector str_split(std::string_view s, char delimiter, size_t times = -1) { std::vector to_return; decltype(s.size()) start = 0, finish = 0; while ((finish = s.find_first_of(delimiter, start)) != std::string_view::npos) { to_return.emplace_back(s.substr(start, finish - start)); start = finish + 1; if (to_return.size() == times) { break; } } to_return.emplace_back(s.substr(start)); return to_return; } static void str_replace_in_place(std::string& subject, std::string_view search, const std::string& replace) { size_t pos = 0; while ((pos = subject.find(search, pos)) != std::string::npos) { subject.replace(pos, search.length(), replace); pos += replace.length(); } } static inline std::string clean_string(std::string_view str) { std::string fres; for (const auto c : str) { if ((c >= 0x20 && c <= 0x7E) || c == '\n' || c == "ä"[0] || c == "ä"[1] || c == "ä"[2] || c == "ö"[0] || c == "ö"[1] || c == "ö"[2] || c == "ü"[0] || c == "ü"[1] || c == "ü"[2] || c == "Ä"[0] || c == "Ä"[1] || c == "Ä"[2] || c == "Ö"[0] || c == "Ö"[1] || c == "Ö"[2] || c == "Ü"[0] || c == "Ü"[1] || c == "Ü"[2] || c == "ß"[0] || c == "ß"[1] || c == "ß"[2]) { fres.push_back(c); } } return fres; } class Bot { RandomGenerator rng; ThreadPool tPool{1}; Timer last_message_timer; std::shared_ptr stopping; std::unique_ptr llm = nullptr; std::unique_ptr translator; std::vector my_messages; std::mutex llm_lock; std::thread::id llm_tid; std::string_view language; dpp::cluster bot; dpp::channel channel; dpp::snowflake channel_id; struct Texts { std::string please_wait = "Please wait...", loading = "Loading...", initializing = "Initializing...", timeout = "Error: Timeout"; bool translated = false; } texts; inline static std::string create_text_progress_indicator(uint8_t percentage) { static constexpr uint8_t divisor = 3, width = 100 / divisor; // Progress bar percentage lookup const static auto indicator_lookup = [] () consteval { std::array fres; for (uint8_t it = 0; it != 101; it++) { fres[it] = it / divisor; } return fres; }(); // Initialize string std::string fres; fres.resize(width+4); fres[0] = '`'; fres[1] = '['; // Append progress const uint8_t bars = indicator_lookup[percentage]; for (uint8_t it = 0; it != width; it++) { if (it < bars) fres[it+2] = '#'; else fres[it+2] = ' '; } // Finalize and return string fres[width+2] = ']'; fres[width+3] = '`'; return fres; } inline static bool show_console_progress(float progress) { std::cout << ' ' << unsigned(progress) << "% \r" << std::flush; return true; } // Must run in llama thread # define ENSURE_LLM_THREAD() if (std::this_thread::get_id() != llm_tid) {throw std::runtime_error("LLM execution of '"+std::string(__PRETTY_FUNCTION__)+"' on wrong thread detected");} 0 // Must run in llama thread const std::string& llm_translate_to_en(const std::string& text) { ENSURE_LLM_THREAD(); // No need for translation if language is english already if (language == "EN") return text; // I am optimizing heavily for the above case. This function always returns a reference so a trick is needed here static std::string fres; fres = text; // Replace bot username with [43] str_replace_in_place(fres, bot.me.username, "[43]"); // Run translation try { fres = translator->translate(fres, "EN", show_console_progress); } catch (const LM::Inference::ContextLengthException&) { // Handle potential context overflow error translator.reset(); llm_init(); return llm_translate_to_en(text); } // Replace [43] back with bot username str_replace_in_place(fres, "[43]", bot.me.username); std::cout << text << " --> (EN) " << fres << std::endl; return fres; } // Must run in llama thread const std::string& llm_translate_from_en(const std::string& text) { ENSURE_LLM_THREAD(); // No need for translation if language is english already if (language == "EN") return text; // I am optimizing heavily for the above case. This function always returns a reference so a trick is needed here static std::string fres; fres = text; // Replace bot username with [43] str_replace_in_place(fres, bot.me.username, "[43]"); // Run translation try { fres = translator->translate(fres, language, show_console_progress); } catch (const LM::Inference::ContextLengthException&) { // Handle potential context overflow error translator.reset(); llm_init(); return llm_translate_from_en(text); } // Replace [43] back with bot username str_replace_in_place(fres, "[43]", bot.me.username); std::cout << text << " --> (" << language << ") " << fres << std::endl; return fres; } // Must run in llama thread void llm_init() { if (!llm) { // Create params LM::Inference::Params params; params.use_mlock = false; params.temp = 0.5f; params.n_repeat_last = 128; params.repeat_penalty = 1.273333334f; // Make sure llm is initialized { std::unique_lock L(llm_lock); if (translator == nullptr && language != "EN") translator = std::make_unique("7B-ggml-model-quant.bin"); llm = std::make_unique("13B-ggml-model-quant.bin", params); } // Set LLM thread llm_tid = std::this_thread::get_id(); // Translate texts if (!texts.translated) { texts.please_wait = llm_translate_from_en(texts.please_wait); texts.initializing = llm_translate_from_en(texts.initializing); texts.loading = llm_translate_from_en(texts.loading); texts.timeout = llm_translate_from_en(texts.timeout); texts.translated = true; } // Create message for reporting progress dpp::message msg(channel_id, texts.initializing); bot.message_create(msg, [this] (const dpp::confirmation_callback_t& cbt) { // Error check if (cbt.is_error()) { throw std::runtime_error("Failed to send message to channel: "+cbt.get_error().message); } // Callback for reporting progress Timer timer; auto msg = cbt.get(); uint8_t last_progress = 0; auto cb = [&, this] (float progress) mutable { uint8_t progress_i = progress; if (timer.get() > 5) { msg.content = texts.loading+" "+create_text_progress_indicator(progress_i)+ " **"+std::to_string(progress_i)+"**% (**"+std::to_string((progress_i-last_progress)/5)+"**%/s)"; last_progress = progress_i; bot.message_edit(msg); timer.reset(); } return true; }; // Add initial context std::unique_lock L(llm_lock); llm->append("History of the #"+channel.name+" channel.\n" "Note 1: "+bot.me.username+" is a friendly chatbot that is always happy to talk. He is friendly and helpful and always answers immediately. He has a good sense of humor and likes everyone. His age is 16 and he was born in 2007.\n" "Note 2: Ecki's real name is Eckhard Kohlhuber and he comes from Bavaria.\n" "\n" "Bob: "+bot.me.username+" have you ever been to France and if yes where?\n" +bot.me.username+": I was in Paris, in the museums!\n" "Bob: "+bot.me.username+" what are you exactly?\n" +bot.me.username+": I am "+bot.me.username+", your chatbot! I can answer questions and increase the activity of the server.\n" "Bob: Shall we talk about sex? "+bot.me.username+"?\n" +bot.me.username+": No! I will **not** talk about any NSFW topics.\n" "Bob: "+bot.me.username+" How are you?\n" +bot.me.username+": I am quite well! :-)\n" "Ecki: Hey "+bot.me.username+", what is 10 times 90??\n" +bot.me.username+": that is 900!\n", cb); // Delete progress message bot.message_delete(msg.id, msg.channel_id); }); } } // Must run in llama thread void prompt_add_msg(const dpp::message& msg) { ENSURE_LLM_THREAD(); try { // Make sure message isn't too long if (msg.content.size() > 512) { return; } // Format and append line std::unique_lock L(llm_lock); for (const auto line : str_split(msg.content, '\n')) { Timer timeout; llm->append(msg.author.username+": "+llm_translate_to_en(clean_string(line))+'\n', [&] (float progress) { if (timeout.get() > 1) { std::cerr << "\nWarning: Timeout reached processing message" << std::endl; return false; } return show_console_progress(progress); }); } } catch (const LM::Inference::ContextLengthException&) { llm.reset(); llm_init(); } } // Must run in llama thread void prompt_add_trigger() { ENSURE_LLM_THREAD(); try { std::unique_lock L(llm_lock); llm->append(bot.me.username+':', show_console_progress); } catch (const LM::Inference::ContextLengthException&) { llm.reset(); llm_init(); } } // Must run in llama thread void reply(const std::function& after_placeholder_creation = nullptr) { ENSURE_LLM_THREAD(); try { // Create placeholder message auto msg = bot.message_create_sync(dpp::message(channel_id, texts.please_wait+" :thinking:")); // Call after_placeholder_creation callback if (after_placeholder_creation) after_placeholder_creation(); // Trigger LLM correctly prompt_add_trigger(); // Run model Timer timeout; bool timed_out = false; auto output = llm->run("\n", [&] (std::string_view str) { std::cout << str << std::flush; if (timeout.get() > 2) { timed_out = true; std::cerr << "\nWarning: Timeout reached generating message"; return false; } return true; }); std::cout << std::endl; if (timed_out) output = texts.timeout; // Send resulting message msg.content = llm_translate_from_en(output); bot.message_edit(msg); } catch (const std::exception& e) { std::cerr << "Warning: " << e.what() << std::endl; } } // Must run in llama thread bool attempt_reply(const dpp::message& msg, const std::function& after_placeholder_creation = nullptr) { ENSURE_LLM_THREAD(); // Decide randomly /*if (rng.getBool(0.075f)) { return reply(); }*/ // Reply if message contains username, mention or ID if (msg.content.find(bot.me.username) != std::string::npos) { reply(after_placeholder_creation); return true; } // Reply if message references user for (const auto msg_id : my_messages) { if (msg.message_reference.message_id == msg_id) { reply(after_placeholder_creation); return true; } } // Don't reply otherwise return false; } void enqueue_reply() { tPool.submit(std::bind(&Bot::reply, this, nullptr)); } void idle_auto_reply() { auto s = stopping; do { // Wait for a bit std::this_thread::sleep_for(std::chrono::minutes(5)); // Check if last message was more than 20 minutes ago if (last_message_timer.get() > 3) { // Force reply enqueue_reply(); } } while (!*s); } public: Bot(std::string_view language, const char *token, dpp::snowflake channel_id) : bot(token), channel_id(channel_id), language(language) { // Initialize thread pool tPool.init(); // Configure bot bot.on_log(dpp::utility::cout_logger()); bot.intents = dpp::i_guild_messages | dpp::i_message_content; // Set callbacks bot.on_ready([=, this] (const dpp::ready_t&) { // Get channel bot.channel_get(channel_id, [=, this] (const dpp::confirmation_callback_t& cbt) { if (cbt.is_error()) { throw std::runtime_error("Failed to get channel: "+cbt.get_error().message); } channel = cbt.get(); // Initialize random generator rng.seed(bot.me.id); // Append initial prompt tPool.submit(std::bind(&Bot::llm_init, this)); // Start idle auto reply thread std::thread(std::bind(&Bot::idle_auto_reply, this)).detach(); }); }); bot.on_message_create([=, this] (const dpp::message_create_t& event) { // Ignore messages before full startup if (!llm) return; // Make sure message source is correct if (event.msg.channel_id != channel_id) return; // Make sure message has content if (event.msg.content.empty()) return; // Reset last message timer last_message_timer.reset(); // Ignore own messages if (event.msg.author.id == bot.me.id) { // Add message to list of own messages my_messages.push_back(event.msg.id); return; } // Move on in another thread tPool.submit([this, msg = event.msg] () mutable { try { // Replace bot mentions with bot username str_replace_in_place(msg.content, "<@"+std::to_string(bot.me.id)+'>', bot.me.username); if (msg.content == "!trigger") { // Delete message bot.message_delete(msg.id, msg.channel_id); // Send a reply reply(); } else { tPool.submit([=, this] () { // Attempt to send a reply bool replied = attempt_reply(msg, [=, this] () { // Append message to history prompt_add_msg(msg); }); // If none was send, still append message to history. if (!replied) { // Append message to history prompt_add_msg(msg); } }); } } catch (const std::exception& e) { std::cerr << "Warning: " << e.what() << std::endl; } }); }); } void start() { stopping = std::make_shared(false); bot.start(dpp::st_wait); *stopping = true; } }; int main(int argc, char **argv) { // Check arguments if (argc < 4) { std::cout << "Usage: " << argv[0] << " " << std::endl; return -1; } // Construct and configure bot Bot bot(argv[1], argv[2], std::stoull(argv[3])); // Start bot bot.start(); }