1
0
Fork 0
mirror of https://gitlab.com/niansa/discord_llama.git synced 2025-03-06 20:48:25 +01:00

Reorganization and added experimental thread embeds

This commit is contained in:
niansa 2023-04-29 21:26:13 +02:00
parent f113ae5a81
commit dfcdcc7c1b
5 changed files with 200 additions and 89 deletions

View file

@ -11,7 +11,7 @@ add_subdirectory(DPP)
add_subdirectory(thread-pool)
add_subdirectory(fmt)
add_executable(discord_llama main.cpp)
add_executable(discord_llama main.cpp utils.cpp utils.hpp)
target_link_libraries(discord_llama PUBLIC dpp fmt pthread libjustlm anyproc ggml threadpool sqlite3)
install(TARGETS discord_llama

View file

@ -1,25 +0,0 @@
#ifndef _PHASMOENGINE_TIMER_HPP
#define _PHASMOENGINE_TIMER_HPP
#include <chrono>
class Timer {
std::chrono::time_point<std::chrono::high_resolution_clock> value;
public:
Timer() {
reset();
}
void reset() {
value = std::chrono::high_resolution_clock::now();
}
template<typename Unit = std::chrono::milliseconds>
auto get() {
auto duration = std::chrono::duration_cast<Unit>(std::chrono::high_resolution_clock::now() - value);
return duration.count();
}
};
#endif

143
main.cpp
View file

@ -1,4 +1,4 @@
#include "Timer.hpp"
#include "utils.hpp"
#include "sqlite_modern_cpp/sqlite_modern_cpp.h"
#include <string>
@ -25,39 +25,6 @@
static
std::vector<std::string_view> str_split(std::string_view s, char delimiter, size_t times = -1) {
std::vector<std::string_view> to_return;
decltype(s.size()) start = 0, finish = 0;
while ((finish = s.find_first_of(delimiter, start)) != std::string_view::npos) {
to_return.emplace_back(s.substr(start, finish - start));
start = finish + 1;
if (to_return.size() == times) { break; }
}
to_return.emplace_back(s.substr(start));
return to_return;
}
static
void str_replace_in_place(std::string& subject, std::string_view search,
const std::string& replace) {
if (search.empty()) return;
size_t pos = 0;
while ((pos = subject.find(search, pos)) != std::string::npos) {
subject.replace(pos, search.length(), replace);
pos += replace.length();
}
}
static
void clean_for_command_name(std::string& value) {
for (auto& c : value) {
if (c == '.') c = '_';
if (isalpha(c)) c = tolower(c);
}
}
class Bot {
ThreadPool thread_pool{1};
std::shared_ptr<bool> stopping;
@ -71,7 +38,9 @@ class Bot {
std::mutex command_completion_buffer_mutex;
std::unordered_map<dpp::snowflake, dpp::slashcommand_t> command_completion_buffer;
std::string_view language;
std::mutex thread_embeds_mutex;
std::unordered_map<dpp::snowflake, dpp::snowflake> thread_embeds;
dpp::cluster bot;
public:
@ -148,18 +117,18 @@ private:
ENSURE_LLM_THREAD();
// Skip if there is no translator
if (translator == nullptr || skip) {
std::cout << "(" << language << ") " << text << std::endl;
std::cout << "(" << config.language << ") " << text << std::endl;
return text;
}
// I am optimizing heavily for the above case. This function always returns a reference so a trick is needed here
static std::string fres;
fres = text;
// Replace bot username with [43]
str_replace_in_place(fres, bot.me.username, "[43]");
utils::str_replace_in_place(fres, bot.me.username, "[43]");
// Run translation
fres = translator->translate(fres, "EN", show_console_progress);
// Replace [43] back with bot username
str_replace_in_place(fres, "[43]", bot.me.username);
utils::str_replace_in_place(fres, "[43]", bot.me.username);
std::cout << text << " --> (EN) " << fres << std::endl;
return fres;
}
@ -169,19 +138,19 @@ private:
ENSURE_LLM_THREAD();
// Skip if there is no translator
if (translator == nullptr || skip) {
std::cout << "(" << language << ") " << text << std::endl;
std::cout << "(" << config.language << ") " << text << std::endl;
return text;
}
// I am optimizing heavily for the above case. This function always returns a reference so a trick is needed here
static std::string fres;
fres = text;
// Replace bot username with [43]
str_replace_in_place(fres, bot.me.username, "[43]");
utils::str_replace_in_place(fres, bot.me.username, "[43]");
// Run translation
fres = translator->translate(fres, language, show_console_progress);
fres = translator->translate(fres, config.language, show_console_progress);
// Replace [43] back with bot username
str_replace_in_place(fres, "[43]", bot.me.username);
std::cout << text << " --> (" << language << ") " << fres << std::endl;
utils::str_replace_in_place(fres, "[43]", bot.me.username);
std::cout << text << " --> (" << config.language << ") " << fres << std::endl;
return fres;
}
@ -332,7 +301,7 @@ private:
auto& inference = llm_get_inference(msg.channel_id, channel_cfg);
std::string prefix;
// Define callback for console progress and timeout
Timer timeout;
utils::Timer timeout;
bool timeout_exceeded = false;
const auto cb = [&] (float progress) {
if (timeout.get<std::chrono::minutes>() > 1) {
@ -348,7 +317,7 @@ private:
inference.append("\n\n"+std::string(llm_translate_to_en(msg.content, channel_cfg.model_config->no_translate))+'\n', cb);
} else {
// Format and append lines
for (const auto line : str_split(msg.content, '\n')) {
for (const auto line : utils::str_split(msg.content, '\n')) {
inference.append(msg.author.username+": "+std::string(llm_translate_to_en(line, channel_cfg.model_config->no_translate))+'\n', cb);
}
}
@ -374,8 +343,8 @@ private:
// Trigger LLM correctly
prompt_add_trigger(inference, channel_cfg);
// Run model
Timer timeout;
Timer edit_timer;
utils::Timer timeout;
utils::Timer edit_timer;
bool timeout_exceeded = false;
msg.content.clear();
auto output = inference.run(channel_cfg.instruct_mode?channel_cfg.model_config->user_prompt:"\n", [&] (std::string_view token) {
@ -444,6 +413,32 @@ private:
return (unsigned(id.get_creation_time()) % config.shard_count) == config.shard_id;
}
std::string create_thread_name(const std::string& model_name, bool instruct_mode) const {
return "Chat with "+model_name+" " // Model name
+(instruct_mode?"":"(Non Instruct mode) ") // Instruct mode
+'#'+(config.shard_count!=1?std::to_string(config.shard_id):""); // Shard ID
}
dpp::embed create_chat_embed(dpp::snowflake guild_id, dpp::snowflake thread_id, const std::string& model_name, bool instruct_mode, const dpp::user& author, const std::string& first_message = "") const {
dpp::embed embed;
// Create embed
embed.set_title(create_thread_name(model_name, instruct_mode))
.set_description("[Open the chat](https://discord.com/channels/"+std::to_string(guild_id)+'/'+std::to_string(thread_id+')'))
.set_footer(dpp::embed_footer().set_text("Started by "+author.format_username()))
.set_color(utils::get_unique_color(model_name));
// Add first message if any
if (!first_message.empty()) {
// Make sure it's not too long
std::string shorted(utils::max_words(first_message, 12));
if (shorted.size() != first_message.size()) {
shorted += "...";
}
embed.add_field("", shorted);
}
// Return final result
return embed;
}
// This function is responsible for sharding thread creation
// A bit ugly but a nice way to avoid having to communicate over any other means than just the Discord API
void command_completion_handler(dpp::slashcommand_t&& event, dpp::channel *thread = nullptr) {
@ -465,15 +460,13 @@ private:
const auto& [model_name, model_config] = *res;
// Get weather to enable instruct mode
bool instruct_mode;
const auto& instruct_mode_param = event.get_parameter("instruct_mode");
if (model_config.instruct_mode_policy == ModelConfig::InstructModePolicy::Allow) {
{
const auto& instruct_mode_param = event.get_parameter("instruct_mode");
if (instruct_mode_param.index()) {
instruct_mode = std::get<bool>(instruct_mode_param);
} else {
instruct_mode = true;
}
} else {
instruct_mode = model_config.instruct_mode_policy == ModelConfig::InstructModePolicy::Force;
}
// Create thread if it doesn't exist or update it if it does
if (thread == nullptr) {
@ -497,17 +490,27 @@ private:
if (!on_own_shard(thread->id)) return;
// Set name
std::cout << "Responsible for finalizing thread: " << thread->id << std::endl;
thread->name = "Chat with "+model_name+" " // Model name
+(instruct_mode?"":"(Non Instruct mode) ") // Instruct mode
+'#'+(config.shard_count!=1?std::to_string(config.shard_id):""); // Shard ID
thread->name = create_thread_name(model_name, instruct_mode);
bot.channel_edit(*thread);
// Send embed
const auto embed = create_chat_embed(event.command.guild_id, thread->id, model_name, instruct_mode, event.command.usr);
bot.message_create(dpp::message(event.command.channel_id, embed),
[this, thread_id = thread->id] (const dpp::confirmation_callback_t& ccb) {
// Check for error
if (ccb.is_error()) {
std::cerr << "Warning: Failed to create embed: " << ccb.get_error().message << std::endl;
return;
}
// Add to embed list
thread_embeds[thread_id] = ccb.get<dpp::message>().id;
});
}
}
public:
Bot(decltype(config) cfg, decltype(model_configs) model_configs)
: config(cfg), model_configs(model_configs), bot(cfg.token),
language(cfg.language), db("database.sqlite3"),
: config(cfg), model_configs(model_configs),
bot(cfg.token), db("database.sqlite3"),
llm_pool(cfg.pool_size, "discord_llama", !cfg.persistance) {
// Initialize database
db << "CREATE TABLE IF NOT EXISTS threads ("
@ -524,7 +527,7 @@ public:
thread_pool.init();
// Prepare translator
if (language != "EN") {
if (cfg.language != "EN") {
thread_pool.submit([this] () {
std::cout << "Preparing translator..." << std::endl;
translator = std::make_unique<Translator>(config.translation_model_cfg->weight_path, llm_get_translation_params());
@ -565,7 +568,7 @@ public:
// Check that this is for thread creation
if (event.msg.type != dpp::mt_thread_created) return;
// Get thread that was created
bot.channel_get(event.msg.id, [this] (const dpp::confirmation_callback_t& ccb) {
bot.channel_get(event.msg.id, [this, msg_id = event.msg.id, channel_id = event.msg.channel_id] (const dpp::confirmation_callback_t& ccb) {
// Stop on error
if (ccb.is_error()) return;
// Get thread
@ -587,6 +590,8 @@ public:
command_completion_handler(std::move(res->second), &thread);
// Remove command from buffer
command_completion_buffer.erase(res);
// Delete this message
bot.message_delete(msg_id, channel_id);
});
});
bot.on_message_create([=, this] (const dpp::message_create_t& event) {
@ -616,10 +621,10 @@ public:
return;
}
// Replace bot mentions with bot username
str_replace_in_place(msg.content, "<@"+std::to_string(bot.me.id)+'>', bot.me.username);
utils::str_replace_in_place(msg.content, "<@"+std::to_string(bot.me.id)+'>', bot.me.username);
// Replace all other known users
for (const auto& [user_id, user] : users) {
str_replace_in_place(msg.content, "<@"+std::to_string(user_id)+'>', user.username);
utils::str_replace_in_place(msg.content, "<@"+std::to_string(user_id)+'>', user.username);
}
// Get channel config
BotChannelConfig channel_cfg;
@ -665,6 +670,18 @@ public:
} else {
attempt_reply(msg, channel_cfg);
}
// Find thread embed
std::scoped_lock L(thread_embeds_mutex);
auto res = thread_embeds.find(msg.channel_id);
if (res == thread_embeds.end()) {
return;
}
// Update that embed
const auto embed = create_chat_embed(msg.guild_id, msg.channel_id, *channel_cfg.model_name, channel_cfg.instruct_mode, msg.author, msg.content);
dpp::message embed_msg;
embed_msg.id = res->second;
embed_msg.add_embed(embed);
bot.message_edit(embed_msg);
} catch (const std::exception& e) {
std::cerr << "Warning: " << e.what() << std::endl;
}
@ -731,10 +748,10 @@ int main(int argc, char **argv) {
cfg.language = std::move(value);
} else if (key == "default_inference_model") {
cfg.default_inference_model = std::move(value);
clean_for_command_name(cfg.default_inference_model);
utils::clean_for_command_name(cfg.default_inference_model);
} else if (key == "translation_model") {
cfg.translation_model = std::move(value);
clean_for_command_name(cfg.translation_model);
utils::clean_for_command_name(cfg.translation_model);
} else if (key == "prompt_file") {
cfg.prompt_file = std::move(value);
} else if (key == "instruct_prompt_file") {
@ -780,7 +797,7 @@ int main(int argc, char **argv) {
// Get model name
auto model_name = file.path().filename().string();
model_name.erase(model_name.size()-4, 4);
clean_for_command_name(model_name);
utils::clean_for_command_name(model_name);
// Parse model config
Bot::ModelConfig model_cfg;
std::ifstream cfgf(file.path());

54
utils.cpp Normal file
View file

@ -0,0 +1,54 @@
#include "utils.hpp"
namespace utils {
std::vector<std::string_view> str_split(std::string_view s, char delimiter, size_t times) {
std::vector<std::string_view> to_return;
decltype(s.size()) start = 0, finish = 0;
while ((finish = s.find_first_of(delimiter, start)) != std::string_view::npos) {
to_return.emplace_back(s.substr(start, finish - start));
start = finish + 1;
if (to_return.size() == times) { break; }
}
to_return.emplace_back(s.substr(start));
return to_return;
}
void str_replace_in_place(std::string& subject, std::string_view search,
const std::string& replace) {
if (search.empty()) return;
size_t pos = 0;
while ((pos = subject.find(search, pos)) != std::string::npos) {
subject.replace(pos, search.length(), replace);
pos += replace.length();
}
}
void clean_for_command_name(std::string& value) {
for (auto& c : value) {
if (c == '.') c = '_';
if (isalpha(c)) c = tolower(c);
}
}
std::string_view max_words(std::string_view text, unsigned count) {
unsigned word_len = 0,
word_count = 0,
idx;
// Get idx after last word
for (idx = 0; idx != text.size() && word_count != count; idx++) {
char c = text[idx];
if (c == ' ' || word_len == 7) {
if (word_len != 0) {
word_count++;
word_len = 0;
}
} else {
word_len++;
}
}
// Return resulting string
return {text.data()+idx, text.size()-idx};
}
}

65
utils.hpp Normal file
View file

@ -0,0 +1,65 @@
#ifndef UTILS_HPP
#define UTILS_HPP
#include <string>
#include <string_view>
#include <initializer_list>
#include <vector>
#include <chrono>
namespace utils {
class Timer {
std::chrono::time_point<std::chrono::high_resolution_clock> value;
public:
Timer() {
reset();
}
void reset() {
value = std::chrono::high_resolution_clock::now();
}
template<typename Unit = std::chrono::milliseconds>
auto get() {
auto duration = std::chrono::duration_cast<Unit>(std::chrono::high_resolution_clock::now() - value);
return duration.count();
}
};
std::vector<std::string_view> str_split(std::string_view s, char delimiter, size_t times = -1);
void str_replace_in_place(std::string& subject, std::string_view search, const std::string& replace);
void clean_for_command_name(std::string& value);
std::string_view max_words(std::string_view text, unsigned count);
inline
uint32_t get_unique_color(const auto& input) {
auto i = std::hash<typename std::remove_const<typename std::remove_reference<decltype(input)>::type>::type>{}(input);
const std::initializer_list<uint32_t> colors = {
0xf44336,
0xe91e63,
0x9c27b0,
0x673ab7,
0x3f51b5,
0x2196f3,
0x03a9f4,
0x00bcd4,
0x009688,
0x4caf50,
0x8bc34a,
0xcddc39,
0xffeb3b,
0xffc107,
0xff9800,
0xff5722,
0x795548,
0xcfd8dc
};
return *(colors.begin()+(i%colors.size()));
}
}
#endif // UTILS_HPP