#include "utils.hpp" #include "config.hpp" #include "sqlite_modern_cpp/sqlite_modern_cpp.h" #include <string> #include <string_view> #include <stdexcept> #include <fstream> #include <thread> #include <chrono> #include <functional> #include <vector> #include <unordered_map> #include <filesystem> #include <optional> #include <mutex> #include <memory> #include <utility> #include <dpp/dpp.h> #include <fmt/format.h> #include <justlm.hpp> #include <justlm_pool.hpp> #include <cosched2/scheduled_thread.hpp> #include <cosched2/scheduler_mutex.hpp> class Bot { CoSched::ScheduledThread sched_thread; LM::InferencePool llm_pool; std::vector<dpp::snowflake> my_messages; std::unordered_map<dpp::snowflake, dpp::user> users; std::thread::id llm_tid; utils::Timer cleanup_timer; sqlite::database db; std::mutex command_completion_buffer_mutex; std::unordered_map<dpp::snowflake, dpp::slashcommand_t> command_completion_buffer; std::mutex thread_embeds_mutex; std::unordered_map<dpp::snowflake, dpp::message> thread_embeds; dpp::cluster bot; public: struct BotChannelConfig { const std::string *model_name; const Configuration::Model *model; bool instruct_mode = false; }; private: Configuration& config; inline static bool show_console_progress(float progress) { std::cout << ' ' << unsigned(progress) << "% \r" << std::flush; return true; } // Must run in llama thread # define ENSURE_LLM_THREAD() if (std::this_thread::get_id() != llm_tid) {throw std::runtime_error("LLM execution of '"+std::string(__PRETTY_FUNCTION__)+"' on wrong thread detected");} 0 LM::Inference::Params llm_get_params(bool instruct_mode = false) const { return { .n_threads = config.threads, .n_ctx = config.ctx_size, .n_repeat_last = unsigned(instruct_mode?0:256), .temp = 0.3f, .repeat_penalty = instruct_mode?1.0f:1.372222224f, .use_mlock = config.mlock }; } bool check_timeout(utils::Timer& timer, const dpp::message& msg, uint8_t& slow) { auto passed = timer.get<std::chrono::seconds>(); if (passed > config.timeout) { auto& task = CoSched::Task::get_current(); // Calculate new priority const CoSched::Priority prio = task.get_priority()-5; // Make sure it's above minimum if (prio < CoSched::PRIO_LOWEST) { // Stop slow = 2; return false; } // Decrease priority task.set_priority(prio); // Add snail reaction if (!slow) { slow = 1; bot.message_add_reaction(msg, "🐌"); } // Reset timeout timer timer.reset(); } // No need to stop return true; } // Must run in llama thread bool llm_restart(const std::shared_ptr<LM::Inference>& inference, const BotChannelConfig& channel_cfg) { ENSURE_LLM_THREAD(); // Deserialize init cache if not instruct mode without prompt file if (channel_cfg.instruct_mode && config.instruct_prompt_file == "none") return true; const auto path = (*channel_cfg.model_name)+(channel_cfg.instruct_mode?"_instruct_init_cache":"_init_cache"); std::ifstream f(path, std::ios::binary); if (!f) { std::cerr << "Warning: Failed to init cache open file, consider regeneration: " << path << std::endl; return false; } if (!inference->deserialize(f)) { return false; } // Set params inference->params.n_ctx_window_top_bar = inference->get_context_size(); inference->params.scroll_keep = float(config.scroll_keep) * 0.01f; return true; } // Must run in llama thread std::shared_ptr<LM::Inference> llm_start(dpp::snowflake id, const BotChannelConfig& channel_cfg) { ENSURE_LLM_THREAD(); // Get or create inference auto inference = llm_pool.create_inference(id, channel_cfg.model->weights_path, llm_get_params(channel_cfg.instruct_mode)); if (!llm_restart(inference, channel_cfg)) { std::cerr << "Warning: Failed to deserialize cache: " << inference->get_last_error() << std::endl; return nullptr; } return inference; } // Must run in llama thread std::shared_ptr<LM::Inference> llm_get_inference(dpp::snowflake id, const BotChannelConfig& channel_cfg) { ENSURE_LLM_THREAD(); // Get inference auto fres = llm_pool.get_inference(id); if (!fres) { // Start new inference fres = llm_start(id, channel_cfg); // Check for error if (!fres) { return nullptr; } } // Set scroll callback fres->set_scroll_callback([msg = dpp::message(), channel_id = id] (float progress) { std::cout << "WARNING: " << channel_id << " is scrolling! " << progress << "% \r" << std::flush; return true; }); // Return inference return fres; } // Must run in llama thread void llm_init() { // Run at high priority CoSched::Task::get_current().set_priority(CoSched::PRIO_HIGHER); // Set LLM thread llm_tid = std::this_thread::get_id(); // Set scroll callback auto scroll_cb = [] (float) { std::cerr << "Error: Prompt doesn't fit into max. context size!" << std::endl; abort(); return false; }; // Build init caches std::string filename; for (const auto& [model_name, model_config] : config.models) { //TODO: Add hashes to regenerate these as needed // Standard prompt filename = model_name+"_init_cache"; if (model_config.is_non_instruct_mode_allowed() && !std::filesystem::exists(filename) && config.prompt_file != "none") { std::cout << "Building init_cache for "+model_name+"..." << std::endl; auto llm = LM::Inference::construct(model_config.weights_path, llm_get_params()); // Add initial context std::string prompt; { // Read whole file std::ifstream f(config.prompt_file); if (!f) { // Clean up and abort on error std::cerr << "Error: Failed to open prompt file." << std::endl; abort(); } std::ostringstream sstr; sstr << f.rdbuf(); prompt = sstr.str(); } // Append using namespace fmt::literals; if (prompt.back() != '\n') prompt.push_back('\n'); llm->set_scroll_callback(scroll_cb); llm->append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username), show_console_progress); // Serialize end result std::ofstream f(filename, std::ios::binary); llm->serialize(f); } // Instruct prompt filename = model_name+"_instruct_init_cache"; if (model_config.is_instruct_mode_allowed() && !std::filesystem::exists(filename)) { std::cout << "Building instruct_init_cache for "+model_name+"..." << std::endl; auto llm = LM::Inference::construct(model_config.weights_path, llm_get_params()); // Add initial context std::string prompt; if (config.instruct_prompt_file != "none" && !model_config.no_instruct_prompt) { // Read whole file std::ifstream f(config.instruct_prompt_file); if (!f) { // Clean up and abort on error std::cerr << "Error: Failed to open instruct prompt file." << std::endl; abort(); } std::ostringstream sstr; sstr << f.rdbuf(); prompt = sstr.str(); // Append instruct prompt using namespace fmt::literals; if (prompt.back() != '\n' && !model_config.no_extra_linebreaks) prompt.push_back('\n'); llm->set_scroll_callback(scroll_cb); llm->append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username, "bot_prompt"_a=model_config.bot_prompt, "user_prompt"_a=model_config.user_prompt)+(model_config.no_extra_linebreaks?"":"\n\n")+model_config.user_prompt, show_console_progress); } // Append user prompt llm->append(model_config.user_prompt); // Serialize end result std::ofstream f(filename, std::ios::binary); llm->serialize(f); } } // Report complete init std::cout << "Init done!" << std::endl; } // Must run in llama thread bool prompt_add_msg(const dpp::message& msg, const BotChannelConfig& channel_cfg) { ENSURE_LLM_THREAD(); // Get inference auto inference = llm_get_inference(msg.channel_id, channel_cfg); if (!inference) { std::cerr << "Warning: Failed to get inference" << std::endl; return false; } std::string prefix; // Define callback for console progress and timeout utils::Timer timeout; bool timeout_exceeded = false; uint8_t slow = 0; const auto cb = [&] (float progress) { // Check for timeout if (!check_timeout(timeout, msg, slow)) return false; // Show progress in console return show_console_progress(progress); }; // Instruct mode user prompt if (channel_cfg.instruct_mode) { // Append line as-is if (!inference->append((channel_cfg.model->no_extra_linebreaks?"\n":"\n\n") +msg.content +(channel_cfg.model->no_extra_linebreaks?"":"\n"), cb)) { std::cerr << "Warning: Failed to append user prompt: " << inference->get_last_error() << std::endl; return false; } } else { // Format and append lines for (const auto line : utils::str_split(msg.content, '\n')) { if (!inference->append(msg.author.username+": "+std::string(line)+'\n', cb)) { std::cerr << "Warning: Failed to append user prompt (single line): " << inference->get_last_error() << std::endl; return false; } } } // Append line break on timeout if (timeout_exceeded) return inference->append("\n"); return true; } // Must run in llama thread bool prompt_add_trigger(const std::shared_ptr<LM::Inference>& inference, const BotChannelConfig& channel_cfg) { ENSURE_LLM_THREAD(); if (channel_cfg.instruct_mode) { return inference->append((channel_cfg.model->no_extra_linebreaks?"":"\n") +channel_cfg.model->bot_prompt +(channel_cfg.model->no_extra_linebreaks?"\n":"\n\n")); } else { return inference->append(bot.me.username+':', show_console_progress); } } // Must run in llama thread void reply(dpp::snowflake id, dpp::message& new_msg, const BotChannelConfig& channel_cfg) { ENSURE_LLM_THREAD(); // Get inference auto inference = llm_get_inference(id, channel_cfg); if (!inference) { std::cerr << "Warning: Failed to get inference" << std::endl; return; } // Trigger LLM correctly if (!prompt_add_trigger(inference, channel_cfg)) { std::cerr << "Warning: Failed to add trigger to prompt: " << inference->get_last_error() << std::endl; return; } if (CoSched::Task::get_current().is_dead()) { return; } // Run model utils::Timer timeout; utils::Timer edit_timer; new_msg.content.clear(); const std::string reverse_prompt = channel_cfg.instruct_mode?channel_cfg.model->user_prompt:"\n"; uint8_t slow = 0; bool response_too_long = false; auto output = inference->run(reverse_prompt, [&] (std::string_view token) { std::cout << token << std::flush; // Check for timeout if (!check_timeout(timeout, new_msg, slow)) return false; // Make sure message isn't too long if (new_msg.content.size() > 1995-config.texts.length_error.size()) { response_too_long = true; return false; } // Edit live if (config.live_edit) { new_msg.content += token; if (edit_timer.get<std::chrono::seconds>() > 3) { try { bot.message_edit(new_msg); } catch (...) {} edit_timer.reset(); } } return true; }); if (output.empty()) { std::cerr << "Warning: Failed to generate message: " << inference->get_last_error() << std::endl; output = '<'+config.texts.empty_response+'>'; } std::cout << std::endl; // Handle message length error if (response_too_long) { output += "...\n"+config.texts.length_error; } // Handle timeout else if (slow == 2) { output += "...\n"+config.texts.timeout; } // Handle termination else if (CoSched::Task::get_current().is_dead()) { output += "...\n"+config.texts.terminated; } // Send resulting message new_msg.content = std::move(output); try { bot.message_edit(new_msg); } catch (...) {} // Tell model about length error if (response_too_long) { inference->append("... Response interrupted due to length error"); } // Prepare for next message if (!channel_cfg.instruct_mode || !channel_cfg.model->no_extra_linebreaks) { inference->append("\n"); } if (channel_cfg.instruct_mode && channel_cfg.model->emits_eos) { inference->append("\n"+channel_cfg.model->user_prompt); } } bool check_should_reply(const dpp::message& msg) { // Reply if message contains username, mention or ID if (msg.content.find(bot.me.username) != std::string::npos) { return true; } // Reply if message references user for (const auto msg_id : my_messages) { if (msg.message_reference.message_id == msg_id) { return true; } } // Reply at random if (config.random_response_chance) { if (!(unsigned(msg.id.get_creation_time()) % config.random_response_chance)) { return true; } } // Don't reply otherwise return false; } bool is_on_own_shard(dpp::snowflake id) const { return (unsigned(id.get_creation_time()) % config.shard_count) == config.shard_id; } void cleanup() { // Clean up InferencePool if (config.max_context_age) llm_pool.cleanup(config.max_context_age); // Reset timer cleanup_timer.reset(); } void attempt_cleanup() { // Run cleanup if enough time has passed if (cleanup_timer.get<std::chrono::seconds>() > config.max_context_age / 4) { cleanup(); } } std::string create_thread_name(const std::string& model_name, bool instruct_mode) const { return "Chat with "+model_name+" " // Model name +(instruct_mode?"":"(Non Instruct mode)") // Instruct mode +(config.shard_count!=1?(" #"+std::to_string(config.shard_id)):""); // Shard ID } dpp::embed create_chat_embed(dpp::snowflake guild_id, dpp::snowflake thread_id, const std::string& model_name, bool instruct_mode, const dpp::user& author, std::string_view first_message = "") const { dpp::embed embed; // Create embed embed.set_title(create_thread_name(model_name, instruct_mode)) .set_description("[Open the chat](https://discord.com/channels/"+std::to_string(guild_id)+'/'+std::to_string(thread_id)+')') .set_footer(dpp::embed_footer().set_text("Started by "+author.format_username())) .set_color(utils::get_unique_color(model_name)); // Add first message if any if (!first_message.empty()) { // Make sure it's not too long std::string shorted(utils::max_words(first_message, 12)); if (shorted.size() != first_message.size()) { shorted += "..."; } embed.description += "\n\n> "+shorted; } // Warn about non-instruct mode if (instruct_mode == false) { embed.description += "\n\n**In the selected mode, the quality is highly degraded**, but the conversation more humorous. Please avoid this if you want helpful responses or want to evaluate the models quality."; } // Return final result return embed; } // This function is responsible for sharding thread creation // A bit ugly but a nice way to avoid having to communicate over any other means than just the Discord API bool command_completion_handler(dpp::slashcommand_t&& event, dpp::channel *thread = nullptr) { // Stop if this is not the correct shard for thread creation if (thread == nullptr) { // But register this command first std::scoped_lock L(command_completion_buffer_mutex); command_completion_buffer.emplace(event.command.id, std::move(event)); // And then actually stop if (!is_on_own_shard(event.command.channel_id)) return false; } // Get model by name auto res = config.models.find(event.command.get_command_name()); if (res == config.models.end()) { // Model does not exit, delete corresponding command bot.global_command_delete(event.command.get_command_interaction().id); return false; } const auto& [model_name, model_config] = *res; // Get weather to enable instruct mode bool instruct_mode; { const auto& instruct_mode_param = event.get_parameter("instruct_mode"); if (instruct_mode_param.index()) { instruct_mode = std::get<bool>(instruct_mode_param); } else { instruct_mode = model_config.instruct_mode_policy != Configuration::Model::InstructModePolicy::Forbid; } } // Create thread if it doesn't exist or update it if it does if (thread == nullptr) { bot.thread_create(std::to_string(event.command.id), event.command.channel_id, 1440, dpp::CHANNEL_PUBLIC_THREAD, true, 15, [this, event, model_name = res->first] (const dpp::confirmation_callback_t& ccb) { // Check for error if (ccb.is_error()) { std::cout << "Thread creation failed: " << ccb.get_error().message << std::endl; event.reply(dpp::message(config.texts.thread_create_fail).set_flags(dpp::message_flags::m_ephemeral)); return; } std::cout << "Responsible for creating thread: " << ccb.get<dpp::thread>().id << std::endl; // Report success event.reply(dpp::message("Okay!").set_flags(dpp::message_flags::m_ephemeral)); }); } else { bool this_shard = is_on_own_shard(thread->id); // Add thread to database db << "INSERT INTO threads (id, model, instruct_mode, this_shard) VALUES (?, ?, ?, ?);" << std::to_string(thread->id) << model_name << instruct_mode << this_shard; // Stop if this is not the correct shard for thread finalization if (!this_shard) return false; // Set name std::cout << "Responsible for finalizing thread: " << thread->id << std::endl; thread->name = create_thread_name(model_name, instruct_mode); bot.channel_edit(*thread); // Send embed const auto embed = create_chat_embed(event.command.guild_id, thread->id, model_name, instruct_mode, event.command.usr); bot.message_create(dpp::message(event.command.channel_id, embed), [this, thread_id = thread->id] (const dpp::confirmation_callback_t& ccb) { // Check for error if (ccb.is_error()) { std::cerr << "Warning: Failed to create embed: " << ccb.get_error().message << std::endl; return; } // Get message const auto& msg = ccb.get<dpp::message>(); // Add to embed list std::scoped_lock L(thread_embeds_mutex); thread_embeds[thread_id] = msg; }); } return true; } public: Bot(decltype(config) cfg) : llm_pool(cfg.pool_size, "discord_llama", !cfg.persistance), db("database.sqlite3"), bot(cfg.token), config(cfg) { // Initialize database db << "CREATE TABLE IF NOT EXISTS threads (" " id TEXT PRIMARY KEY NOT NULL," " model TEXT," " instruct_mode INTEGER," " this_shard INTEGER," " UNIQUE(id)" ");"; // Start Scheduled Thread sched_thread.start(); // Configure bot bot.on_log(dpp::utility::cout_logger()); bot.intents = dpp::i_guild_messages | dpp::i_message_content | dpp::i_message_content; // Set callbacks bot.on_ready([=, this] (const dpp::ready_t&) { //TODO: Consider removal std::cout << "Connected to Discord." << std::endl; // Register chat command once if (dpp::run_once<struct register_bot_commands>()) { auto register_command = [this] (dpp::slashcommand&& c) { bot.global_command_edit(c, [this, c] (const dpp::confirmation_callback_t& ccb) { if (ccb.is_error()) bot.global_command_create(c); }); }; // Register model commands for (const auto& [name, model] : config.models) { // Create command dpp::slashcommand command(name, "Start a chat with me", bot.me.id); // Add instruct mode option if (model.instruct_mode_policy == Configuration::Model::InstructModePolicy::Allow) { command.add_option(dpp::command_option(dpp::co_boolean, "instruct_mode", "Defaults to \"True\" for best output quality. Weather to enable instruct mode", false)); } // Register command register_command(std::move(command)); } // Register other commands register_command(dpp::slashcommand("ping", "Check my status", bot.me.id)); register_command(dpp::slashcommand("reset", "Reset this conversation", bot.me.id)); register_command(dpp::slashcommand("tasklist", "Get list of tasks", bot.me.id)); //register_command(dpp::slashcommand("taskkill", "Kill a task", bot.me.id)); TODO } if (dpp::run_once<class LM::Inference>()) { // Prepare llm sched_thread.create_task("Language Model Initialization", [this] () -> void { llm_init(); }); } }); bot.on_slashcommand([=, this](dpp::slashcommand_t event) { const auto invalidate_event = [this] (const dpp::slashcommand_t& event) { if (is_on_own_shard(event.command.channel_id)) { event.thinking(true, [this, event] (const dpp::confirmation_callback_t& ccb) { event.delete_original_response(); }); } }; // Process basic commands const auto& command_name = event.command.get_command_name(); if (command_name == "ping") { // Sender message if (is_on_own_shard(event.command.channel_id)) { bot.message_create(dpp::message(event.command.channel_id, "Ping from user "+event.command.usr.format_username()+'!')); } // Recipient message bot.message_create(dpp::message(event.command.channel_id, "Pong from shard "+std::to_string(config.shard_id+1)+'/'+std::to_string(config.shard_count)+'!')); // Finalize invalidate_event(event); return; } else if (command_name == "reset") { // Delete inference from pool sched_thread.create_task("Language Model Inference Pool", [this, id = event.command.channel_id, user = event.command.usr] () -> void { CoSched::Task::get_current().user_data = std::move(user); llm_pool.delete_inference(id); }); // Sender message if (is_on_own_shard(event.command.channel_id)) { bot.message_create(dpp::message(event.command.channel_id, "Conversation was reset by "+event.command.usr.format_username()+'!')); } // Finalize invalidate_event(event); return; } else if (command_name == "tasklist") { // Build task list sched_thread.create_task("tasklist", [this, event, id = event.command.channel_id, user = event.command.usr] () -> void { auto& task = CoSched::Task::get_current(); task.user_data = std::move(user); // Set priority to max task.set_priority(CoSched::PRIO_REALTIME); // Header std::string str = "**__Task List on Shard "+std::to_string(config.shard_id)+"__**\n"; // Produce list for (const auto& task : task.get_scheduler().get_tasks()) { // Get user const dpp::user *user = nullptr; { if (task->user_data.has_value()) { user = &std::any_cast<const dpp::user&>(task->user_data); } } // Append line str += fmt::format("- `{}` (State: **{}**, Priority: **{}**, User: **{}**)\n", task->get_name(), task->is_suspended()?"suspended":task->get_state_string(), task->get_priority(), user?user->format_username():bot.me.format_username()); } // Delete original thinking response if (is_on_own_shard(event.command.channel_id)) { event.delete_original_response(); } // Send list bot.message_create(dpp::message(id, str)); return; }); // Finalize if (is_on_own_shard(event.command.channel_id)) { event.thinking(false); } return; } // Run command completion handler command_completion_handler(std::move(event)); }); bot.on_message_create([=, this](const dpp::message_create_t&) { // Attempt cleanup attempt_cleanup(); }); bot.on_message_create([=, this](const dpp::message_create_t& event) { // Check that this is for thread creation if (event.msg.type != dpp::mt_thread_created) return; // Get thread that was created bot.channel_get(event.msg.id, [this, msg_id = event.msg.id, channel_id = event.msg.channel_id] (const dpp::confirmation_callback_t& ccb) { // Stop on error if (ccb.is_error()) return; // Get thread auto thread = ccb.get<dpp::channel>(); // Attempt to get command ID dpp::snowflake command_id; try { command_id = thread.name; } catch (...) { return; } // Find command std::scoped_lock L(command_completion_buffer_mutex); auto res = command_completion_buffer.find(command_id); if (res == command_completion_buffer.end()) { return; } // Complete command auto handled = command_completion_handler(std::move(res->second), &thread); // Remove command from buffer command_completion_buffer.erase(res); // Delete this message if we handled it if (handled) bot.message_delete(msg_id, channel_id); }); }); bot.on_message_create([=, this] (const dpp::message_create_t& event) { // Update user cache users[event.msg.author.id] = event.msg.author; // Make sure message has content if (event.msg.content.empty()) return; // Ignore messges from channel on another shard bool this_shard = is_on_own_shard(event.msg.channel_id); db << "SELECT this_shard FROM threads " "WHERE id = ?;" << std::to_string(event.msg.channel_id) >> [&](int _this_shard) { this_shard = _this_shard; }; if (!this_shard) return; // Ignore own messages if (event.msg.author.id == bot.me.id) { // Add message to list of own messages my_messages.push_back(event.msg.id); return; } // Process message try { // Copy message dpp::message msg = event.msg; // Replace bot mentions with bot username utils::str_replace_in_place(msg.content, "<@"+std::to_string(bot.me.id)+'>', bot.me.username); // Replace all other known users for (const auto& [user_id, user] : users) { utils::str_replace_in_place(msg.content, "<@"+std::to_string(user_id)+'>', user.username); } // Get channel config BotChannelConfig channel_cfg; // Attempt to find thread first... bool in_bot_thread = false, model_missing = false; db << "SELECT model, instruct_mode, this_shard FROM threads " "WHERE id = ?;" << std::to_string(msg.channel_id) >> [&](const std::string& model_name, int instruct_mode) { in_bot_thread = true; channel_cfg.instruct_mode = instruct_mode; // Find model auto res = config.models.find(model_name); if (res == config.models.end()) { bot.message_create(dpp::message(msg.channel_id, config.texts.model_missing)); model_missing = true; return; } channel_cfg.model_name = &res->first; channel_cfg.model = &res->second; }; if (model_missing) return; // Otherwise just fall back to the default model config if allowed if (!in_bot_thread) { if (config.threads_only) return; channel_cfg.model_name = &config.default_inference_model; channel_cfg.model = config.default_inference_model_cfg; } // Append message sched_thread.create_task("Language Model Inference ("+*channel_cfg.model_name+" at "+std::to_string(msg.channel_id)+")", [=, this] () -> void { CoSched::Task::get_current().user_data = msg.author; // Create initial message dpp::message placeholder_msg(msg.channel_id, config.texts.please_wait+" :thinking:"); // Get task auto &task = CoSched::Task::get_current(); // Await previous completion while (true) { // Check that there are no other tasks with the same name bool is_unique = true; bool any_non_suspended = false; for (const auto& other_task : task.get_scheduler().get_tasks()) { if (&task == other_task.get()) continue; if (task.get_name() == other_task->get_name()) { is_unique = false; } else if (!other_task->is_suspended()) { any_non_suspended = true; } } // Stop looking if there is no task that isn't suspended if (!any_non_suspended) break; // Stop looking if task is unique if (is_unique) break; // Suspend, we'll be woken up by that other task task.set_suspended(true); if (!task.yield()) return; } // Check if message should reply bool should_reply = false; if (in_bot_thread) { should_reply = true; } else if (msg.content == "!trigger") { bot.message_delete(msg.id, msg.channel_id); should_reply = true; } else { should_reply = check_should_reply(msg); } if (should_reply) { // Send placeholder placeholder_msg = bot.message_create_sync(placeholder_msg); // Add user message if (!prompt_add_msg(msg, channel_cfg)) { std::cerr << "Warning: Failed to add user message, not going to reply" << std::endl; return; } // Send a reply reply(msg.channel_id, placeholder_msg, channel_cfg); } else { // Add user message if (!prompt_add_msg(msg, channel_cfg)) { std::cerr << "Warning: Failed to add user message" << std::endl; return; } } // Unsuspend other tasks with same name for (const auto& other_task : task.get_scheduler().get_tasks()) { if (task.get_name() == other_task->get_name()) { other_task->set_suspended(false); } } }); // Find thread embed std::scoped_lock L(thread_embeds_mutex); auto res = thread_embeds.find(msg.channel_id); if (res == thread_embeds.end()) { return; } // Update that embed auto embed_msg = res->second; embed_msg.embeds[0] = create_chat_embed(msg.guild_id, msg.channel_id, *channel_cfg.model_name, channel_cfg.instruct_mode, msg.author, msg.content); bot.message_edit(embed_msg); // Remove thread embed linkage from vector thread_embeds.erase(res); } catch (const std::exception& e) { std::cerr << "Warning: " << e.what() << std::endl; } }); } void start() { cleanup(); bot.start(dpp::st_wait); } void stop_prepare() { if (config.persistance) { sched_thread.create_task("Language Model Shutdown", [=, this] () -> void { llm_pool.store_all(); }); } sched_thread.wait(); } }; int main(int argc, char **argv) { // Parse configuration Configuration cfg; cfg.parse_configs(argc<2?"":argv[1]); // Construct and configure bot Bot bot(cfg); // Set signal handlers if available # ifdef sa_sigaction struct sigaction sigact; static Bot& bot_st = bot; static const auto main_thread = std::this_thread::get_id(); sigact.sa_handler = [] (int) { if (std::this_thread::get_id() == main_thread) { bot_st.stop_prepare(); exit(0); } }; sigemptyset(&sigact.sa_mask); sigact.sa_flags = 0; sigaction(SIGTERM, &sigact, nullptr); sigaction(SIGINT, &sigact, nullptr); sigaction(SIGHUP, &sigact, nullptr); # endif // Start bot bot.start(); }