mirror of
https://gitlab.com/niansa/discord_llama.git
synced 2025-03-06 20:48:25 +01:00
Converted most commands to slash commands and added tasklist
This commit is contained in:
parent
13f1fde45d
commit
52288eb0c7
2 changed files with 84 additions and 25 deletions
2
cosched
2
cosched
|
@ -1 +1 @@
|
|||
Subproject commit e38e792adc2b06ff38d1e06f8609c1087fd27b21
|
||||
Subproject commit 05eb25ce59de5fa9fdb80c0124cfdc09f920861c
|
107
main.cpp
107
main.cpp
|
@ -4,17 +4,15 @@
|
|||
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <fstream>
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
#include <functional>
|
||||
#include <array>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <filesystem>
|
||||
#include <sstream>
|
||||
#include <optional>
|
||||
#include <mutex>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
|
@ -185,8 +183,8 @@ private:
|
|||
|
||||
// Must run in llama thread
|
||||
CoSched::AwaitableTask<void> llm_init() {
|
||||
// Run at realtime priority
|
||||
CoSched::Task::get_current().set_priority(CoSched::PRIO_REALTIME);
|
||||
// Run at high priority
|
||||
CoSched::Task::get_current().set_priority(CoSched::PRIO_HIGHER);
|
||||
// Set LLM thread
|
||||
llm_tid = std::this_thread::get_id();
|
||||
// Translate texts
|
||||
|
@ -546,6 +544,12 @@ public:
|
|||
std::cout << "Connected to Discord." << std::endl;
|
||||
// Register chat command once
|
||||
if (dpp::run_once<struct register_bot_commands>()) {
|
||||
auto register_command = [this] (dpp::slashcommand&& c) {
|
||||
bot.global_command_edit(c, [this, c] (const dpp::confirmation_callback_t& ccb) {
|
||||
if (ccb.is_error()) bot.global_command_create(c);
|
||||
});
|
||||
};
|
||||
// Register model commands
|
||||
for (const auto& [name, model] : config.models) {
|
||||
// Create command
|
||||
dpp::slashcommand command(name, "Start a chat with me", bot.me.id);
|
||||
|
@ -554,10 +558,13 @@ public:
|
|||
command.add_option(dpp::command_option(dpp::co_boolean, "instruct_mode", "Defaults to \"True\" for best output quality. Weather to enable instruct mode", false));
|
||||
}
|
||||
// Register command
|
||||
bot.global_command_edit(command, [this, command] (const dpp::confirmation_callback_t& ccb) {
|
||||
if (ccb.is_error()) bot.global_command_create(command);
|
||||
});
|
||||
register_command(std::move(command));
|
||||
}
|
||||
// Register other commands
|
||||
register_command(dpp::slashcommand("ping", "Check my status", bot.me.id));
|
||||
register_command(dpp::slashcommand("reset", "Reset this conversation", bot.me.id));
|
||||
register_command(dpp::slashcommand("tasklist", "Get list of tasks", bot.me.id));
|
||||
//register_command(dpp::slashcommand("taskkill", "Kill a task", bot.me.id)); TODO
|
||||
}
|
||||
if (dpp::run_once<class LM::Inference>()) {
|
||||
// Prepare llm
|
||||
|
@ -567,6 +574,73 @@ public:
|
|||
}
|
||||
});
|
||||
bot.on_slashcommand([=, this](dpp::slashcommand_t event) {
|
||||
const auto invalidate_event = [this] (const dpp::slashcommand_t& event) {
|
||||
if (is_on_own_shard(event.command.channel_id)) {
|
||||
event.thinking(true, [this, event] (const dpp::confirmation_callback_t& ccb) {
|
||||
event.delete_original_response();
|
||||
});
|
||||
}
|
||||
};
|
||||
// Process basic commands
|
||||
const auto& command_name = event.command.get_command_name();
|
||||
if (command_name == "ping") {
|
||||
// Sender message
|
||||
if (is_on_own_shard(event.command.channel_id)) {
|
||||
bot.message_create(dpp::message(event.command.channel_id, "Ping from user "+event.command.usr.format_username()+'!'));
|
||||
}
|
||||
// Recipient message
|
||||
bot.message_create(dpp::message(event.command.channel_id, "Pong from shard "+std::to_string(config.shard_id+1)+'/'+std::to_string(config.shard_count)+'!'));
|
||||
// Finalize
|
||||
invalidate_event(event);
|
||||
return;
|
||||
} else if (command_name == "reset") {
|
||||
// Delete inference from pool
|
||||
sched_thread.create_task("Language Model Inference Pool", [this, id = event.command.channel_id, user = event.command.usr] () -> CoSched::AwaitableTask<void> {
|
||||
CoSched::Task::get_current().properties.emplace("user", std::move(user));
|
||||
co_await llm_pool.delete_inference(id);
|
||||
});
|
||||
// Sender message
|
||||
bot.message_create(dpp::message(event.command.channel_id, "Conversation was reset by "+event.command.usr.format_username()+'!'));
|
||||
// Finalize
|
||||
invalidate_event(event);
|
||||
return;
|
||||
} else if (command_name == "tasklist") {
|
||||
// Build task list
|
||||
sched_thread.create_task("tasklist", [this, event, id = event.command.channel_id, user = event.command.usr] () -> CoSched::AwaitableTask<void> {
|
||||
auto& task = CoSched::Task::get_current();
|
||||
task.properties.emplace("user", std::move(user));
|
||||
// Set priority to max
|
||||
task.set_priority(CoSched::PRIO_REALTIME);
|
||||
// Header
|
||||
std::string str = "**__Task List on Shard "+std::to_string(config.shard_id)+"__**\n";
|
||||
// Produce list
|
||||
for (const auto& task : task.get_scheduler().get_tasks()) {
|
||||
// Get user
|
||||
const dpp::user *user = nullptr;
|
||||
{
|
||||
auto res = task->properties.find("user");
|
||||
if (res != task->properties.end()) {
|
||||
user = &std::any_cast<const dpp::user&>(res->second);
|
||||
}
|
||||
}
|
||||
// Append line
|
||||
str += fmt::format(" - `{}` (State: **{}**, Priority: **{}**, User: **{}**)\n", task->get_name(), task->is_suspended()?"suspended":task->get_state_string(), task->get_priority(), user?user->format_username():bot.me.format_username());
|
||||
}
|
||||
// Delete original thinking response
|
||||
if (is_on_own_shard(event.command.channel_id)) {
|
||||
event.delete_original_response();
|
||||
}
|
||||
// Send list
|
||||
bot.message_create(dpp::message(id, str));
|
||||
co_return;
|
||||
});
|
||||
// Finalize
|
||||
if (is_on_own_shard(event.command.channel_id)) {
|
||||
event.thinking(false);
|
||||
}
|
||||
return;
|
||||
}
|
||||
// Run command completion handler
|
||||
command_completion_handler(std::move(event));
|
||||
});
|
||||
bot.on_message_create([=, this](...) {
|
||||
|
@ -608,11 +682,6 @@ public:
|
|||
users[event.msg.author.id] = event.msg.author;
|
||||
// Make sure message has content
|
||||
if (event.msg.content.empty()) return;
|
||||
// Handle basic commands
|
||||
if (event.msg.content == "!ping") {
|
||||
bot.message_create(dpp::message(event.msg.channel_id, "Pong from shard "+std::to_string(config.shard_id+1)+'/'+std::to_string(config.shard_count)+'!'));
|
||||
return;
|
||||
}
|
||||
// Ignore messges from channel on another shard
|
||||
bool this_shard = false;
|
||||
db << "SELECT this_shard FROM threads "
|
||||
|
@ -630,17 +699,6 @@ public:
|
|||
}
|
||||
// Process message
|
||||
try {
|
||||
// Check for prompt commands
|
||||
if (event.msg.content == "!reset") {
|
||||
// Delete inference from pool
|
||||
sched_thread.create_task("Language Model Inference Pool", [msg = event.msg, this] () -> CoSched::AwaitableTask<void> {
|
||||
co_await llm_pool.delete_inference(msg.channel_id);
|
||||
});
|
||||
// Delete message
|
||||
bot.message_delete(event.msg.id, event.msg.channel_id);
|
||||
bot.message_create(dpp::message(event.msg.channel_id, "Conversation was reset by "+event.msg.author.format_username()+'!'));
|
||||
return;
|
||||
}
|
||||
// Copy message
|
||||
dpp::message msg = event.msg;
|
||||
// Replace bot mentions with bot username
|
||||
|
@ -679,6 +737,7 @@ public:
|
|||
}
|
||||
// Append message
|
||||
sched_thread.create_task("Language Model Inference ("+*channel_cfg.model_name+" at "+std::to_string(msg.channel_id)+")", [=, this] () -> CoSched::AwaitableTask<void> {
|
||||
CoSched::Task::get_current().properties.emplace("user", msg.author);
|
||||
// Create initial message
|
||||
auto placeholder_msg = bot.message_create_sync(dpp::message(msg.channel_id, config.texts.please_wait+" :thinking:"));
|
||||
// Get task
|
||||
|
|
Loading…
Add table
Reference in a new issue