1
0
Fork 0
mirror of https://gitlab.com/niansa/discord_llama.git synced 2025-03-06 20:48:25 +01:00
discord_llama/main.cpp
2024-03-27 18:42:06 +01:00

850 lines
37 KiB
C++

#include "utils.hpp"
#include "config.hpp"
#include "sqlite_modern_cpp/sqlite_modern_cpp.h"
#include <string>
#include <string_view>
#include <stdexcept>
#include <fstream>
#include <thread>
#include <chrono>
#include <functional>
#include <vector>
#include <unordered_map>
#include <filesystem>
#include <optional>
#include <mutex>
#include <memory>
#include <utility>
#include <dpp/dpp.h>
#include <fmt/format.h>
#include <justlm.hpp>
#include <justlm_pool.hpp>
#include <cosched2/scheduled_thread.hpp>
#include <cosched2/scheduler_mutex.hpp>
class Bot {
CoSched::ScheduledThread sched_thread;
LM::InferencePool llm_pool;
std::vector<dpp::snowflake> my_messages;
std::unordered_map<dpp::snowflake, dpp::user> users;
std::thread::id llm_tid;
utils::Timer cleanup_timer;
sqlite::database db;
std::mutex command_completion_buffer_mutex;
std::unordered_map<dpp::snowflake, dpp::slashcommand_t> command_completion_buffer;
std::mutex thread_embeds_mutex;
std::unordered_map<dpp::snowflake, dpp::message> thread_embeds;
dpp::cluster bot;
public:
struct BotChannelConfig {
const std::string *model_name;
const Configuration::Model *model;
bool instruct_mode = false;
};
private:
Configuration& config;
inline static
bool show_console_progress(float progress) {
std::cout << ' ' << unsigned(progress) << "% \r" << std::flush;
return true;
}
// Must run in llama thread
# define ENSURE_LLM_THREAD() if (std::this_thread::get_id() != llm_tid) {throw std::runtime_error("LLM execution of '"+std::string(__PRETTY_FUNCTION__)+"' on wrong thread detected");} 0
LM::Inference::Params llm_get_params(bool instruct_mode = false) const {
return {
.n_threads = config.threads,
.n_ctx = config.ctx_size,
.n_repeat_last = unsigned(instruct_mode?0:256),
.temp = 0.3f,
.repeat_penalty = instruct_mode?1.0f:1.372222224f,
.use_mlock = config.mlock
};
}
bool check_timeout(utils::Timer& timer, const dpp::message& msg, uint8_t& slow) {
auto passed = timer.get<std::chrono::seconds>();
if (passed > config.timeout) {
auto& task = CoSched::Task::get_current();
// Calculate new priority
const CoSched::Priority prio = task.get_priority()-5;
// Make sure it's above minimum
if (prio < CoSched::PRIO_LOWEST) {
// Stop
slow = 2;
return false;
}
// Decrease priority
task.set_priority(prio);
// Add snail reaction
if (!slow) {
slow = 1;
bot.message_add_reaction(msg, "🐌");
}
// Reset timeout timer
timer.reset();
}
// No need to stop
return true;
}
// Must run in llama thread
bool llm_restart(const std::shared_ptr<LM::Inference>& inference, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
// Deserialize init cache if not instruct mode without prompt file
if (channel_cfg.instruct_mode && config.instruct_prompt_file == "none") return true;
const auto path = (*channel_cfg.model_name)+(channel_cfg.instruct_mode?"_instruct_init_cache":"_init_cache");
std::ifstream f(path, std::ios::binary);
if (!f) {
std::cerr << "Warning: Failed to init cache open file, consider regeneration: " << path << std::endl;
return false;
}
if (!inference->deserialize(f)) {
return false;
}
// Set params
inference->params.n_ctx_window_top_bar = inference->get_context_size();
inference->params.scroll_keep = float(config.scroll_keep) * 0.01f;
return true;
}
// Must run in llama thread
std::shared_ptr<LM::Inference> llm_start(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
// Get or create inference
auto inference = llm_pool.create_inference(id, channel_cfg.model->weights_path, llm_get_params(channel_cfg.instruct_mode));
if (!llm_restart(inference, channel_cfg)) {
std::cerr << "Warning: Failed to deserialize cache: " << inference->get_last_error() << std::endl;
return nullptr;
}
return inference;
}
// Must run in llama thread
std::shared_ptr<LM::Inference> llm_get_inference(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
// Get inference
auto fres = llm_pool.get_inference(id);
if (!fres) {
// Start new inference
fres = llm_start(id, channel_cfg);
// Check for error
if (!fres) {
return nullptr;
}
}
// Set scroll callback
fres->set_scroll_callback([msg = dpp::message(), channel_id = id] (float progress) {
std::cout << "WARNING: " << channel_id << " is scrolling! " << progress << "% \r" << std::flush;
return true;
});
// Return inference
return fres;
}
// Must run in llama thread
void llm_init() {
// Run at high priority
CoSched::Task::get_current().set_priority(CoSched::PRIO_HIGHER);
// Set LLM thread
llm_tid = std::this_thread::get_id();
// Set scroll callback
auto scroll_cb = [] (float) {
std::cerr << "Error: Prompt doesn't fit into max. context size!" << std::endl;
abort();
return false;
};
// Build init caches
std::string filename;
for (const auto& [model_name, model_config] : config.models) {
//TODO: Add hashes to regenerate these as needed
// Standard prompt
filename = model_name+"_init_cache";
if (model_config.is_non_instruct_mode_allowed() &&
!std::filesystem::exists(filename) && config.prompt_file != "none") {
std::cout << "Building init_cache for "+model_name+"..." << std::endl;
auto llm = LM::Inference::construct(model_config.weights_path, llm_get_params());
// Add initial context
std::string prompt;
{
// Read whole file
std::ifstream f(config.prompt_file);
if (!f) {
// Clean up and abort on error
std::cerr << "Error: Failed to open prompt file." << std::endl;
abort();
}
std::ostringstream sstr;
sstr << f.rdbuf();
prompt = sstr.str();
}
// Append
using namespace fmt::literals;
if (prompt.back() != '\n') prompt.push_back('\n');
llm->set_scroll_callback(scroll_cb);
llm->append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username), show_console_progress);
// Serialize end result
std::ofstream f(filename, std::ios::binary);
llm->serialize(f);
}
// Instruct prompt
filename = model_name+"_instruct_init_cache";
if (model_config.is_instruct_mode_allowed() &&
!std::filesystem::exists(filename)) {
std::cout << "Building instruct_init_cache for "+model_name+"..." << std::endl;
auto llm = LM::Inference::construct(model_config.weights_path, llm_get_params());
// Add initial context
std::string prompt;
if (config.instruct_prompt_file != "none" && !model_config.no_instruct_prompt) {
// Read whole file
std::ifstream f(config.instruct_prompt_file);
if (!f) {
// Clean up and abort on error
std::cerr << "Error: Failed to open instruct prompt file." << std::endl;
abort();
}
std::ostringstream sstr;
sstr << f.rdbuf();
prompt = sstr.str();
// Append instruct prompt
using namespace fmt::literals;
if (prompt.back() != '\n' && !model_config.no_extra_linebreaks) prompt.push_back('\n');
llm->set_scroll_callback(scroll_cb);
llm->append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username, "bot_prompt"_a=model_config.bot_prompt, "user_prompt"_a=model_config.user_prompt)+(model_config.no_extra_linebreaks?"":"\n\n")+model_config.user_prompt, show_console_progress);
}
// Append user prompt
llm->append(model_config.user_prompt);
// Serialize end result
std::ofstream f(filename, std::ios::binary);
llm->serialize(f);
}
}
// Report complete init
std::cout << "Init done!" << std::endl;
}
// Must run in llama thread
bool prompt_add_msg(const dpp::message& msg, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
// Get inference
auto inference = llm_get_inference(msg.channel_id, channel_cfg);
if (!inference) {
std::cerr << "Warning: Failed to get inference" << std::endl;
return false;
}
std::string prefix;
// Define callback for console progress and timeout
utils::Timer timeout;
bool timeout_exceeded = false;
uint8_t slow = 0;
const auto cb = [&] (float progress) {
// Check for timeout
if (!check_timeout(timeout, msg, slow)) return false;
// Show progress in console
return show_console_progress(progress);
};
// Instruct mode user prompt
if (channel_cfg.instruct_mode) {
// Append line as-is
if (!inference->append((channel_cfg.model->no_extra_linebreaks?"\n":"\n\n")
+msg.content
+(channel_cfg.model->no_extra_linebreaks?"":"\n"), cb)) {
std::cerr << "Warning: Failed to append user prompt: " << inference->get_last_error() << std::endl;
return false;
}
} else {
// Format and append lines
for (const auto line : utils::str_split(msg.content, '\n')) {
if (!inference->append(msg.author.username+": "+std::string(line)+'\n', cb)) {
std::cerr << "Warning: Failed to append user prompt (single line): " << inference->get_last_error() << std::endl;
return false;
}
}
}
// Append line break on timeout
if (timeout_exceeded) return inference->append("\n");
return true;
}
// Must run in llama thread
bool prompt_add_trigger(const std::shared_ptr<LM::Inference>& inference, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
if (channel_cfg.instruct_mode) {
return inference->append((channel_cfg.model->no_extra_linebreaks?"":"\n")
+channel_cfg.model->bot_prompt
+(channel_cfg.model->no_extra_linebreaks?"\n":"\n\n"));
} else {
return inference->append(bot.me.username+':', show_console_progress);
}
}
// Must run in llama thread
void reply(dpp::snowflake id, dpp::message& new_msg, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
// Get inference
auto inference = llm_get_inference(id, channel_cfg);
if (!inference) {
std::cerr << "Warning: Failed to get inference" << std::endl;
return;
}
// Trigger LLM correctly
if (!prompt_add_trigger(inference, channel_cfg)) {
std::cerr << "Warning: Failed to add trigger to prompt: " << inference->get_last_error() << std::endl;
return;
}
if (CoSched::Task::get_current().is_dead()) {
return;
}
// Run model
utils::Timer timeout;
utils::Timer edit_timer;
new_msg.content.clear();
const std::string reverse_prompt = channel_cfg.instruct_mode?channel_cfg.model->user_prompt:"\n";
uint8_t slow = 0;
bool response_too_long = false;
auto output = inference->run(reverse_prompt, [&] (std::string_view token) {
std::cout << token << std::flush;
// Check for timeout
if (!check_timeout(timeout, new_msg, slow)) return false;
// Make sure message isn't too long
if (new_msg.content.size() > 1995-config.texts.length_error.size()) {
response_too_long = true;
return false;
}
// Edit live
if (config.live_edit) {
new_msg.content += token;
if (edit_timer.get<std::chrono::seconds>() > 3) {
try {
bot.message_edit(new_msg);
} catch (...) {}
edit_timer.reset();
}
}
return true;
});
if (output.empty()) {
std::cerr << "Warning: Failed to generate message: " << inference->get_last_error() << std::endl;
output = '<'+config.texts.empty_response+'>';
}
std::cout << std::endl;
// Handle message length error
if (response_too_long) {
output += "...\n"+config.texts.length_error;
}
// Handle timeout
else if (slow == 2) {
output += "...\n"+config.texts.timeout;
}
// Handle termination
else if (CoSched::Task::get_current().is_dead()) {
output += "...\n"+config.texts.terminated;
}
// Send resulting message
new_msg.content = std::move(output);
try {
bot.message_edit(new_msg);
} catch (...) {}
// Tell model about length error
if (response_too_long) {
inference->append("... Response interrupted due to length error");
}
// Prepare for next message
if (!channel_cfg.instruct_mode || !channel_cfg.model->no_extra_linebreaks) {
inference->append("\n");
}
if (channel_cfg.instruct_mode && channel_cfg.model->emits_eos) {
inference->append("\n"+channel_cfg.model->user_prompt);
}
}
bool check_should_reply(const dpp::message& msg) {
// Reply if message contains username, mention or ID
if (msg.content.find(bot.me.username) != std::string::npos) {
return true;
}
// Reply if message references user
for (const auto msg_id : my_messages) {
if (msg.message_reference.message_id == msg_id) {
return true;
}
}
// Reply at random
if (config.random_response_chance) {
if (!(unsigned(msg.id.get_creation_time()) % config.random_response_chance)) {
return true;
}
}
// Don't reply otherwise
return false;
}
bool is_on_own_shard(dpp::snowflake id) const {
return (unsigned(id.get_creation_time()) % config.shard_count) == config.shard_id;
}
void cleanup() {
// Clean up InferencePool
if (config.max_context_age) llm_pool.cleanup(config.max_context_age);
// Reset timer
cleanup_timer.reset();
}
void attempt_cleanup() {
// Run cleanup if enough time has passed
if (cleanup_timer.get<std::chrono::seconds>() > config.max_context_age / 4) {
cleanup();
}
}
std::string create_thread_name(const std::string& model_name, bool instruct_mode) const {
return "Chat with "+model_name+" " // Model name
+(instruct_mode?"":"(Non Instruct mode)") // Instruct mode
+(config.shard_count!=1?(" #"+std::to_string(config.shard_id)):""); // Shard ID
}
dpp::embed create_chat_embed(dpp::snowflake guild_id, dpp::snowflake thread_id, const std::string& model_name, bool instruct_mode, const dpp::user& author, std::string_view first_message = "") const {
dpp::embed embed;
// Create embed
embed.set_title(create_thread_name(model_name, instruct_mode))
.set_description("[Open the chat](https://discord.com/channels/"+std::to_string(guild_id)+'/'+std::to_string(thread_id)+')')
.set_footer(dpp::embed_footer().set_text("Started by "+author.format_username()))
.set_color(utils::get_unique_color(model_name));
// Add first message if any
if (!first_message.empty()) {
// Make sure it's not too long
std::string shorted(utils::max_words(first_message, 12));
if (shorted.size() != first_message.size()) {
shorted += "...";
}
embed.description += "\n\n> "+shorted;
}
// Warn about non-instruct mode
if (instruct_mode == false) {
embed.description += "\n\n**In the selected mode, the quality is highly degraded**, but the conversation more humorous. Please avoid this if you want helpful responses or want to evaluate the models quality.";
}
// Return final result
return embed;
}
// This function is responsible for sharding thread creation
// A bit ugly but a nice way to avoid having to communicate over any other means than just the Discord API
bool command_completion_handler(dpp::slashcommand_t&& event, dpp::channel *thread = nullptr) {
// Stop if this is not the correct shard for thread creation
if (thread == nullptr) {
// But register this command first
std::scoped_lock L(command_completion_buffer_mutex);
command_completion_buffer.emplace(event.command.id, std::move(event));
// And then actually stop
if (!is_on_own_shard(event.command.channel_id)) return false;
}
// Get model by name
auto res = config.models.find(event.command.get_command_name());
if (res == config.models.end()) {
// Model does not exit, delete corresponding command
bot.global_command_delete(event.command.get_command_interaction().id);
return false;
}
const auto& [model_name, model_config] = *res;
// Get weather to enable instruct mode
bool instruct_mode;
{
const auto& instruct_mode_param = event.get_parameter("instruct_mode");
if (instruct_mode_param.index()) {
instruct_mode = std::get<bool>(instruct_mode_param);
} else {
instruct_mode = model_config.instruct_mode_policy != Configuration::Model::InstructModePolicy::Forbid;
}
}
// Create thread if it doesn't exist or update it if it does
if (thread == nullptr) {
bot.thread_create(std::to_string(event.command.id), event.command.channel_id, 1440, dpp::CHANNEL_PUBLIC_THREAD, true, 15,
[this, event, model_name = res->first] (const dpp::confirmation_callback_t& ccb) {
// Check for error
if (ccb.is_error()) {
std::cout << "Thread creation failed: " << ccb.get_error().message << std::endl;
event.reply(dpp::message(config.texts.thread_create_fail).set_flags(dpp::message_flags::m_ephemeral));
return;
}
std::cout << "Responsible for creating thread: " << ccb.get<dpp::thread>().id << std::endl;
// Report success
event.reply(dpp::message("Okay!").set_flags(dpp::message_flags::m_ephemeral));
});
} else {
bool this_shard = is_on_own_shard(thread->id);
// Add thread to database
db << "INSERT INTO threads (id, model, instruct_mode, this_shard) VALUES (?, ?, ?, ?);"
<< std::to_string(thread->id) << model_name << instruct_mode << this_shard;
// Stop if this is not the correct shard for thread finalization
if (!this_shard) return false;
// Set name
std::cout << "Responsible for finalizing thread: " << thread->id << std::endl;
thread->name = create_thread_name(model_name, instruct_mode);
bot.channel_edit(*thread);
// Send embed
const auto embed = create_chat_embed(event.command.guild_id, thread->id, model_name, instruct_mode, event.command.usr);
bot.message_create(dpp::message(event.command.channel_id, embed),
[this, thread_id = thread->id] (const dpp::confirmation_callback_t& ccb) {
// Check for error
if (ccb.is_error()) {
std::cerr << "Warning: Failed to create embed: " << ccb.get_error().message << std::endl;
return;
}
// Get message
const auto& msg = ccb.get<dpp::message>();
// Add to embed list
std::scoped_lock L(thread_embeds_mutex);
thread_embeds[thread_id] = msg;
});
}
return true;
}
public:
Bot(decltype(config) cfg)
: llm_pool(cfg.pool_size, "discord_llama", !cfg.persistance),
db("database.sqlite3"), bot(cfg.token), config(cfg) {
// Initialize database
db << "CREATE TABLE IF NOT EXISTS threads ("
" id TEXT PRIMARY KEY NOT NULL,"
" model TEXT,"
" instruct_mode INTEGER,"
" this_shard INTEGER,"
" UNIQUE(id)"
");";
// Start Scheduled Thread
sched_thread.start();
// Configure bot
bot.on_log(dpp::utility::cout_logger());
bot.intents = dpp::i_guild_messages | dpp::i_message_content | dpp::i_message_content;
// Set callbacks
bot.on_ready([=, this] (const dpp::ready_t&) { //TODO: Consider removal
std::cout << "Connected to Discord." << std::endl;
// Register chat command once
if (dpp::run_once<struct register_bot_commands>()) {
auto register_command = [this] (dpp::slashcommand&& c) {
bot.global_command_edit(c, [this, c] (const dpp::confirmation_callback_t& ccb) {
if (ccb.is_error()) bot.global_command_create(c);
});
};
// Register model commands
for (const auto& [name, model] : config.models) {
// Create command
dpp::slashcommand command(name, "Start a chat with me", bot.me.id);
// Add instruct mode option
if (model.instruct_mode_policy == Configuration::Model::InstructModePolicy::Allow) {
command.add_option(dpp::command_option(dpp::co_boolean, "instruct_mode", "Defaults to \"True\" for best output quality. Weather to enable instruct mode", false));
}
// Register command
register_command(std::move(command));
}
// Register other commands
register_command(dpp::slashcommand("ping", "Check my status", bot.me.id));
register_command(dpp::slashcommand("reset", "Reset this conversation", bot.me.id));
register_command(dpp::slashcommand("tasklist", "Get list of tasks", bot.me.id));
//register_command(dpp::slashcommand("taskkill", "Kill a task", bot.me.id)); TODO
}
if (dpp::run_once<class LM::Inference>()) {
// Prepare llm
sched_thread.create_task("Language Model Initialization", [this] () -> void {
llm_init();
});
}
});
bot.on_slashcommand([=, this](dpp::slashcommand_t event) {
const auto invalidate_event = [this] (const dpp::slashcommand_t& event) {
if (is_on_own_shard(event.command.channel_id)) {
event.thinking(true, [this, event] (const dpp::confirmation_callback_t& ccb) {
event.delete_original_response();
});
}
};
// Process basic commands
const auto& command_name = event.command.get_command_name();
if (command_name == "ping") {
// Sender message
if (is_on_own_shard(event.command.channel_id)) {
bot.message_create(dpp::message(event.command.channel_id, "Ping from user "+event.command.usr.format_username()+'!'));
}
// Recipient message
bot.message_create(dpp::message(event.command.channel_id, "Pong from shard "+std::to_string(config.shard_id+1)+'/'+std::to_string(config.shard_count)+'!'));
// Finalize
invalidate_event(event);
return;
} else if (command_name == "reset") {
// Delete inference from pool
sched_thread.create_task("Language Model Inference Pool", [this, id = event.command.channel_id, user = event.command.usr] () -> void {
CoSched::Task::get_current().user_data = std::move(user);
llm_pool.delete_inference(id);
});
// Sender message
if (is_on_own_shard(event.command.channel_id)) {
bot.message_create(dpp::message(event.command.channel_id, "Conversation was reset by "+event.command.usr.format_username()+'!'));
}
// Finalize
invalidate_event(event);
return;
} else if (command_name == "tasklist") {
// Build task list
sched_thread.create_task("tasklist", [this, event, id = event.command.channel_id, user = event.command.usr] () -> void {
auto& task = CoSched::Task::get_current();
task.user_data = std::move(user);
// Set priority to max
task.set_priority(CoSched::PRIO_REALTIME);
// Header
std::string str = "**__Task List on Shard "+std::to_string(config.shard_id)+"__**\n";
// Produce list
for (const auto& task : task.get_scheduler().get_tasks()) {
// Get user
const dpp::user *user = nullptr;
{
if (task->user_data.has_value()) {
user = &std::any_cast<const dpp::user&>(task->user_data);
}
}
// Append line
str += fmt::format("- `{}` (State: **{}**, Priority: **{}**, User: **{}**)\n", task->get_name(), task->is_suspended()?"suspended":task->get_state_string(), task->get_priority(), user?user->format_username():bot.me.format_username());
}
// Delete original thinking response
if (is_on_own_shard(event.command.channel_id)) {
event.delete_original_response();
}
// Send list
bot.message_create(dpp::message(id, str));
return;
});
// Finalize
if (is_on_own_shard(event.command.channel_id)) {
event.thinking(false);
}
return;
}
// Run command completion handler
command_completion_handler(std::move(event));
});
bot.on_message_create([=, this](const dpp::message_create_t&) {
// Attempt cleanup
attempt_cleanup();
});
bot.on_message_create([=, this](const dpp::message_create_t& event) {
// Check that this is for thread creation
if (event.msg.type != dpp::mt_thread_created) return;
// Get thread that was created
bot.channel_get(event.msg.id, [this, msg_id = event.msg.id, channel_id = event.msg.channel_id] (const dpp::confirmation_callback_t& ccb) {
// Stop on error
if (ccb.is_error()) return;
// Get thread
auto thread = ccb.get<dpp::channel>();
// Attempt to get command ID
dpp::snowflake command_id;
try {
command_id = thread.name;
} catch (...) {
return;
}
// Find command
std::scoped_lock L(command_completion_buffer_mutex);
auto res = command_completion_buffer.find(command_id);
if (res == command_completion_buffer.end()) {
return;
}
// Complete command
auto handled = command_completion_handler(std::move(res->second), &thread);
// Remove command from buffer
command_completion_buffer.erase(res);
// Delete this message if we handled it
if (handled) bot.message_delete(msg_id, channel_id);
});
});
bot.on_message_create([=, this] (const dpp::message_create_t& event) {
// Update user cache
users[event.msg.author.id] = event.msg.author;
// Make sure message has content
if (event.msg.content.empty()) return;
// Ignore messges from channel on another shard
bool this_shard = is_on_own_shard(event.msg.channel_id);
db << "SELECT this_shard FROM threads "
"WHERE id = ?;"
<< std::to_string(event.msg.channel_id)
>> [&](int _this_shard) {
this_shard = _this_shard;
};
if (!this_shard) return;
// Ignore own messages
if (event.msg.author.id == bot.me.id) {
// Add message to list of own messages
my_messages.push_back(event.msg.id);
return;
}
// Process message
try {
// Copy message
dpp::message msg = event.msg;
// Replace bot mentions with bot username
utils::str_replace_in_place(msg.content, "<@"+std::to_string(bot.me.id)+'>', bot.me.username);
// Replace all other known users
for (const auto& [user_id, user] : users) {
utils::str_replace_in_place(msg.content, "<@"+std::to_string(user_id)+'>', user.username);
}
// Get channel config
BotChannelConfig channel_cfg;
// Attempt to find thread first...
bool in_bot_thread = false,
model_missing = false;
db << "SELECT model, instruct_mode, this_shard FROM threads "
"WHERE id = ?;"
<< std::to_string(msg.channel_id)
>> [&](const std::string& model_name, int instruct_mode) {
in_bot_thread = true;
channel_cfg.instruct_mode = instruct_mode;
// Find model
auto res = config.models.find(model_name);
if (res == config.models.end()) {
bot.message_create(dpp::message(msg.channel_id, config.texts.model_missing));
model_missing = true;
return;
}
channel_cfg.model_name = &res->first;
channel_cfg.model = &res->second;
};
if (model_missing) return;
// Otherwise just fall back to the default model config if allowed
if (!in_bot_thread) {
if (config.threads_only) return;
channel_cfg.model_name = &config.default_inference_model;
channel_cfg.model = config.default_inference_model_cfg;
}
// Append message
sched_thread.create_task("Language Model Inference ("+*channel_cfg.model_name+" at "+std::to_string(msg.channel_id)+")", [=, this] () -> void {
CoSched::Task::get_current().user_data = msg.author;
// Create initial message
dpp::message placeholder_msg(msg.channel_id, config.texts.please_wait+" :thinking:");
// Get task
auto &task = CoSched::Task::get_current();
// Await previous completion
while (true) {
// Check that there are no other tasks with the same name
bool is_unique = true;
bool any_non_suspended = false;
for (const auto& other_task : task.get_scheduler().get_tasks()) {
if (&task == other_task.get()) continue;
if (task.get_name() == other_task->get_name()) {
is_unique = false;
} else if (!other_task->is_suspended()) {
any_non_suspended = true;
}
}
// Stop looking if there is no task that isn't suspended
if (!any_non_suspended) break;
// Stop looking if task is unique
if (is_unique) break;
// Suspend, we'll be woken up by that other task
task.set_suspended(true);
if (!task.yield()) return;
}
// Check if message should reply
bool should_reply = false;
if (in_bot_thread) {
should_reply = true;
} else if (msg.content == "!trigger") {
bot.message_delete(msg.id, msg.channel_id);
should_reply = true;
} else {
should_reply = check_should_reply(msg);
}
if (should_reply) {
// Send placeholder
placeholder_msg = bot.message_create_sync(placeholder_msg);
// Add user message
if (!prompt_add_msg(msg, channel_cfg)) {
std::cerr << "Warning: Failed to add user message, not going to reply" << std::endl;
return;
}
// Send a reply
reply(msg.channel_id, placeholder_msg, channel_cfg);
} else {
// Add user message
if (!prompt_add_msg(msg, channel_cfg)) {
std::cerr << "Warning: Failed to add user message" << std::endl;
return;
}
}
// Unsuspend other tasks with same name
for (const auto& other_task : task.get_scheduler().get_tasks()) {
if (task.get_name() == other_task->get_name()) {
other_task->set_suspended(false);
}
}
});
// Find thread embed
std::scoped_lock L(thread_embeds_mutex);
auto res = thread_embeds.find(msg.channel_id);
if (res == thread_embeds.end()) {
return;
}
// Update that embed
auto embed_msg = res->second;
embed_msg.embeds[0] = create_chat_embed(msg.guild_id, msg.channel_id, *channel_cfg.model_name, channel_cfg.instruct_mode, msg.author, msg.content);
bot.message_edit(embed_msg);
// Remove thread embed linkage from vector
thread_embeds.erase(res);
} catch (const std::exception& e) {
std::cerr << "Warning: " << e.what() << std::endl;
}
});
}
void start() {
cleanup();
bot.start(dpp::st_wait);
}
void stop_prepare() {
if (config.persistance) {
sched_thread.create_task("Language Model Shutdown", [=, this] () -> void {
llm_pool.store_all();
});
}
sched_thread.wait();
}
};
int main(int argc, char **argv) {
// Parse configuration
Configuration cfg;
cfg.parse_configs(argc<2?"":argv[1]);
// Construct and configure bot
Bot bot(cfg);
// Set signal handlers if available
# ifdef sa_sigaction
struct sigaction sigact;
static Bot& bot_st = bot;
static const auto main_thread = std::this_thread::get_id();
sigact.sa_handler = [] (int) {
if (std::this_thread::get_id() == main_thread) {
bot_st.stop_prepare();
exit(0);
}
};
sigemptyset(&sigact.sa_mask);
sigact.sa_flags = 0;
sigaction(SIGTERM, &sigact, nullptr);
sigaction(SIGINT, &sigact, nullptr);
sigaction(SIGHUP, &sigact, nullptr);
# endif
// Start bot
bot.start();
}