1
0
Fork 0
mirror of https://gitlab.com/niansa/discord_llama.git synced 2025-03-06 20:48:25 +01:00

Converted most commands to slash commands and added tasklist

This commit is contained in:
niansa 2023-05-10 15:24:55 +02:00
parent 13f1fde45d
commit 52288eb0c7
2 changed files with 84 additions and 25 deletions

@ -1 +1 @@
Subproject commit e38e792adc2b06ff38d1e06f8609c1087fd27b21
Subproject commit 05eb25ce59de5fa9fdb80c0124cfdc09f920861c

107
main.cpp
View file

@ -4,17 +4,15 @@
#include <string>
#include <string_view>
#include <sstream>
#include <stdexcept>
#include <fstream>
#include <thread>
#include <chrono>
#include <functional>
#include <array>
#include <vector>
#include <unordered_map>
#include <filesystem>
#include <sstream>
#include <optional>
#include <mutex>
#include <memory>
#include <utility>
@ -185,8 +183,8 @@ private:
// Must run in llama thread
CoSched::AwaitableTask<void> llm_init() {
// Run at realtime priority
CoSched::Task::get_current().set_priority(CoSched::PRIO_REALTIME);
// Run at high priority
CoSched::Task::get_current().set_priority(CoSched::PRIO_HIGHER);
// Set LLM thread
llm_tid = std::this_thread::get_id();
// Translate texts
@ -546,6 +544,12 @@ public:
std::cout << "Connected to Discord." << std::endl;
// Register chat command once
if (dpp::run_once<struct register_bot_commands>()) {
auto register_command = [this] (dpp::slashcommand&& c) {
bot.global_command_edit(c, [this, c] (const dpp::confirmation_callback_t& ccb) {
if (ccb.is_error()) bot.global_command_create(c);
});
};
// Register model commands
for (const auto& [name, model] : config.models) {
// Create command
dpp::slashcommand command(name, "Start a chat with me", bot.me.id);
@ -554,10 +558,13 @@ public:
command.add_option(dpp::command_option(dpp::co_boolean, "instruct_mode", "Defaults to \"True\" for best output quality. Weather to enable instruct mode", false));
}
// Register command
bot.global_command_edit(command, [this, command] (const dpp::confirmation_callback_t& ccb) {
if (ccb.is_error()) bot.global_command_create(command);
});
register_command(std::move(command));
}
// Register other commands
register_command(dpp::slashcommand("ping", "Check my status", bot.me.id));
register_command(dpp::slashcommand("reset", "Reset this conversation", bot.me.id));
register_command(dpp::slashcommand("tasklist", "Get list of tasks", bot.me.id));
//register_command(dpp::slashcommand("taskkill", "Kill a task", bot.me.id)); TODO
}
if (dpp::run_once<class LM::Inference>()) {
// Prepare llm
@ -567,6 +574,73 @@ public:
}
});
bot.on_slashcommand([=, this](dpp::slashcommand_t event) {
const auto invalidate_event = [this] (const dpp::slashcommand_t& event) {
if (is_on_own_shard(event.command.channel_id)) {
event.thinking(true, [this, event] (const dpp::confirmation_callback_t& ccb) {
event.delete_original_response();
});
}
};
// Process basic commands
const auto& command_name = event.command.get_command_name();
if (command_name == "ping") {
// Sender message
if (is_on_own_shard(event.command.channel_id)) {
bot.message_create(dpp::message(event.command.channel_id, "Ping from user "+event.command.usr.format_username()+'!'));
}
// Recipient message
bot.message_create(dpp::message(event.command.channel_id, "Pong from shard "+std::to_string(config.shard_id+1)+'/'+std::to_string(config.shard_count)+'!'));
// Finalize
invalidate_event(event);
return;
} else if (command_name == "reset") {
// Delete inference from pool
sched_thread.create_task("Language Model Inference Pool", [this, id = event.command.channel_id, user = event.command.usr] () -> CoSched::AwaitableTask<void> {
CoSched::Task::get_current().properties.emplace("user", std::move(user));
co_await llm_pool.delete_inference(id);
});
// Sender message
bot.message_create(dpp::message(event.command.channel_id, "Conversation was reset by "+event.command.usr.format_username()+'!'));
// Finalize
invalidate_event(event);
return;
} else if (command_name == "tasklist") {
// Build task list
sched_thread.create_task("tasklist", [this, event, id = event.command.channel_id, user = event.command.usr] () -> CoSched::AwaitableTask<void> {
auto& task = CoSched::Task::get_current();
task.properties.emplace("user", std::move(user));
// Set priority to max
task.set_priority(CoSched::PRIO_REALTIME);
// Header
std::string str = "**__Task List on Shard "+std::to_string(config.shard_id)+"__**\n";
// Produce list
for (const auto& task : task.get_scheduler().get_tasks()) {
// Get user
const dpp::user *user = nullptr;
{
auto res = task->properties.find("user");
if (res != task->properties.end()) {
user = &std::any_cast<const dpp::user&>(res->second);
}
}
// Append line
str += fmt::format(" - `{}` (State: **{}**, Priority: **{}**, User: **{}**)\n", task->get_name(), task->is_suspended()?"suspended":task->get_state_string(), task->get_priority(), user?user->format_username():bot.me.format_username());
}
// Delete original thinking response
if (is_on_own_shard(event.command.channel_id)) {
event.delete_original_response();
}
// Send list
bot.message_create(dpp::message(id, str));
co_return;
});
// Finalize
if (is_on_own_shard(event.command.channel_id)) {
event.thinking(false);
}
return;
}
// Run command completion handler
command_completion_handler(std::move(event));
});
bot.on_message_create([=, this](...) {
@ -608,11 +682,6 @@ public:
users[event.msg.author.id] = event.msg.author;
// Make sure message has content
if (event.msg.content.empty()) return;
// Handle basic commands
if (event.msg.content == "!ping") {
bot.message_create(dpp::message(event.msg.channel_id, "Pong from shard "+std::to_string(config.shard_id+1)+'/'+std::to_string(config.shard_count)+'!'));
return;
}
// Ignore messges from channel on another shard
bool this_shard = false;
db << "SELECT this_shard FROM threads "
@ -630,17 +699,6 @@ public:
}
// Process message
try {
// Check for prompt commands
if (event.msg.content == "!reset") {
// Delete inference from pool
sched_thread.create_task("Language Model Inference Pool", [msg = event.msg, this] () -> CoSched::AwaitableTask<void> {
co_await llm_pool.delete_inference(msg.channel_id);
});
// Delete message
bot.message_delete(event.msg.id, event.msg.channel_id);
bot.message_create(dpp::message(event.msg.channel_id, "Conversation was reset by "+event.msg.author.format_username()+'!'));
return;
}
// Copy message
dpp::message msg = event.msg;
// Replace bot mentions with bot username
@ -679,6 +737,7 @@ public:
}
// Append message
sched_thread.create_task("Language Model Inference ("+*channel_cfg.model_name+" at "+std::to_string(msg.channel_id)+")", [=, this] () -> CoSched::AwaitableTask<void> {
CoSched::Task::get_current().properties.emplace("user", msg.author);
// Create initial message
auto placeholder_msg = bot.message_create_sync(dpp::message(msg.channel_id, config.texts.please_wait+" :thinking:"));
// Get task