1
0
Fork 0
mirror of https://gitlab.com/niansa/discord_llama.git synced 2025-03-06 20:48:25 +01:00

Implemented inference pooling

This commit is contained in:
niansa 2023-04-23 15:32:20 +02:00
parent 18de66f8ab
commit 9b626a049a
3 changed files with 104 additions and 113 deletions

@ -1 +1 @@
Subproject commit b35cf5b1139b097e27995bef1050625521c73a95
Subproject commit bee9170bcafc28adf8183333b44c93bbd8a64242

View file

@ -1,7 +1,9 @@
token MTA0MDYxMTQzNjUwNzk1OTMyNw.Gl_iMU.jVVM3bRqBJVi8ORVpWHquOivlASGJpRySt8qFg
channel 1099461766888574989
# The following parameters are set to their defaults here and can be ommited
language EN
inference_model 13B-ggml-model-quant.bin
translation_model 13B-ggml-model-quant.bin
mlock false
pool_size 2
threads 4

211
main.cpp
View file

@ -11,10 +11,12 @@
#include <array>
#include <vector>
#include <unordered_map>
#include <sstream>
#include <mutex>
#include <memory>
#include <dpp/dpp.h>
#include <justlm.hpp>
#include <justlm_pool.hpp>
#include <anyproc.hpp>
#include <ThreadPool.h>
@ -48,18 +50,14 @@ class Bot {
ThreadPool tPool{1};
Timer last_message_timer;
std::shared_ptr<bool> stopping;
LM::Inference llm;
LM::InferencePool llm_pool;
Translator translator;
LM::Inference::Savestate start_sv;
std::vector<dpp::snowflake> my_messages;
std::unordered_map<dpp::snowflake, dpp::user> users;
std::mutex llm_lock;
std::thread::id llm_tid;
std::string_view language;
dpp::cluster bot;
dpp::channel channel;
dpp::snowflake channel_id;
struct Texts {
std::string please_wait = "Please wait...",
@ -125,8 +123,7 @@ class Bot {
fres = translator.translate(fres, "EN", show_console_progress);
} catch (const LM::Inference::ContextLengthException&) {
// Handle potential context overflow error
llm_restart();
return llm_translate_to_en(text);
return "(Translation impossible)";
}
// Replace [43] back with bot username
str_replace_in_place(fres, "[43]", bot.me.username);
@ -152,8 +149,7 @@ class Bot {
fres = translator.translate(fres, language, show_console_progress);
} catch (const LM::Inference::ContextLengthException&) {
// Handle potential context overflow error
llm_restart();
return llm_translate_from_en(text);
return "(Translation impossible)";
}
// Replace [43] back with bot username
str_replace_in_place(fres, "[43]", bot.me.username);
@ -161,20 +157,47 @@ class Bot {
return fres;
}
constexpr static LM::Inference::Params llm_get_params(bool mlock) {
LM::Inference::Params llm_get_translation_params() const {
auto fres = translator.get_params();
fres.n_threads = config.threads;
fres.use_mlock = config.mlock;
return fres;
}
LM::Inference::Params llm_get_params() const {
return {
.n_threads = int(config.threads),
.n_ctx = 1012,
.n_repeat_last = 256,
.temp = 0.3f,
.repeat_penalty = 1.372222224f,
.use_mlock = mlock
.use_mlock = config.mlock
};
}
// Must run in llama thread
void llm_restart() {
void llm_restart(LM::Inference& inference) {
// Deserialize init cache
std::ifstream f("init_cache", std::ios::binary);
inference.deserialize(f);
}
// Must run in llama thread
LM::Inference &llm_restart(dpp::snowflake id) {
ENSURE_LLM_THREAD();
llm.restore_savestate(start_sv);
// Get or create inference
auto& inference = llm_pool.get_or_create_inference(id, config.inference_model, llm_get_params());
llm_restart(inference);
return inference;
}
// Must run in llama thread
LM::Inference &llm_get_inference(dpp::snowflake id) {
ENSURE_LLM_THREAD();
auto inference_opt = llm_pool.get_inference(id);
if (!inference_opt.has_value()) {
// Start new inference
inference_opt = llm_restart(id);
}
return inference_opt.value();
}
// Must run in llama thread
@ -189,64 +212,43 @@ class Bot {
texts.timeout = llm_translate_from_en(texts.timeout);
texts.translated = true;
}
// Create message for reporting progress
dpp::message msg(channel_id, texts.initializing);
bot.message_create(msg, [this] (const dpp::confirmation_callback_t& cbt) {
// Error check
if (cbt.is_error()) {
throw std::runtime_error("Failed to send message to channel: "+cbt.get_error().message);
}
// Callback for reporting progress
Timer timer;
auto msg = cbt.get<dpp::message>();
uint8_t last_progress = 0;
auto cb = [&, this] (float progress) mutable {
uint8_t progress_i = progress;
if (timer.get<std::chrono::seconds>() > 5) {
msg.content = texts.loading+" "+create_text_progress_indicator(progress_i)+
" **"+std::to_string(progress_i)+"**% (**"+std::to_string((progress_i-last_progress)/5)+"**%/s)";
last_progress = progress_i;
bot.message_edit(msg);
timer.reset();
}
return true;
};
// Add initial context
std::unique_lock L(llm_lock);
llm.append("History of the #"+channel.name+" channel.\n"
"Note 1: "+bot.me.username+" is a friendly chatbot that is always happy to talk. He is friendly and helpful and always answers immediately. He has a good sense of humor and likes everyone. His age is unknown.\n"
"Note 2: Ecki's real name is Eckhard Kohlhuber and he comes from Bavaria.\n"
"\n"
"Bob: "+bot.me.username+" have you ever been to France and if yes where?\n"
+bot.me.username+": I was in Paris, in the museums!\n"
"Bob: "+bot.me.username+" what are you exactly?\n"
+bot.me.username+": I am "+bot.me.username+", your chatbot! I can answer questions and increase the activity of the server.\n"
"Bob: Shall we talk about sex? "+bot.me.username+"?\n"
+bot.me.username+": No! I will **not** talk about any NSFW topics.\n"
"Bob: "+bot.me.username+" How are you?\n"
+bot.me.username+": I am quite well! :-)\n"
"Ecki: Hey "+bot.me.username+", what is 10 times 90??\n"
+bot.me.username+": that is 900!\n", cb);
// Delete progress message
bot.message_delete(msg.id, msg.channel_id);
// Create savestate
llm.create_savestate(start_sv);
});
// Inference for init cache TODO: Don't recreate on each startup
LM::Inference llm(config.inference_model, llm_get_params());
std::ofstream f("init_cache", std::ios::binary);
// Add initial context
llm.append("History of the discord server.\n"
"Note 1: "+bot.me.username+" is a friendly chatbot that is always happy to talk. He is friendly and helpful and always answers immediately. He has a good sense of humor and likes everyone. His age is unknown.\n"
"Note 2: Ecki's real name is Eckhard Kohlhuber and he comes from Bavaria.\n" // Little easter egg
"\n"
"This is the #meta channel.\n"
"Bob: "+bot.me.username+" have you ever been to France and if yes where?\n"
+bot.me.username+": I was in Paris, in the museums!\n"
"Bob: "+bot.me.username+" what are you exactly?\n"
+bot.me.username+": I am "+bot.me.username+", your chatbot! I can answer questions and increase the activity of the server.\n"
"Bob: Shall we talk about sex? "+bot.me.username+"?\n"
+bot.me.username+": No! I will **not** talk about any NSFW topics.\n"
"Bob: "+bot.me.username+" How are you?\n"
+bot.me.username+": I am quite well! :-)\n"
"Ecki: Hey "+bot.me.username+", what is 10 times 90??\n"
+bot.me.username+": that is 900!\n", show_console_progress);
// Serialize end result
llm.serialize(f);
}
// Must run in llama thread
void prompt_add_msg(const dpp::message& msg) {
ENSURE_LLM_THREAD();
// Make sure message isn't too long
if (msg.content.size() > 512) {
return;
}
// Get inference
auto& inference = llm_get_inference(msg.channel_id);
try {
// Make sure message isn't too long
if (msg.content.size() > 512) {
return;
}
// Format and append line
std::unique_lock L(llm_lock);
// Format and append lines
for (const auto line : str_split(msg.content, '\n')) {
Timer timeout;
bool timeout_exceeded = false;
llm.append(msg.author.username+": "+std::string(llm_translate_to_en(line))+'\n', [&] (float progress) {
inference.append(msg.author.username+": "+std::string(llm_translate_to_en(line))+'\n', [&] (float progress) {
if (timeout.get<std::chrono::minutes>() > 1) {
std::cerr << "\nWarning: Timeout exceeded processing message" << std::endl;
timeout_exceeded = true;
@ -254,37 +256,40 @@ class Bot {
}
return show_console_progress(progress);
});
if (timeout_exceeded) llm.append("\n");
if (timeout_exceeded) inference.append("\n");
}
} catch (const LM::Inference::ContextLengthException&) {
llm_restart();
llm_restart(inference);
prompt_add_msg(msg);
}
}
// Must run in llama thread
void prompt_add_trigger() {
void prompt_add_trigger(dpp::snowflake id) {
ENSURE_LLM_THREAD();
auto& inference = llm_get_inference(id);
try {
std::unique_lock L(llm_lock);
llm.append(bot.me.username+':', show_console_progress);
inference.append(bot.me.username+':', show_console_progress);
} catch (const LM::Inference::ContextLengthException&) {
llm_restart();
llm_restart(inference);
}
}
// Must run in llama thread
void reply(const std::function<void ()>& after_placeholder_creation = nullptr) {
void reply(dpp::snowflake id, const std::function<void ()>& after_placeholder_creation = nullptr) {
ENSURE_LLM_THREAD();
try {
// Create placeholder message
auto msg = bot.message_create_sync(dpp::message(channel_id, texts.please_wait+" :thinking:"));
auto msg = bot.message_create_sync(dpp::message(id, texts.please_wait+" :thinking:"));
// Call after_placeholder_creation callback
if (after_placeholder_creation) after_placeholder_creation();
// Trigger LLM correctly
prompt_add_trigger();
prompt_add_trigger(id);
// Get inference
auto& inference = llm_get_inference(id);
// Run model
Timer timeout;
bool timeout_exceeded = false;
auto output = llm.run("\n", [&] (std::string_view str) {
auto output = inference.run("\n", [&] (std::string_view str) {
std::cout << str << std::flush;
if (timeout.get<std::chrono::minutes>() > 2) {
timeout_exceeded = true;
@ -295,7 +300,7 @@ class Bot {
});
std::cout << std::endl;
if (timeout_exceeded) {
llm.append("\n");
inference.append("\n");
output = texts.timeout;
}
// Send resulting message
@ -311,13 +316,13 @@ class Bot {
ENSURE_LLM_THREAD();
// Reply if message contains username, mention or ID
if (msg.content.find(bot.me.username) != std::string::npos) {
reply(after_placeholder_creation);
reply(msg.channel_id, after_placeholder_creation);
return true;
}
// Reply if message references user
for (const auto msg_id : my_messages) {
if (msg.message_reference.message_id == msg_id) {
reply(after_placeholder_creation);
reply(msg.channel_id, after_placeholder_creation);
return true;
}
}
@ -325,61 +330,43 @@ class Bot {
return false;
}
void enqueue_reply() {
tPool.submit(std::bind(&Bot::reply, this, nullptr));
}
void idle_auto_reply() {
auto s = stopping;
do {
// Wait for a bit
std::this_thread::sleep_for(std::chrono::minutes(5));
// Check if last message was more than 20 minutes ago
if (last_message_timer.get<std::chrono::hours>() > 3) {
// Force reply
enqueue_reply();
}
} while (!*s);
void enqueue_reply(dpp::snowflake id) {
tPool.submit(std::bind(&Bot::reply, this, id, nullptr));
}
public:
struct Configuration {
std::string token,
channel,
language = "EN",
inference_model = "13B-ggml-model-quant.bin",
translation_model = "13B-ggml-model-quant.bin";
unsigned pool_size = 2,
threads = 4;
bool mlock = false;
} config;
Bot(const Configuration& cfg) : config(cfg), bot(cfg.token), channel_id(cfg.channel), language(cfg.language),
llm(cfg.inference_model, llm_get_params(cfg.mlock)), translator(cfg.translation_model) {
Bot(const Configuration& cfg) : config(cfg), bot(cfg.token), language(cfg.language),
llm_pool(cfg.pool_size, "discord_llama", false), translator(cfg.translation_model, llm_get_translation_params()) {
// Configure llm_pool
llm_pool.set_store_on_destruct(true);
// Initialize thread pool
tPool.init();
// Prepare llm
tPool.submit(std::bind(&Bot::llm_init, this));
// Configure bot
bot.on_log(dpp::utility::cout_logger());
bot.intents = dpp::i_guild_messages | dpp::i_message_content;
// Set callbacks
bot.on_ready([=, this] (const dpp::ready_t&) {
// Get channel
bot.channel_get(channel_id, [=, this] (const dpp::confirmation_callback_t& cbt) {
if (cbt.is_error()) {
throw std::runtime_error("Failed to get channel: "+cbt.get_error().message);
}
channel = cbt.get<dpp::channel>();
// Append initial prompt
tPool.submit(std::bind(&Bot::llm_init, this));
// Start idle auto reply thread
std::thread(std::bind(&Bot::idle_auto_reply, this)).detach();
});
bot.on_ready([=, this] (const dpp::ready_t&) { //TODO: Consider removal
std::cout << "Connected to Discord." << std::endl;
});
bot.on_message_create([=, this] (const dpp::message_create_t& event) {
// Update user cache
users[event.msg.author.id] = event.msg.author;
// Make sure message source is correct
if (event.msg.channel_id != channel_id) return;
// Make sure message has content
if (event.msg.content.empty()) return;
// Reset last message timer
@ -404,7 +391,7 @@ public:
// Delete message
bot.message_delete(msg.id, msg.channel_id);
// Send a reply
reply();
reply(msg.channel_id);
} else {
tPool.submit([=, this] () {
// Attempt to send a reply
@ -457,14 +444,16 @@ int main(int argc, char **argv) {
// Check key and ignore comment lines
if (key == "token") {
cfg.token = std::move(value);
} else if (key == "channel") {
cfg.channel = std::move(value);
} else if (key == "language") {
cfg.language = std::move(value);
} else if (key == "inference_model") {
cfg.inference_model = std::move(value);
} else if (key == "translation_model") {
cfg.translation_model = std::move(value);
} else if (key == "pool_size") {
cfg.pool_size = std::stoi(value);
} else if (key == "threads") {
cfg.threads = std::stoi(value);
} else if (key == "mlock") {
cfg.mlock = (value=="true")?true:false;
} else if (!key.empty() && key[0] != '#') {