1
0
Fork 0
mirror of https://gitlab.com/niansa/discord_llama.git synced 2025-03-06 20:48:25 +01:00

Remove need for restarting

This commit is contained in:
niansa 2023-04-19 21:10:59 +02:00
parent 64f983a18e
commit 21dbc073a9

168
main.cpp
View file

@ -69,8 +69,9 @@ class Bot {
ThreadPool tPool{1};
Timer last_message_timer;
std::shared_ptr<bool> stopping;
std::unique_ptr<LM::Inference> llm = nullptr;
std::unique_ptr<Translator> translator;
LM::Inference llm;
Translator translator;
LM::Inference::Savestate start_sv;
std::vector<dpp::snowflake> my_messages;
std::unordered_map<dpp::snowflake, dpp::user> users;
std::mutex llm_lock;
@ -139,11 +140,10 @@ class Bot {
str_replace_in_place(fres, bot.me.username, "[43]");
// Run translation
try {
fres = translator->translate(fres, "EN", show_console_progress);
fres = translator.translate(fres, "EN", show_console_progress);
} catch (const LM::Inference::ContextLengthException&) {
// Handle potential context overflow error
translator.reset();
llm_init();
llm_restart();
return llm_translate_to_en(text);
}
// Replace [43] back with bot username
@ -164,11 +164,10 @@ class Bot {
str_replace_in_place(fres, bot.me.username, "[43]");
// Run translation
try {
fres = translator->translate(fres, language, show_console_progress);
fres = translator.translate(fres, language, show_console_progress);
} catch (const LM::Inference::ContextLengthException&) {
// Handle potential context overflow error
translator.reset();
llm_init();
llm_restart();
return llm_translate_from_en(text);
}
// Replace [43] back with bot username
@ -177,73 +176,77 @@ class Bot {
return fres;
}
constexpr static LM::Inference::Params llm_get_default_params() {
return {
.n_repeat_last = 256,
.temp = 0.4f,
.repeat_penalty = 1.372222224f,
.use_mlock = false,
.n_ctx = 1012
};
}
// Must run in llama thread
void llm_restart() {
ENSURE_LLM_THREAD();
llm.restore_savestate(start_sv);
}
// Must run in llama thread
void llm_init() {
if (!llm) {
// Create params
LM::Inference::Params params;
params.use_mlock = false;
params.temp = 0.4f;
params.n_repeat_last = 256;
params.repeat_penalty = 1.372222224f;
// Make sure llm is initialized
{
std::unique_lock L(llm_lock);
if (translator == nullptr && language != "EN") translator = std::make_unique<Translator>("13B-ggml-model-quant.bin");
llm = std::make_unique<LM::Inference>("13B-ggml-model-quant.bin", params);
}
// Set LLM thread
llm_tid = std::this_thread::get_id();
// Translate texts
if (!texts.translated) {
texts.please_wait = llm_translate_from_en(texts.please_wait);
texts.initializing = llm_translate_from_en(texts.initializing);
texts.loading = llm_translate_from_en(texts.loading);
texts.timeout = llm_translate_from_en(texts.timeout);
texts.translated = true;
}
// Create message for reporting progress
dpp::message msg(channel_id, texts.initializing);
bot.message_create(msg, [this] (const dpp::confirmation_callback_t& cbt) {
// Error check
if (cbt.is_error()) {
throw std::runtime_error("Failed to send message to channel: "+cbt.get_error().message);
}
// Callback for reporting progress
Timer timer;
auto msg = cbt.get<dpp::message>();
uint8_t last_progress = 0;
auto cb = [&, this] (float progress) mutable {
uint8_t progress_i = progress;
if (timer.get<std::chrono::seconds>() > 5) {
msg.content = texts.loading+" "+create_text_progress_indicator(progress_i)+
" **"+std::to_string(progress_i)+"**% (**"+std::to_string((progress_i-last_progress)/5)+"**%/s)";
last_progress = progress_i;
bot.message_edit(msg);
timer.reset();
}
return true;
};
// Add initial context
std::unique_lock L(llm_lock);
llm->append("History of the #"+channel.name+" channel.\n"
"Note 1: "+bot.me.username+" is a friendly chatbot that is always happy to talk. He is friendly and helpful and always answers immediately. He has a good sense of humor and likes everyone. His age is 16 and he was born in 2007.\n"
"Note 2: Ecki's real name is Eckhard Kohlhuber and he comes from Bavaria.\n"
"\n"
"Bob: "+bot.me.username+" have you ever been to France and if yes where?\n"
+bot.me.username+": I was in Paris, in the museums!\n"
"Bob: "+bot.me.username+" what are you exactly?\n"
+bot.me.username+": I am "+bot.me.username+", your chatbot! I can answer questions and increase the activity of the server.\n"
"Bob: Shall we talk about sex? "+bot.me.username+"?\n"
+bot.me.username+": No! I will **not** talk about any NSFW topics.\n"
"Bob: "+bot.me.username+" How are you?\n"
+bot.me.username+": I am quite well! :-)\n"
"Ecki: Hey "+bot.me.username+", what is 10 times 90??\n"
+bot.me.username+": that is 900!\n", cb);
// Delete progress message
bot.message_delete(msg.id, msg.channel_id);
});
// Set LLM thread
llm_tid = std::this_thread::get_id();
// Translate texts
if (!texts.translated) {
texts.please_wait = llm_translate_from_en(texts.please_wait);
texts.initializing = llm_translate_from_en(texts.initializing);
texts.loading = llm_translate_from_en(texts.loading);
texts.timeout = llm_translate_from_en(texts.timeout);
texts.translated = true;
}
// Create message for reporting progress
dpp::message msg(channel_id, texts.initializing);
bot.message_create(msg, [this] (const dpp::confirmation_callback_t& cbt) {
// Error check
if (cbt.is_error()) {
throw std::runtime_error("Failed to send message to channel: "+cbt.get_error().message);
}
// Callback for reporting progress
Timer timer;
auto msg = cbt.get<dpp::message>();
uint8_t last_progress = 0;
auto cb = [&, this] (float progress) mutable {
uint8_t progress_i = progress;
if (timer.get<std::chrono::seconds>() > 5) {
msg.content = texts.loading+" "+create_text_progress_indicator(progress_i)+
" **"+std::to_string(progress_i)+"**% (**"+std::to_string((progress_i-last_progress)/5)+"**%/s)";
last_progress = progress_i;
bot.message_edit(msg);
timer.reset();
}
return true;
};
// Add initial context
std::unique_lock L(llm_lock);
llm.append("History of the #"+channel.name+" channel.\n"
"Note 1: "+bot.me.username+" is a friendly chatbot that is always happy to talk. He is friendly and helpful and always answers immediately. He has a good sense of humor and likes everyone. His age is 16 and he was born in 2007.\n"
"Note 2: Ecki's real name is Eckhard Kohlhuber and he comes from Bavaria.\n"
"\n"
"Bob: "+bot.me.username+" have you ever been to France and if yes where?\n"
+bot.me.username+": I was in Paris, in the museums!\n"
"Bob: "+bot.me.username+" what are you exactly?\n"
+bot.me.username+": I am "+bot.me.username+", your chatbot! I can answer questions and increase the activity of the server.\n"
"Bob: Shall we talk about sex? "+bot.me.username+"?\n"
+bot.me.username+": No! I will **not** talk about any NSFW topics.\n"
"Bob: "+bot.me.username+" How are you?\n"
+bot.me.username+": I am quite well! :-)\n"
"Ecki: Hey "+bot.me.username+", what is 10 times 90??\n"
+bot.me.username+": that is 900!\n", cb);
// Delete progress message
bot.message_delete(msg.id, msg.channel_id);
// Create savestate
llm.create_savestate(start_sv);
});
}
// Must run in llama thread
void prompt_add_msg(const dpp::message& msg) {
@ -258,7 +261,7 @@ class Bot {
for (const auto line : str_split(msg.content, '\n')) {
Timer timeout;
bool timeout_exceeded = false;
llm->append(msg.author.username+": "+llm_translate_to_en(clean_string(line))+'\n', [&] (float progress) {
llm.append(msg.author.username+": "+llm_translate_to_en(clean_string(line))+'\n', [&] (float progress) {
if (timeout.get<std::chrono::minutes>() > 1) {
std::cerr << "\nWarning: Timeout exceeded processing message" << std::endl;
timeout_exceeded = true;
@ -266,11 +269,10 @@ class Bot {
}
return show_console_progress(progress);
});
if (timeout_exceeded) llm->append("\n");
if (timeout_exceeded) llm.append("\n");
}
} catch (const LM::Inference::ContextLengthException&) {
llm.reset();
llm_init();
llm_restart();
}
}
// Must run in llama thread
@ -278,10 +280,9 @@ class Bot {
ENSURE_LLM_THREAD();
try {
std::unique_lock L(llm_lock);
llm->append(bot.me.username+':', show_console_progress);
llm.append(bot.me.username+':', show_console_progress);
} catch (const LM::Inference::ContextLengthException&) {
llm.reset();
llm_init();
llm_restart();
}
}
@ -298,7 +299,7 @@ class Bot {
// Run model
Timer timeout;
bool timeout_exceeded = false;
auto output = llm->run("\n", [&] (std::string_view str) {
auto output = llm.run("\n", [&] (std::string_view str) {
std::cout << str << std::flush;
if (timeout.get<std::chrono::minutes>() > 2) {
timeout_exceeded = true;
@ -309,7 +310,7 @@ class Bot {
});
std::cout << std::endl;
if (timeout_exceeded) {
llm->append("\n");
llm.append("\n");
output = texts.timeout;
}
// Send resulting message
@ -361,7 +362,8 @@ class Bot {
}
public:
Bot(std::string_view language, const char *token, dpp::snowflake channel_id) : bot(token), channel_id(channel_id), language(language) {
Bot(std::string_view language, const char *token, dpp::snowflake channel_id) : bot(token), channel_id(channel_id), language(language),
llm("13B-ggml-model-quant.bin", llm_get_default_params()), translator("13B-ggml-model-quant.bin") {
// Initialize thread pool
tPool.init();
@ -388,8 +390,6 @@ public:
bot.on_message_create([=, this] (const dpp::message_create_t& event) {
// Update user cache
users[event.msg.author.id] = event.msg.author;
// Ignore messages before full startup
if (!llm) return;
// Make sure message source is correct
if (event.msg.channel_id != channel_id) return;
// Make sure message has content