mirror of
https://gitlab.com/niansa/discord_llama.git
synced 2025-03-06 20:48:25 +01:00
Remove need for restarting
This commit is contained in:
parent
64f983a18e
commit
21dbc073a9
1 changed files with 84 additions and 84 deletions
168
main.cpp
168
main.cpp
|
@ -69,8 +69,9 @@ class Bot {
|
|||
ThreadPool tPool{1};
|
||||
Timer last_message_timer;
|
||||
std::shared_ptr<bool> stopping;
|
||||
std::unique_ptr<LM::Inference> llm = nullptr;
|
||||
std::unique_ptr<Translator> translator;
|
||||
LM::Inference llm;
|
||||
Translator translator;
|
||||
LM::Inference::Savestate start_sv;
|
||||
std::vector<dpp::snowflake> my_messages;
|
||||
std::unordered_map<dpp::snowflake, dpp::user> users;
|
||||
std::mutex llm_lock;
|
||||
|
@ -139,11 +140,10 @@ class Bot {
|
|||
str_replace_in_place(fres, bot.me.username, "[43]");
|
||||
// Run translation
|
||||
try {
|
||||
fres = translator->translate(fres, "EN", show_console_progress);
|
||||
fres = translator.translate(fres, "EN", show_console_progress);
|
||||
} catch (const LM::Inference::ContextLengthException&) {
|
||||
// Handle potential context overflow error
|
||||
translator.reset();
|
||||
llm_init();
|
||||
llm_restart();
|
||||
return llm_translate_to_en(text);
|
||||
}
|
||||
// Replace [43] back with bot username
|
||||
|
@ -164,11 +164,10 @@ class Bot {
|
|||
str_replace_in_place(fres, bot.me.username, "[43]");
|
||||
// Run translation
|
||||
try {
|
||||
fres = translator->translate(fres, language, show_console_progress);
|
||||
fres = translator.translate(fres, language, show_console_progress);
|
||||
} catch (const LM::Inference::ContextLengthException&) {
|
||||
// Handle potential context overflow error
|
||||
translator.reset();
|
||||
llm_init();
|
||||
llm_restart();
|
||||
return llm_translate_from_en(text);
|
||||
}
|
||||
// Replace [43] back with bot username
|
||||
|
@ -177,73 +176,77 @@ class Bot {
|
|||
return fres;
|
||||
}
|
||||
|
||||
constexpr static LM::Inference::Params llm_get_default_params() {
|
||||
return {
|
||||
.n_repeat_last = 256,
|
||||
.temp = 0.4f,
|
||||
.repeat_penalty = 1.372222224f,
|
||||
.use_mlock = false,
|
||||
.n_ctx = 1012
|
||||
};
|
||||
}
|
||||
|
||||
// Must run in llama thread
|
||||
void llm_restart() {
|
||||
ENSURE_LLM_THREAD();
|
||||
llm.restore_savestate(start_sv);
|
||||
}
|
||||
|
||||
// Must run in llama thread
|
||||
void llm_init() {
|
||||
if (!llm) {
|
||||
// Create params
|
||||
LM::Inference::Params params;
|
||||
params.use_mlock = false;
|
||||
params.temp = 0.4f;
|
||||
params.n_repeat_last = 256;
|
||||
params.repeat_penalty = 1.372222224f;
|
||||
// Make sure llm is initialized
|
||||
{
|
||||
std::unique_lock L(llm_lock);
|
||||
if (translator == nullptr && language != "EN") translator = std::make_unique<Translator>("13B-ggml-model-quant.bin");
|
||||
llm = std::make_unique<LM::Inference>("13B-ggml-model-quant.bin", params);
|
||||
}
|
||||
// Set LLM thread
|
||||
llm_tid = std::this_thread::get_id();
|
||||
// Translate texts
|
||||
if (!texts.translated) {
|
||||
texts.please_wait = llm_translate_from_en(texts.please_wait);
|
||||
texts.initializing = llm_translate_from_en(texts.initializing);
|
||||
texts.loading = llm_translate_from_en(texts.loading);
|
||||
texts.timeout = llm_translate_from_en(texts.timeout);
|
||||
texts.translated = true;
|
||||
}
|
||||
// Create message for reporting progress
|
||||
dpp::message msg(channel_id, texts.initializing);
|
||||
bot.message_create(msg, [this] (const dpp::confirmation_callback_t& cbt) {
|
||||
// Error check
|
||||
if (cbt.is_error()) {
|
||||
throw std::runtime_error("Failed to send message to channel: "+cbt.get_error().message);
|
||||
}
|
||||
// Callback for reporting progress
|
||||
Timer timer;
|
||||
auto msg = cbt.get<dpp::message>();
|
||||
uint8_t last_progress = 0;
|
||||
auto cb = [&, this] (float progress) mutable {
|
||||
uint8_t progress_i = progress;
|
||||
if (timer.get<std::chrono::seconds>() > 5) {
|
||||
msg.content = texts.loading+" "+create_text_progress_indicator(progress_i)+
|
||||
" **"+std::to_string(progress_i)+"**% (**"+std::to_string((progress_i-last_progress)/5)+"**%/s)";
|
||||
last_progress = progress_i;
|
||||
bot.message_edit(msg);
|
||||
timer.reset();
|
||||
}
|
||||
return true;
|
||||
};
|
||||
// Add initial context
|
||||
std::unique_lock L(llm_lock);
|
||||
llm->append("History of the #"+channel.name+" channel.\n"
|
||||
"Note 1: "+bot.me.username+" is a friendly chatbot that is always happy to talk. He is friendly and helpful and always answers immediately. He has a good sense of humor and likes everyone. His age is 16 and he was born in 2007.\n"
|
||||
"Note 2: Ecki's real name is Eckhard Kohlhuber and he comes from Bavaria.\n"
|
||||
"\n"
|
||||
"Bob: "+bot.me.username+" have you ever been to France and if yes where?\n"
|
||||
+bot.me.username+": I was in Paris, in the museums!\n"
|
||||
"Bob: "+bot.me.username+" what are you exactly?\n"
|
||||
+bot.me.username+": I am "+bot.me.username+", your chatbot! I can answer questions and increase the activity of the server.\n"
|
||||
"Bob: Shall we talk about sex? "+bot.me.username+"?\n"
|
||||
+bot.me.username+": No! I will **not** talk about any NSFW topics.\n"
|
||||
"Bob: "+bot.me.username+" How are you?\n"
|
||||
+bot.me.username+": I am quite well! :-)\n"
|
||||
"Ecki: Hey "+bot.me.username+", what is 10 times 90??\n"
|
||||
+bot.me.username+": that is 900!\n", cb);
|
||||
// Delete progress message
|
||||
bot.message_delete(msg.id, msg.channel_id);
|
||||
});
|
||||
// Set LLM thread
|
||||
llm_tid = std::this_thread::get_id();
|
||||
// Translate texts
|
||||
if (!texts.translated) {
|
||||
texts.please_wait = llm_translate_from_en(texts.please_wait);
|
||||
texts.initializing = llm_translate_from_en(texts.initializing);
|
||||
texts.loading = llm_translate_from_en(texts.loading);
|
||||
texts.timeout = llm_translate_from_en(texts.timeout);
|
||||
texts.translated = true;
|
||||
}
|
||||
// Create message for reporting progress
|
||||
dpp::message msg(channel_id, texts.initializing);
|
||||
bot.message_create(msg, [this] (const dpp::confirmation_callback_t& cbt) {
|
||||
// Error check
|
||||
if (cbt.is_error()) {
|
||||
throw std::runtime_error("Failed to send message to channel: "+cbt.get_error().message);
|
||||
}
|
||||
// Callback for reporting progress
|
||||
Timer timer;
|
||||
auto msg = cbt.get<dpp::message>();
|
||||
uint8_t last_progress = 0;
|
||||
auto cb = [&, this] (float progress) mutable {
|
||||
uint8_t progress_i = progress;
|
||||
if (timer.get<std::chrono::seconds>() > 5) {
|
||||
msg.content = texts.loading+" "+create_text_progress_indicator(progress_i)+
|
||||
" **"+std::to_string(progress_i)+"**% (**"+std::to_string((progress_i-last_progress)/5)+"**%/s)";
|
||||
last_progress = progress_i;
|
||||
bot.message_edit(msg);
|
||||
timer.reset();
|
||||
}
|
||||
return true;
|
||||
};
|
||||
// Add initial context
|
||||
std::unique_lock L(llm_lock);
|
||||
llm.append("History of the #"+channel.name+" channel.\n"
|
||||
"Note 1: "+bot.me.username+" is a friendly chatbot that is always happy to talk. He is friendly and helpful and always answers immediately. He has a good sense of humor and likes everyone. His age is 16 and he was born in 2007.\n"
|
||||
"Note 2: Ecki's real name is Eckhard Kohlhuber and he comes from Bavaria.\n"
|
||||
"\n"
|
||||
"Bob: "+bot.me.username+" have you ever been to France and if yes where?\n"
|
||||
+bot.me.username+": I was in Paris, in the museums!\n"
|
||||
"Bob: "+bot.me.username+" what are you exactly?\n"
|
||||
+bot.me.username+": I am "+bot.me.username+", your chatbot! I can answer questions and increase the activity of the server.\n"
|
||||
"Bob: Shall we talk about sex? "+bot.me.username+"?\n"
|
||||
+bot.me.username+": No! I will **not** talk about any NSFW topics.\n"
|
||||
"Bob: "+bot.me.username+" How are you?\n"
|
||||
+bot.me.username+": I am quite well! :-)\n"
|
||||
"Ecki: Hey "+bot.me.username+", what is 10 times 90??\n"
|
||||
+bot.me.username+": that is 900!\n", cb);
|
||||
// Delete progress message
|
||||
bot.message_delete(msg.id, msg.channel_id);
|
||||
// Create savestate
|
||||
llm.create_savestate(start_sv);
|
||||
});
|
||||
}
|
||||
// Must run in llama thread
|
||||
void prompt_add_msg(const dpp::message& msg) {
|
||||
|
@ -258,7 +261,7 @@ class Bot {
|
|||
for (const auto line : str_split(msg.content, '\n')) {
|
||||
Timer timeout;
|
||||
bool timeout_exceeded = false;
|
||||
llm->append(msg.author.username+": "+llm_translate_to_en(clean_string(line))+'\n', [&] (float progress) {
|
||||
llm.append(msg.author.username+": "+llm_translate_to_en(clean_string(line))+'\n', [&] (float progress) {
|
||||
if (timeout.get<std::chrono::minutes>() > 1) {
|
||||
std::cerr << "\nWarning: Timeout exceeded processing message" << std::endl;
|
||||
timeout_exceeded = true;
|
||||
|
@ -266,11 +269,10 @@ class Bot {
|
|||
}
|
||||
return show_console_progress(progress);
|
||||
});
|
||||
if (timeout_exceeded) llm->append("\n");
|
||||
if (timeout_exceeded) llm.append("\n");
|
||||
}
|
||||
} catch (const LM::Inference::ContextLengthException&) {
|
||||
llm.reset();
|
||||
llm_init();
|
||||
llm_restart();
|
||||
}
|
||||
}
|
||||
// Must run in llama thread
|
||||
|
@ -278,10 +280,9 @@ class Bot {
|
|||
ENSURE_LLM_THREAD();
|
||||
try {
|
||||
std::unique_lock L(llm_lock);
|
||||
llm->append(bot.me.username+':', show_console_progress);
|
||||
llm.append(bot.me.username+':', show_console_progress);
|
||||
} catch (const LM::Inference::ContextLengthException&) {
|
||||
llm.reset();
|
||||
llm_init();
|
||||
llm_restart();
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -298,7 +299,7 @@ class Bot {
|
|||
// Run model
|
||||
Timer timeout;
|
||||
bool timeout_exceeded = false;
|
||||
auto output = llm->run("\n", [&] (std::string_view str) {
|
||||
auto output = llm.run("\n", [&] (std::string_view str) {
|
||||
std::cout << str << std::flush;
|
||||
if (timeout.get<std::chrono::minutes>() > 2) {
|
||||
timeout_exceeded = true;
|
||||
|
@ -309,7 +310,7 @@ class Bot {
|
|||
});
|
||||
std::cout << std::endl;
|
||||
if (timeout_exceeded) {
|
||||
llm->append("\n");
|
||||
llm.append("\n");
|
||||
output = texts.timeout;
|
||||
}
|
||||
// Send resulting message
|
||||
|
@ -361,7 +362,8 @@ class Bot {
|
|||
}
|
||||
|
||||
public:
|
||||
Bot(std::string_view language, const char *token, dpp::snowflake channel_id) : bot(token), channel_id(channel_id), language(language) {
|
||||
Bot(std::string_view language, const char *token, dpp::snowflake channel_id) : bot(token), channel_id(channel_id), language(language),
|
||||
llm("13B-ggml-model-quant.bin", llm_get_default_params()), translator("13B-ggml-model-quant.bin") {
|
||||
// Initialize thread pool
|
||||
tPool.init();
|
||||
|
||||
|
@ -388,8 +390,6 @@ public:
|
|||
bot.on_message_create([=, this] (const dpp::message_create_t& event) {
|
||||
// Update user cache
|
||||
users[event.msg.author.id] = event.msg.author;
|
||||
// Ignore messages before full startup
|
||||
if (!llm) return;
|
||||
// Make sure message source is correct
|
||||
if (event.msg.channel_id != channel_id) return;
|
||||
// Make sure message has content
|
||||
|
|
Loading…
Add table
Reference in a new issue