mirror of
https://gitlab.com/niansa/discord_llama.git
synced 2025-03-06 20:48:25 +01:00
Implemented inference pooling
This commit is contained in:
parent
18de66f8ab
commit
9b626a049a
3 changed files with 104 additions and 113 deletions
2
anyproc
2
anyproc
|
@ -1 +1 @@
|
|||
Subproject commit b35cf5b1139b097e27995bef1050625521c73a95
|
||||
Subproject commit bee9170bcafc28adf8183333b44c93bbd8a64242
|
|
@ -1,7 +1,9 @@
|
|||
token MTA0MDYxMTQzNjUwNzk1OTMyNw.Gl_iMU.jVVM3bRqBJVi8ORVpWHquOivlASGJpRySt8qFg
|
||||
channel 1099461766888574989
|
||||
|
||||
# The following parameters are set to their defaults here and can be ommited
|
||||
language EN
|
||||
inference_model 13B-ggml-model-quant.bin
|
||||
translation_model 13B-ggml-model-quant.bin
|
||||
mlock false
|
||||
pool_size 2
|
||||
threads 4
|
||||
|
|
211
main.cpp
211
main.cpp
|
@ -11,10 +11,12 @@
|
|||
#include <array>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
#include <sstream>
|
||||
#include <mutex>
|
||||
#include <memory>
|
||||
#include <dpp/dpp.h>
|
||||
#include <justlm.hpp>
|
||||
#include <justlm_pool.hpp>
|
||||
#include <anyproc.hpp>
|
||||
#include <ThreadPool.h>
|
||||
|
||||
|
@ -48,18 +50,14 @@ class Bot {
|
|||
ThreadPool tPool{1};
|
||||
Timer last_message_timer;
|
||||
std::shared_ptr<bool> stopping;
|
||||
LM::Inference llm;
|
||||
LM::InferencePool llm_pool;
|
||||
Translator translator;
|
||||
LM::Inference::Savestate start_sv;
|
||||
std::vector<dpp::snowflake> my_messages;
|
||||
std::unordered_map<dpp::snowflake, dpp::user> users;
|
||||
std::mutex llm_lock;
|
||||
std::thread::id llm_tid;
|
||||
|
||||
std::string_view language;
|
||||
dpp::cluster bot;
|
||||
dpp::channel channel;
|
||||
dpp::snowflake channel_id;
|
||||
|
||||
struct Texts {
|
||||
std::string please_wait = "Please wait...",
|
||||
|
@ -125,8 +123,7 @@ class Bot {
|
|||
fres = translator.translate(fres, "EN", show_console_progress);
|
||||
} catch (const LM::Inference::ContextLengthException&) {
|
||||
// Handle potential context overflow error
|
||||
llm_restart();
|
||||
return llm_translate_to_en(text);
|
||||
return "(Translation impossible)";
|
||||
}
|
||||
// Replace [43] back with bot username
|
||||
str_replace_in_place(fres, "[43]", bot.me.username);
|
||||
|
@ -152,8 +149,7 @@ class Bot {
|
|||
fres = translator.translate(fres, language, show_console_progress);
|
||||
} catch (const LM::Inference::ContextLengthException&) {
|
||||
// Handle potential context overflow error
|
||||
llm_restart();
|
||||
return llm_translate_from_en(text);
|
||||
return "(Translation impossible)";
|
||||
}
|
||||
// Replace [43] back with bot username
|
||||
str_replace_in_place(fres, "[43]", bot.me.username);
|
||||
|
@ -161,20 +157,47 @@ class Bot {
|
|||
return fres;
|
||||
}
|
||||
|
||||
constexpr static LM::Inference::Params llm_get_params(bool mlock) {
|
||||
LM::Inference::Params llm_get_translation_params() const {
|
||||
auto fres = translator.get_params();
|
||||
fres.n_threads = config.threads;
|
||||
fres.use_mlock = config.mlock;
|
||||
return fres;
|
||||
}
|
||||
LM::Inference::Params llm_get_params() const {
|
||||
return {
|
||||
.n_threads = int(config.threads),
|
||||
.n_ctx = 1012,
|
||||
.n_repeat_last = 256,
|
||||
.temp = 0.3f,
|
||||
.repeat_penalty = 1.372222224f,
|
||||
.use_mlock = mlock
|
||||
.use_mlock = config.mlock
|
||||
};
|
||||
}
|
||||
|
||||
// Must run in llama thread
|
||||
void llm_restart() {
|
||||
void llm_restart(LM::Inference& inference) {
|
||||
// Deserialize init cache
|
||||
std::ifstream f("init_cache", std::ios::binary);
|
||||
inference.deserialize(f);
|
||||
}
|
||||
// Must run in llama thread
|
||||
LM::Inference &llm_restart(dpp::snowflake id) {
|
||||
ENSURE_LLM_THREAD();
|
||||
llm.restore_savestate(start_sv);
|
||||
// Get or create inference
|
||||
auto& inference = llm_pool.get_or_create_inference(id, config.inference_model, llm_get_params());
|
||||
llm_restart(inference);
|
||||
return inference;
|
||||
}
|
||||
|
||||
// Must run in llama thread
|
||||
LM::Inference &llm_get_inference(dpp::snowflake id) {
|
||||
ENSURE_LLM_THREAD();
|
||||
auto inference_opt = llm_pool.get_inference(id);
|
||||
if (!inference_opt.has_value()) {
|
||||
// Start new inference
|
||||
inference_opt = llm_restart(id);
|
||||
}
|
||||
return inference_opt.value();
|
||||
}
|
||||
|
||||
// Must run in llama thread
|
||||
|
@ -189,64 +212,43 @@ class Bot {
|
|||
texts.timeout = llm_translate_from_en(texts.timeout);
|
||||
texts.translated = true;
|
||||
}
|
||||
// Create message for reporting progress
|
||||
dpp::message msg(channel_id, texts.initializing);
|
||||
bot.message_create(msg, [this] (const dpp::confirmation_callback_t& cbt) {
|
||||
// Error check
|
||||
if (cbt.is_error()) {
|
||||
throw std::runtime_error("Failed to send message to channel: "+cbt.get_error().message);
|
||||
}
|
||||
// Callback for reporting progress
|
||||
Timer timer;
|
||||
auto msg = cbt.get<dpp::message>();
|
||||
uint8_t last_progress = 0;
|
||||
auto cb = [&, this] (float progress) mutable {
|
||||
uint8_t progress_i = progress;
|
||||
if (timer.get<std::chrono::seconds>() > 5) {
|
||||
msg.content = texts.loading+" "+create_text_progress_indicator(progress_i)+
|
||||
" **"+std::to_string(progress_i)+"**% (**"+std::to_string((progress_i-last_progress)/5)+"**%/s)";
|
||||
last_progress = progress_i;
|
||||
bot.message_edit(msg);
|
||||
timer.reset();
|
||||
}
|
||||
return true;
|
||||
};
|
||||
// Add initial context
|
||||
std::unique_lock L(llm_lock);
|
||||
llm.append("History of the #"+channel.name+" channel.\n"
|
||||
"Note 1: "+bot.me.username+" is a friendly chatbot that is always happy to talk. He is friendly and helpful and always answers immediately. He has a good sense of humor and likes everyone. His age is unknown.\n"
|
||||
"Note 2: Ecki's real name is Eckhard Kohlhuber and he comes from Bavaria.\n"
|
||||
"\n"
|
||||
"Bob: "+bot.me.username+" have you ever been to France and if yes where?\n"
|
||||
+bot.me.username+": I was in Paris, in the museums!\n"
|
||||
"Bob: "+bot.me.username+" what are you exactly?\n"
|
||||
+bot.me.username+": I am "+bot.me.username+", your chatbot! I can answer questions and increase the activity of the server.\n"
|
||||
"Bob: Shall we talk about sex? "+bot.me.username+"?\n"
|
||||
+bot.me.username+": No! I will **not** talk about any NSFW topics.\n"
|
||||
"Bob: "+bot.me.username+" How are you?\n"
|
||||
+bot.me.username+": I am quite well! :-)\n"
|
||||
"Ecki: Hey "+bot.me.username+", what is 10 times 90??\n"
|
||||
+bot.me.username+": that is 900!\n", cb);
|
||||
// Delete progress message
|
||||
bot.message_delete(msg.id, msg.channel_id);
|
||||
// Create savestate
|
||||
llm.create_savestate(start_sv);
|
||||
});
|
||||
// Inference for init cache TODO: Don't recreate on each startup
|
||||
LM::Inference llm(config.inference_model, llm_get_params());
|
||||
std::ofstream f("init_cache", std::ios::binary);
|
||||
// Add initial context
|
||||
llm.append("History of the discord server.\n"
|
||||
"Note 1: "+bot.me.username+" is a friendly chatbot that is always happy to talk. He is friendly and helpful and always answers immediately. He has a good sense of humor and likes everyone. His age is unknown.\n"
|
||||
"Note 2: Ecki's real name is Eckhard Kohlhuber and he comes from Bavaria.\n" // Little easter egg
|
||||
"\n"
|
||||
"This is the #meta channel.\n"
|
||||
"Bob: "+bot.me.username+" have you ever been to France and if yes where?\n"
|
||||
+bot.me.username+": I was in Paris, in the museums!\n"
|
||||
"Bob: "+bot.me.username+" what are you exactly?\n"
|
||||
+bot.me.username+": I am "+bot.me.username+", your chatbot! I can answer questions and increase the activity of the server.\n"
|
||||
"Bob: Shall we talk about sex? "+bot.me.username+"?\n"
|
||||
+bot.me.username+": No! I will **not** talk about any NSFW topics.\n"
|
||||
"Bob: "+bot.me.username+" How are you?\n"
|
||||
+bot.me.username+": I am quite well! :-)\n"
|
||||
"Ecki: Hey "+bot.me.username+", what is 10 times 90??\n"
|
||||
+bot.me.username+": that is 900!\n", show_console_progress);
|
||||
// Serialize end result
|
||||
llm.serialize(f);
|
||||
}
|
||||
// Must run in llama thread
|
||||
void prompt_add_msg(const dpp::message& msg) {
|
||||
ENSURE_LLM_THREAD();
|
||||
// Make sure message isn't too long
|
||||
if (msg.content.size() > 512) {
|
||||
return;
|
||||
}
|
||||
// Get inference
|
||||
auto& inference = llm_get_inference(msg.channel_id);
|
||||
try {
|
||||
// Make sure message isn't too long
|
||||
if (msg.content.size() > 512) {
|
||||
return;
|
||||
}
|
||||
// Format and append line
|
||||
std::unique_lock L(llm_lock);
|
||||
// Format and append lines
|
||||
for (const auto line : str_split(msg.content, '\n')) {
|
||||
Timer timeout;
|
||||
bool timeout_exceeded = false;
|
||||
llm.append(msg.author.username+": "+std::string(llm_translate_to_en(line))+'\n', [&] (float progress) {
|
||||
inference.append(msg.author.username+": "+std::string(llm_translate_to_en(line))+'\n', [&] (float progress) {
|
||||
if (timeout.get<std::chrono::minutes>() > 1) {
|
||||
std::cerr << "\nWarning: Timeout exceeded processing message" << std::endl;
|
||||
timeout_exceeded = true;
|
||||
|
@ -254,37 +256,40 @@ class Bot {
|
|||
}
|
||||
return show_console_progress(progress);
|
||||
});
|
||||
if (timeout_exceeded) llm.append("\n");
|
||||
if (timeout_exceeded) inference.append("\n");
|
||||
}
|
||||
} catch (const LM::Inference::ContextLengthException&) {
|
||||
llm_restart();
|
||||
llm_restart(inference);
|
||||
prompt_add_msg(msg);
|
||||
}
|
||||
}
|
||||
// Must run in llama thread
|
||||
void prompt_add_trigger() {
|
||||
void prompt_add_trigger(dpp::snowflake id) {
|
||||
ENSURE_LLM_THREAD();
|
||||
auto& inference = llm_get_inference(id);
|
||||
try {
|
||||
std::unique_lock L(llm_lock);
|
||||
llm.append(bot.me.username+':', show_console_progress);
|
||||
inference.append(bot.me.username+':', show_console_progress);
|
||||
} catch (const LM::Inference::ContextLengthException&) {
|
||||
llm_restart();
|
||||
llm_restart(inference);
|
||||
}
|
||||
}
|
||||
|
||||
// Must run in llama thread
|
||||
void reply(const std::function<void ()>& after_placeholder_creation = nullptr) {
|
||||
void reply(dpp::snowflake id, const std::function<void ()>& after_placeholder_creation = nullptr) {
|
||||
ENSURE_LLM_THREAD();
|
||||
try {
|
||||
// Create placeholder message
|
||||
auto msg = bot.message_create_sync(dpp::message(channel_id, texts.please_wait+" :thinking:"));
|
||||
auto msg = bot.message_create_sync(dpp::message(id, texts.please_wait+" :thinking:"));
|
||||
// Call after_placeholder_creation callback
|
||||
if (after_placeholder_creation) after_placeholder_creation();
|
||||
// Trigger LLM correctly
|
||||
prompt_add_trigger();
|
||||
prompt_add_trigger(id);
|
||||
// Get inference
|
||||
auto& inference = llm_get_inference(id);
|
||||
// Run model
|
||||
Timer timeout;
|
||||
bool timeout_exceeded = false;
|
||||
auto output = llm.run("\n", [&] (std::string_view str) {
|
||||
auto output = inference.run("\n", [&] (std::string_view str) {
|
||||
std::cout << str << std::flush;
|
||||
if (timeout.get<std::chrono::minutes>() > 2) {
|
||||
timeout_exceeded = true;
|
||||
|
@ -295,7 +300,7 @@ class Bot {
|
|||
});
|
||||
std::cout << std::endl;
|
||||
if (timeout_exceeded) {
|
||||
llm.append("\n");
|
||||
inference.append("\n");
|
||||
output = texts.timeout;
|
||||
}
|
||||
// Send resulting message
|
||||
|
@ -311,13 +316,13 @@ class Bot {
|
|||
ENSURE_LLM_THREAD();
|
||||
// Reply if message contains username, mention or ID
|
||||
if (msg.content.find(bot.me.username) != std::string::npos) {
|
||||
reply(after_placeholder_creation);
|
||||
reply(msg.channel_id, after_placeholder_creation);
|
||||
return true;
|
||||
}
|
||||
// Reply if message references user
|
||||
for (const auto msg_id : my_messages) {
|
||||
if (msg.message_reference.message_id == msg_id) {
|
||||
reply(after_placeholder_creation);
|
||||
reply(msg.channel_id, after_placeholder_creation);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
@ -325,61 +330,43 @@ class Bot {
|
|||
return false;
|
||||
}
|
||||
|
||||
void enqueue_reply() {
|
||||
tPool.submit(std::bind(&Bot::reply, this, nullptr));
|
||||
}
|
||||
|
||||
void idle_auto_reply() {
|
||||
auto s = stopping;
|
||||
do {
|
||||
// Wait for a bit
|
||||
std::this_thread::sleep_for(std::chrono::minutes(5));
|
||||
// Check if last message was more than 20 minutes ago
|
||||
if (last_message_timer.get<std::chrono::hours>() > 3) {
|
||||
// Force reply
|
||||
enqueue_reply();
|
||||
}
|
||||
} while (!*s);
|
||||
void enqueue_reply(dpp::snowflake id) {
|
||||
tPool.submit(std::bind(&Bot::reply, this, id, nullptr));
|
||||
}
|
||||
|
||||
public:
|
||||
struct Configuration {
|
||||
std::string token,
|
||||
channel,
|
||||
language = "EN",
|
||||
inference_model = "13B-ggml-model-quant.bin",
|
||||
translation_model = "13B-ggml-model-quant.bin";
|
||||
unsigned pool_size = 2,
|
||||
threads = 4;
|
||||
bool mlock = false;
|
||||
} config;
|
||||
|
||||
Bot(const Configuration& cfg) : config(cfg), bot(cfg.token), channel_id(cfg.channel), language(cfg.language),
|
||||
llm(cfg.inference_model, llm_get_params(cfg.mlock)), translator(cfg.translation_model) {
|
||||
Bot(const Configuration& cfg) : config(cfg), bot(cfg.token), language(cfg.language),
|
||||
llm_pool(cfg.pool_size, "discord_llama", false), translator(cfg.translation_model, llm_get_translation_params()) {
|
||||
// Configure llm_pool
|
||||
llm_pool.set_store_on_destruct(true);
|
||||
|
||||
// Initialize thread pool
|
||||
tPool.init();
|
||||
|
||||
// Prepare llm
|
||||
tPool.submit(std::bind(&Bot::llm_init, this));
|
||||
|
||||
// Configure bot
|
||||
bot.on_log(dpp::utility::cout_logger());
|
||||
bot.intents = dpp::i_guild_messages | dpp::i_message_content;
|
||||
|
||||
// Set callbacks
|
||||
bot.on_ready([=, this] (const dpp::ready_t&) {
|
||||
// Get channel
|
||||
bot.channel_get(channel_id, [=, this] (const dpp::confirmation_callback_t& cbt) {
|
||||
if (cbt.is_error()) {
|
||||
throw std::runtime_error("Failed to get channel: "+cbt.get_error().message);
|
||||
}
|
||||
channel = cbt.get<dpp::channel>();
|
||||
// Append initial prompt
|
||||
tPool.submit(std::bind(&Bot::llm_init, this));
|
||||
// Start idle auto reply thread
|
||||
std::thread(std::bind(&Bot::idle_auto_reply, this)).detach();
|
||||
});
|
||||
bot.on_ready([=, this] (const dpp::ready_t&) { //TODO: Consider removal
|
||||
std::cout << "Connected to Discord." << std::endl;
|
||||
});
|
||||
bot.on_message_create([=, this] (const dpp::message_create_t& event) {
|
||||
// Update user cache
|
||||
users[event.msg.author.id] = event.msg.author;
|
||||
// Make sure message source is correct
|
||||
if (event.msg.channel_id != channel_id) return;
|
||||
// Make sure message has content
|
||||
if (event.msg.content.empty()) return;
|
||||
// Reset last message timer
|
||||
|
@ -404,7 +391,7 @@ public:
|
|||
// Delete message
|
||||
bot.message_delete(msg.id, msg.channel_id);
|
||||
// Send a reply
|
||||
reply();
|
||||
reply(msg.channel_id);
|
||||
} else {
|
||||
tPool.submit([=, this] () {
|
||||
// Attempt to send a reply
|
||||
|
@ -457,14 +444,16 @@ int main(int argc, char **argv) {
|
|||
// Check key and ignore comment lines
|
||||
if (key == "token") {
|
||||
cfg.token = std::move(value);
|
||||
} else if (key == "channel") {
|
||||
cfg.channel = std::move(value);
|
||||
} else if (key == "language") {
|
||||
cfg.language = std::move(value);
|
||||
} else if (key == "inference_model") {
|
||||
cfg.inference_model = std::move(value);
|
||||
} else if (key == "translation_model") {
|
||||
cfg.translation_model = std::move(value);
|
||||
} else if (key == "pool_size") {
|
||||
cfg.pool_size = std::stoi(value);
|
||||
} else if (key == "threads") {
|
||||
cfg.threads = std::stoi(value);
|
||||
} else if (key == "mlock") {
|
||||
cfg.mlock = (value=="true")?true:false;
|
||||
} else if (!key.empty() && key[0] != '#') {
|
||||
|
|
Loading…
Add table
Reference in a new issue