mirror of
https://gitlab.com/niansa/discord_llama.git
synced 2025-03-06 20:48:25 +01:00
Reorganization and added experimental thread embeds
This commit is contained in:
parent
f113ae5a81
commit
dfcdcc7c1b
5 changed files with 200 additions and 89 deletions
|
@ -11,7 +11,7 @@ add_subdirectory(DPP)
|
|||
add_subdirectory(thread-pool)
|
||||
add_subdirectory(fmt)
|
||||
|
||||
add_executable(discord_llama main.cpp)
|
||||
add_executable(discord_llama main.cpp utils.cpp utils.hpp)
|
||||
target_link_libraries(discord_llama PUBLIC dpp fmt pthread libjustlm anyproc ggml threadpool sqlite3)
|
||||
|
||||
install(TARGETS discord_llama
|
||||
|
|
25
Timer.hpp
25
Timer.hpp
|
@ -1,25 +0,0 @@
|
|||
#ifndef _PHASMOENGINE_TIMER_HPP
|
||||
#define _PHASMOENGINE_TIMER_HPP
|
||||
#include <chrono>
|
||||
|
||||
|
||||
|
||||
class Timer {
|
||||
std::chrono::time_point<std::chrono::high_resolution_clock> value;
|
||||
|
||||
public:
|
||||
Timer() {
|
||||
reset();
|
||||
}
|
||||
|
||||
void reset() {
|
||||
value = std::chrono::high_resolution_clock::now();
|
||||
}
|
||||
|
||||
template<typename Unit = std::chrono::milliseconds>
|
||||
auto get() {
|
||||
auto duration = std::chrono::duration_cast<Unit>(std::chrono::high_resolution_clock::now() - value);
|
||||
return duration.count();
|
||||
}
|
||||
};
|
||||
#endif
|
143
main.cpp
143
main.cpp
|
@ -1,4 +1,4 @@
|
|||
#include "Timer.hpp"
|
||||
#include "utils.hpp"
|
||||
#include "sqlite_modern_cpp/sqlite_modern_cpp.h"
|
||||
|
||||
#include <string>
|
||||
|
@ -25,39 +25,6 @@
|
|||
|
||||
|
||||
|
||||
static
|
||||
std::vector<std::string_view> str_split(std::string_view s, char delimiter, size_t times = -1) {
|
||||
std::vector<std::string_view> to_return;
|
||||
decltype(s.size()) start = 0, finish = 0;
|
||||
while ((finish = s.find_first_of(delimiter, start)) != std::string_view::npos) {
|
||||
to_return.emplace_back(s.substr(start, finish - start));
|
||||
start = finish + 1;
|
||||
if (to_return.size() == times) { break; }
|
||||
}
|
||||
to_return.emplace_back(s.substr(start));
|
||||
return to_return;
|
||||
}
|
||||
|
||||
static
|
||||
void str_replace_in_place(std::string& subject, std::string_view search,
|
||||
const std::string& replace) {
|
||||
if (search.empty()) return;
|
||||
size_t pos = 0;
|
||||
while ((pos = subject.find(search, pos)) != std::string::npos) {
|
||||
subject.replace(pos, search.length(), replace);
|
||||
pos += replace.length();
|
||||
}
|
||||
}
|
||||
|
||||
static
|
||||
void clean_for_command_name(std::string& value) {
|
||||
for (auto& c : value) {
|
||||
if (c == '.') c = '_';
|
||||
if (isalpha(c)) c = tolower(c);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class Bot {
|
||||
ThreadPool thread_pool{1};
|
||||
std::shared_ptr<bool> stopping;
|
||||
|
@ -71,7 +38,9 @@ class Bot {
|
|||
std::mutex command_completion_buffer_mutex;
|
||||
std::unordered_map<dpp::snowflake, dpp::slashcommand_t> command_completion_buffer;
|
||||
|
||||
std::string_view language;
|
||||
std::mutex thread_embeds_mutex;
|
||||
std::unordered_map<dpp::snowflake, dpp::snowflake> thread_embeds;
|
||||
|
||||
dpp::cluster bot;
|
||||
|
||||
public:
|
||||
|
@ -148,18 +117,18 @@ private:
|
|||
ENSURE_LLM_THREAD();
|
||||
// Skip if there is no translator
|
||||
if (translator == nullptr || skip) {
|
||||
std::cout << "(" << language << ") " << text << std::endl;
|
||||
std::cout << "(" << config.language << ") " << text << std::endl;
|
||||
return text;
|
||||
}
|
||||
// I am optimizing heavily for the above case. This function always returns a reference so a trick is needed here
|
||||
static std::string fres;
|
||||
fres = text;
|
||||
// Replace bot username with [43]
|
||||
str_replace_in_place(fres, bot.me.username, "[43]");
|
||||
utils::str_replace_in_place(fres, bot.me.username, "[43]");
|
||||
// Run translation
|
||||
fres = translator->translate(fres, "EN", show_console_progress);
|
||||
// Replace [43] back with bot username
|
||||
str_replace_in_place(fres, "[43]", bot.me.username);
|
||||
utils::str_replace_in_place(fres, "[43]", bot.me.username);
|
||||
std::cout << text << " --> (EN) " << fres << std::endl;
|
||||
return fres;
|
||||
}
|
||||
|
@ -169,19 +138,19 @@ private:
|
|||
ENSURE_LLM_THREAD();
|
||||
// Skip if there is no translator
|
||||
if (translator == nullptr || skip) {
|
||||
std::cout << "(" << language << ") " << text << std::endl;
|
||||
std::cout << "(" << config.language << ") " << text << std::endl;
|
||||
return text;
|
||||
}
|
||||
// I am optimizing heavily for the above case. This function always returns a reference so a trick is needed here
|
||||
static std::string fres;
|
||||
fres = text;
|
||||
// Replace bot username with [43]
|
||||
str_replace_in_place(fres, bot.me.username, "[43]");
|
||||
utils::str_replace_in_place(fres, bot.me.username, "[43]");
|
||||
// Run translation
|
||||
fres = translator->translate(fres, language, show_console_progress);
|
||||
fres = translator->translate(fres, config.language, show_console_progress);
|
||||
// Replace [43] back with bot username
|
||||
str_replace_in_place(fres, "[43]", bot.me.username);
|
||||
std::cout << text << " --> (" << language << ") " << fres << std::endl;
|
||||
utils::str_replace_in_place(fres, "[43]", bot.me.username);
|
||||
std::cout << text << " --> (" << config.language << ") " << fres << std::endl;
|
||||
return fres;
|
||||
}
|
||||
|
||||
|
@ -332,7 +301,7 @@ private:
|
|||
auto& inference = llm_get_inference(msg.channel_id, channel_cfg);
|
||||
std::string prefix;
|
||||
// Define callback for console progress and timeout
|
||||
Timer timeout;
|
||||
utils::Timer timeout;
|
||||
bool timeout_exceeded = false;
|
||||
const auto cb = [&] (float progress) {
|
||||
if (timeout.get<std::chrono::minutes>() > 1) {
|
||||
|
@ -348,7 +317,7 @@ private:
|
|||
inference.append("\n\n"+std::string(llm_translate_to_en(msg.content, channel_cfg.model_config->no_translate))+'\n', cb);
|
||||
} else {
|
||||
// Format and append lines
|
||||
for (const auto line : str_split(msg.content, '\n')) {
|
||||
for (const auto line : utils::str_split(msg.content, '\n')) {
|
||||
inference.append(msg.author.username+": "+std::string(llm_translate_to_en(line, channel_cfg.model_config->no_translate))+'\n', cb);
|
||||
}
|
||||
}
|
||||
|
@ -374,8 +343,8 @@ private:
|
|||
// Trigger LLM correctly
|
||||
prompt_add_trigger(inference, channel_cfg);
|
||||
// Run model
|
||||
Timer timeout;
|
||||
Timer edit_timer;
|
||||
utils::Timer timeout;
|
||||
utils::Timer edit_timer;
|
||||
bool timeout_exceeded = false;
|
||||
msg.content.clear();
|
||||
auto output = inference.run(channel_cfg.instruct_mode?channel_cfg.model_config->user_prompt:"\n", [&] (std::string_view token) {
|
||||
|
@ -444,6 +413,32 @@ private:
|
|||
return (unsigned(id.get_creation_time()) % config.shard_count) == config.shard_id;
|
||||
}
|
||||
|
||||
std::string create_thread_name(const std::string& model_name, bool instruct_mode) const {
|
||||
return "Chat with "+model_name+" " // Model name
|
||||
+(instruct_mode?"":"(Non Instruct mode) ") // Instruct mode
|
||||
+'#'+(config.shard_count!=1?std::to_string(config.shard_id):""); // Shard ID
|
||||
}
|
||||
|
||||
dpp::embed create_chat_embed(dpp::snowflake guild_id, dpp::snowflake thread_id, const std::string& model_name, bool instruct_mode, const dpp::user& author, const std::string& first_message = "") const {
|
||||
dpp::embed embed;
|
||||
// Create embed
|
||||
embed.set_title(create_thread_name(model_name, instruct_mode))
|
||||
.set_description("[Open the chat](https://discord.com/channels/"+std::to_string(guild_id)+'/'+std::to_string(thread_id+')'))
|
||||
.set_footer(dpp::embed_footer().set_text("Started by "+author.format_username()))
|
||||
.set_color(utils::get_unique_color(model_name));
|
||||
// Add first message if any
|
||||
if (!first_message.empty()) {
|
||||
// Make sure it's not too long
|
||||
std::string shorted(utils::max_words(first_message, 12));
|
||||
if (shorted.size() != first_message.size()) {
|
||||
shorted += "...";
|
||||
}
|
||||
embed.add_field("⠀", shorted);
|
||||
}
|
||||
// Return final result
|
||||
return embed;
|
||||
}
|
||||
|
||||
// This function is responsible for sharding thread creation
|
||||
// A bit ugly but a nice way to avoid having to communicate over any other means than just the Discord API
|
||||
void command_completion_handler(dpp::slashcommand_t&& event, dpp::channel *thread = nullptr) {
|
||||
|
@ -465,15 +460,13 @@ private:
|
|||
const auto& [model_name, model_config] = *res;
|
||||
// Get weather to enable instruct mode
|
||||
bool instruct_mode;
|
||||
const auto& instruct_mode_param = event.get_parameter("instruct_mode");
|
||||
if (model_config.instruct_mode_policy == ModelConfig::InstructModePolicy::Allow) {
|
||||
{
|
||||
const auto& instruct_mode_param = event.get_parameter("instruct_mode");
|
||||
if (instruct_mode_param.index()) {
|
||||
instruct_mode = std::get<bool>(instruct_mode_param);
|
||||
} else {
|
||||
instruct_mode = true;
|
||||
}
|
||||
} else {
|
||||
instruct_mode = model_config.instruct_mode_policy == ModelConfig::InstructModePolicy::Force;
|
||||
}
|
||||
// Create thread if it doesn't exist or update it if it does
|
||||
if (thread == nullptr) {
|
||||
|
@ -497,17 +490,27 @@ private:
|
|||
if (!on_own_shard(thread->id)) return;
|
||||
// Set name
|
||||
std::cout << "Responsible for finalizing thread: " << thread->id << std::endl;
|
||||
thread->name = "Chat with "+model_name+" " // Model name
|
||||
+(instruct_mode?"":"(Non Instruct mode) ") // Instruct mode
|
||||
+'#'+(config.shard_count!=1?std::to_string(config.shard_id):""); // Shard ID
|
||||
thread->name = create_thread_name(model_name, instruct_mode);
|
||||
bot.channel_edit(*thread);
|
||||
// Send embed
|
||||
const auto embed = create_chat_embed(event.command.guild_id, thread->id, model_name, instruct_mode, event.command.usr);
|
||||
bot.message_create(dpp::message(event.command.channel_id, embed),
|
||||
[this, thread_id = thread->id] (const dpp::confirmation_callback_t& ccb) {
|
||||
// Check for error
|
||||
if (ccb.is_error()) {
|
||||
std::cerr << "Warning: Failed to create embed: " << ccb.get_error().message << std::endl;
|
||||
return;
|
||||
}
|
||||
// Add to embed list
|
||||
thread_embeds[thread_id] = ccb.get<dpp::message>().id;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
Bot(decltype(config) cfg, decltype(model_configs) model_configs)
|
||||
: config(cfg), model_configs(model_configs), bot(cfg.token),
|
||||
language(cfg.language), db("database.sqlite3"),
|
||||
: config(cfg), model_configs(model_configs),
|
||||
bot(cfg.token), db("database.sqlite3"),
|
||||
llm_pool(cfg.pool_size, "discord_llama", !cfg.persistance) {
|
||||
// Initialize database
|
||||
db << "CREATE TABLE IF NOT EXISTS threads ("
|
||||
|
@ -524,7 +527,7 @@ public:
|
|||
thread_pool.init();
|
||||
|
||||
// Prepare translator
|
||||
if (language != "EN") {
|
||||
if (cfg.language != "EN") {
|
||||
thread_pool.submit([this] () {
|
||||
std::cout << "Preparing translator..." << std::endl;
|
||||
translator = std::make_unique<Translator>(config.translation_model_cfg->weight_path, llm_get_translation_params());
|
||||
|
@ -565,7 +568,7 @@ public:
|
|||
// Check that this is for thread creation
|
||||
if (event.msg.type != dpp::mt_thread_created) return;
|
||||
// Get thread that was created
|
||||
bot.channel_get(event.msg.id, [this] (const dpp::confirmation_callback_t& ccb) {
|
||||
bot.channel_get(event.msg.id, [this, msg_id = event.msg.id, channel_id = event.msg.channel_id] (const dpp::confirmation_callback_t& ccb) {
|
||||
// Stop on error
|
||||
if (ccb.is_error()) return;
|
||||
// Get thread
|
||||
|
@ -587,6 +590,8 @@ public:
|
|||
command_completion_handler(std::move(res->second), &thread);
|
||||
// Remove command from buffer
|
||||
command_completion_buffer.erase(res);
|
||||
// Delete this message
|
||||
bot.message_delete(msg_id, channel_id);
|
||||
});
|
||||
});
|
||||
bot.on_message_create([=, this] (const dpp::message_create_t& event) {
|
||||
|
@ -616,10 +621,10 @@ public:
|
|||
return;
|
||||
}
|
||||
// Replace bot mentions with bot username
|
||||
str_replace_in_place(msg.content, "<@"+std::to_string(bot.me.id)+'>', bot.me.username);
|
||||
utils::str_replace_in_place(msg.content, "<@"+std::to_string(bot.me.id)+'>', bot.me.username);
|
||||
// Replace all other known users
|
||||
for (const auto& [user_id, user] : users) {
|
||||
str_replace_in_place(msg.content, "<@"+std::to_string(user_id)+'>', user.username);
|
||||
utils::str_replace_in_place(msg.content, "<@"+std::to_string(user_id)+'>', user.username);
|
||||
}
|
||||
// Get channel config
|
||||
BotChannelConfig channel_cfg;
|
||||
|
@ -665,6 +670,18 @@ public:
|
|||
} else {
|
||||
attempt_reply(msg, channel_cfg);
|
||||
}
|
||||
// Find thread embed
|
||||
std::scoped_lock L(thread_embeds_mutex);
|
||||
auto res = thread_embeds.find(msg.channel_id);
|
||||
if (res == thread_embeds.end()) {
|
||||
return;
|
||||
}
|
||||
// Update that embed
|
||||
const auto embed = create_chat_embed(msg.guild_id, msg.channel_id, *channel_cfg.model_name, channel_cfg.instruct_mode, msg.author, msg.content);
|
||||
dpp::message embed_msg;
|
||||
embed_msg.id = res->second;
|
||||
embed_msg.add_embed(embed);
|
||||
bot.message_edit(embed_msg);
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Warning: " << e.what() << std::endl;
|
||||
}
|
||||
|
@ -731,10 +748,10 @@ int main(int argc, char **argv) {
|
|||
cfg.language = std::move(value);
|
||||
} else if (key == "default_inference_model") {
|
||||
cfg.default_inference_model = std::move(value);
|
||||
clean_for_command_name(cfg.default_inference_model);
|
||||
utils::clean_for_command_name(cfg.default_inference_model);
|
||||
} else if (key == "translation_model") {
|
||||
cfg.translation_model = std::move(value);
|
||||
clean_for_command_name(cfg.translation_model);
|
||||
utils::clean_for_command_name(cfg.translation_model);
|
||||
} else if (key == "prompt_file") {
|
||||
cfg.prompt_file = std::move(value);
|
||||
} else if (key == "instruct_prompt_file") {
|
||||
|
@ -780,7 +797,7 @@ int main(int argc, char **argv) {
|
|||
// Get model name
|
||||
auto model_name = file.path().filename().string();
|
||||
model_name.erase(model_name.size()-4, 4);
|
||||
clean_for_command_name(model_name);
|
||||
utils::clean_for_command_name(model_name);
|
||||
// Parse model config
|
||||
Bot::ModelConfig model_cfg;
|
||||
std::ifstream cfgf(file.path());
|
||||
|
|
54
utils.cpp
Normal file
54
utils.cpp
Normal file
|
@ -0,0 +1,54 @@
|
|||
#include "utils.hpp"
|
||||
|
||||
|
||||
|
||||
namespace utils {
|
||||
std::vector<std::string_view> str_split(std::string_view s, char delimiter, size_t times) {
|
||||
std::vector<std::string_view> to_return;
|
||||
decltype(s.size()) start = 0, finish = 0;
|
||||
while ((finish = s.find_first_of(delimiter, start)) != std::string_view::npos) {
|
||||
to_return.emplace_back(s.substr(start, finish - start));
|
||||
start = finish + 1;
|
||||
if (to_return.size() == times) { break; }
|
||||
}
|
||||
to_return.emplace_back(s.substr(start));
|
||||
return to_return;
|
||||
}
|
||||
|
||||
void str_replace_in_place(std::string& subject, std::string_view search,
|
||||
const std::string& replace) {
|
||||
if (search.empty()) return;
|
||||
size_t pos = 0;
|
||||
while ((pos = subject.find(search, pos)) != std::string::npos) {
|
||||
subject.replace(pos, search.length(), replace);
|
||||
pos += replace.length();
|
||||
}
|
||||
}
|
||||
|
||||
void clean_for_command_name(std::string& value) {
|
||||
for (auto& c : value) {
|
||||
if (c == '.') c = '_';
|
||||
if (isalpha(c)) c = tolower(c);
|
||||
}
|
||||
}
|
||||
|
||||
std::string_view max_words(std::string_view text, unsigned count) {
|
||||
unsigned word_len = 0,
|
||||
word_count = 0,
|
||||
idx;
|
||||
// Get idx after last word
|
||||
for (idx = 0; idx != text.size() && word_count != count; idx++) {
|
||||
char c = text[idx];
|
||||
if (c == ' ' || word_len == 7) {
|
||||
if (word_len != 0) {
|
||||
word_count++;
|
||||
word_len = 0;
|
||||
}
|
||||
} else {
|
||||
word_len++;
|
||||
}
|
||||
}
|
||||
// Return resulting string
|
||||
return {text.data()+idx, text.size()-idx};
|
||||
}
|
||||
}
|
65
utils.hpp
Normal file
65
utils.hpp
Normal file
|
@ -0,0 +1,65 @@
|
|||
#ifndef UTILS_HPP
|
||||
#define UTILS_HPP
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
#include <chrono>
|
||||
|
||||
|
||||
namespace utils {
|
||||
class Timer {
|
||||
std::chrono::time_point<std::chrono::high_resolution_clock> value;
|
||||
|
||||
public:
|
||||
Timer() {
|
||||
reset();
|
||||
}
|
||||
|
||||
void reset() {
|
||||
value = std::chrono::high_resolution_clock::now();
|
||||
}
|
||||
|
||||
template<typename Unit = std::chrono::milliseconds>
|
||||
auto get() {
|
||||
auto duration = std::chrono::duration_cast<Unit>(std::chrono::high_resolution_clock::now() - value);
|
||||
return duration.count();
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
std::vector<std::string_view> str_split(std::string_view s, char delimiter, size_t times = -1);
|
||||
|
||||
void str_replace_in_place(std::string& subject, std::string_view search, const std::string& replace);
|
||||
|
||||
void clean_for_command_name(std::string& value);
|
||||
|
||||
std::string_view max_words(std::string_view text, unsigned count);
|
||||
|
||||
inline
|
||||
uint32_t get_unique_color(const auto& input) {
|
||||
auto i = std::hash<typename std::remove_const<typename std::remove_reference<decltype(input)>::type>::type>{}(input);
|
||||
const std::initializer_list<uint32_t> colors = {
|
||||
0xf44336,
|
||||
0xe91e63,
|
||||
0x9c27b0,
|
||||
0x673ab7,
|
||||
0x3f51b5,
|
||||
0x2196f3,
|
||||
0x03a9f4,
|
||||
0x00bcd4,
|
||||
0x009688,
|
||||
0x4caf50,
|
||||
0x8bc34a,
|
||||
0xcddc39,
|
||||
0xffeb3b,
|
||||
0xffc107,
|
||||
0xff9800,
|
||||
0xff5722,
|
||||
0x795548,
|
||||
0xcfd8dc
|
||||
};
|
||||
return *(colors.begin()+(i%colors.size()));
|
||||
}
|
||||
}
|
||||
#endif // UTILS_HPP
|
Loading…
Add table
Reference in a new issue