1
0
Fork 0
mirror of https://gitlab.com/niansa/discord_llama.git synced 2025-03-06 20:48:25 +01:00

Remove need for restarting

This commit is contained in:
niansa 2023-04-19 21:10:59 +02:00
parent 64f983a18e
commit 21dbc073a9

View file

@ -69,8 +69,9 @@ class Bot {
ThreadPool tPool{1};
Timer last_message_timer;
std::shared_ptr<bool> stopping;
std::unique_ptr<LM::Inference> llm = nullptr;
std::unique_ptr<Translator> translator;
LM::Inference llm;
Translator translator;
LM::Inference::Savestate start_sv;
std::vector<dpp::snowflake> my_messages;
std::unordered_map<dpp::snowflake, dpp::user> users;
std::mutex llm_lock;
@ -139,11 +140,10 @@ class Bot {
str_replace_in_place(fres, bot.me.username, "[43]");
// Run translation
try {
fres = translator->translate(fres, "EN", show_console_progress);
fres = translator.translate(fres, "EN", show_console_progress);
} catch (const LM::Inference::ContextLengthException&) {
// Handle potential context overflow error
translator.reset();
llm_init();
llm_restart();
return llm_translate_to_en(text);
}
// Replace [43] back with bot username
@ -164,11 +164,10 @@ class Bot {
str_replace_in_place(fres, bot.me.username, "[43]");
// Run translation
try {
fres = translator->translate(fres, language, show_console_progress);
fres = translator.translate(fres, language, show_console_progress);
} catch (const LM::Inference::ContextLengthException&) {
// Handle potential context overflow error
translator.reset();
llm_init();
llm_restart();
return llm_translate_from_en(text);
}
// Replace [43] back with bot username
@ -177,21 +176,24 @@ class Bot {
return fres;
}
constexpr static LM::Inference::Params llm_get_default_params() {
return {
.n_repeat_last = 256,
.temp = 0.4f,
.repeat_penalty = 1.372222224f,
.use_mlock = false,
.n_ctx = 1012
};
}
// Must run in llama thread
void llm_restart() {
ENSURE_LLM_THREAD();
llm.restore_savestate(start_sv);
}
// Must run in llama thread
void llm_init() {
if (!llm) {
// Create params
LM::Inference::Params params;
params.use_mlock = false;
params.temp = 0.4f;
params.n_repeat_last = 256;
params.repeat_penalty = 1.372222224f;
// Make sure llm is initialized
{
std::unique_lock L(llm_lock);
if (translator == nullptr && language != "EN") translator = std::make_unique<Translator>("13B-ggml-model-quant.bin");
llm = std::make_unique<LM::Inference>("13B-ggml-model-quant.bin", params);
}
// Set LLM thread
llm_tid = std::this_thread::get_id();
// Translate texts
@ -226,7 +228,7 @@ class Bot {
};
// Add initial context
std::unique_lock L(llm_lock);
llm->append("History of the #"+channel.name+" channel.\n"
llm.append("History of the #"+channel.name+" channel.\n"
"Note 1: "+bot.me.username+" is a friendly chatbot that is always happy to talk. He is friendly and helpful and always answers immediately. He has a good sense of humor and likes everyone. His age is 16 and he was born in 2007.\n"
"Note 2: Ecki's real name is Eckhard Kohlhuber and he comes from Bavaria.\n"
"\n"
@ -242,9 +244,10 @@ class Bot {
+bot.me.username+": that is 900!\n", cb);
// Delete progress message
bot.message_delete(msg.id, msg.channel_id);
// Create savestate
llm.create_savestate(start_sv);
});
}
}
// Must run in llama thread
void prompt_add_msg(const dpp::message& msg) {
ENSURE_LLM_THREAD();
@ -258,7 +261,7 @@ class Bot {
for (const auto line : str_split(msg.content, '\n')) {
Timer timeout;
bool timeout_exceeded = false;
llm->append(msg.author.username+": "+llm_translate_to_en(clean_string(line))+'\n', [&] (float progress) {
llm.append(msg.author.username+": "+llm_translate_to_en(clean_string(line))+'\n', [&] (float progress) {
if (timeout.get<std::chrono::minutes>() > 1) {
std::cerr << "\nWarning: Timeout exceeded processing message" << std::endl;
timeout_exceeded = true;
@ -266,11 +269,10 @@ class Bot {
}
return show_console_progress(progress);
});
if (timeout_exceeded) llm->append("\n");
if (timeout_exceeded) llm.append("\n");
}
} catch (const LM::Inference::ContextLengthException&) {
llm.reset();
llm_init();
llm_restart();
}
}
// Must run in llama thread
@ -278,10 +280,9 @@ class Bot {
ENSURE_LLM_THREAD();
try {
std::unique_lock L(llm_lock);
llm->append(bot.me.username+':', show_console_progress);
llm.append(bot.me.username+':', show_console_progress);
} catch (const LM::Inference::ContextLengthException&) {
llm.reset();
llm_init();
llm_restart();
}
}
@ -298,7 +299,7 @@ class Bot {
// Run model
Timer timeout;
bool timeout_exceeded = false;
auto output = llm->run("\n", [&] (std::string_view str) {
auto output = llm.run("\n", [&] (std::string_view str) {
std::cout << str << std::flush;
if (timeout.get<std::chrono::minutes>() > 2) {
timeout_exceeded = true;
@ -309,7 +310,7 @@ class Bot {
});
std::cout << std::endl;
if (timeout_exceeded) {
llm->append("\n");
llm.append("\n");
output = texts.timeout;
}
// Send resulting message
@ -361,7 +362,8 @@ class Bot {
}
public:
Bot(std::string_view language, const char *token, dpp::snowflake channel_id) : bot(token), channel_id(channel_id), language(language) {
Bot(std::string_view language, const char *token, dpp::snowflake channel_id) : bot(token), channel_id(channel_id), language(language),
llm("13B-ggml-model-quant.bin", llm_get_default_params()), translator("13B-ggml-model-quant.bin") {
// Initialize thread pool
tPool.init();
@ -388,8 +390,6 @@ public:
bot.on_message_create([=, this] (const dpp::message_create_t& event) {
// Update user cache
users[event.msg.author.id] = event.msg.author;
// Ignore messages before full startup
if (!llm) return;
// Make sure message source is correct
if (event.msg.channel_id != channel_id) return;
// Make sure message has content