mirror of
https://gitlab.com/niansa/discord_llama.git
synced 2025-03-06 20:48:25 +01:00
Translate everything using anyproc
This commit is contained in:
parent
5b0faba244
commit
5f92f21e8a
5 changed files with 88 additions and 32 deletions
6
.gitmodules
vendored
6
.gitmodules
vendored
|
@ -1,6 +1,6 @@
|
|||
[submodule "DPP"]
|
||||
path = DPP
|
||||
url = https://github.com/brainboxdotcc/DPP.git
|
||||
[submodule "libjustlm"]
|
||||
path = libjustlm
|
||||
url = https://gitlab.com/niansa/libjustlm.git
|
||||
[submodule "anyproc"]
|
||||
path = anyproc
|
||||
url = https://gitlab.com/niansa/anyproc.git
|
||||
|
|
|
@ -5,12 +5,12 @@ project(discord_llama LANGUAGES C CXX)
|
|||
set(CMAKE_CXX_STANDARD 20)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
add_subdirectory(libjustlm)
|
||||
add_subdirectory(anyproc)
|
||||
add_subdirectory(DPP)
|
||||
add_subdirectory(thread-pool)
|
||||
|
||||
add_executable(discord_llama main.cpp)
|
||||
target_link_libraries(discord_llama PUBLIC dpp pthread libjustlm ggml threadpool)
|
||||
target_link_libraries(discord_llama PUBLIC dpp pthread libjustlm anyproc ggml threadpool)
|
||||
|
||||
install(TARGETS discord_llama
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
||||
|
|
1
anyproc
Submodule
1
anyproc
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit 70d8c3bd82b6b49c06daf323e8d16b6160b7a28d
|
|
@ -1 +0,0 @@
|
|||
Subproject commit 7ae1547e5f82807bda195d8c01aa499af39c04a0
|
108
main.cpp
108
main.cpp
|
@ -16,6 +16,7 @@
|
|||
#include <memory>
|
||||
#include <dpp/dpp.h>
|
||||
#include <justlm.hpp>
|
||||
#include <anyproc.hpp>
|
||||
#include <ThreadPool.h>
|
||||
|
||||
|
||||
|
@ -69,14 +70,24 @@ class Bot {
|
|||
Timer last_message_timer;
|
||||
std::shared_ptr<bool> stopping;
|
||||
std::unique_ptr<LM::Inference> llm;
|
||||
std::unique_ptr<Translator> translator;
|
||||
std::vector<dpp::snowflake> my_messages;
|
||||
std::mutex llm_lock;
|
||||
std::thread::id llm_tid;
|
||||
|
||||
std::string_view language;
|
||||
dpp::cluster bot;
|
||||
dpp::channel channel;
|
||||
dpp::snowflake channel_id;
|
||||
|
||||
struct Texts {
|
||||
std::string please_wait = "Please wait...",
|
||||
loading = "Loading...",
|
||||
initializing = "Initializing...",
|
||||
timeout = "Error: Timeout";
|
||||
bool translated = false;
|
||||
} texts;
|
||||
|
||||
inline static
|
||||
std::string create_text_progress_indicator(uint8_t percentage) {
|
||||
static constexpr uint8_t divisor = 3,
|
||||
|
@ -106,15 +117,51 @@ class Bot {
|
|||
return fres;
|
||||
}
|
||||
|
||||
inline static
|
||||
bool show_console_progress(float progress) {
|
||||
std::cout << ' ' << unsigned(progress) << "% \r" << std::flush;
|
||||
return true;
|
||||
}
|
||||
|
||||
// Must run in llama thread
|
||||
# define ENSURE_LLM_THREAD() if (std::this_thread::get_id() != llm_tid) {throw std::runtime_error("LLM execution of '"+std::string(__PRETTY_FUNCTION__)+"' on wrong thread detected");} 0
|
||||
|
||||
// Must run in llama thread
|
||||
const std::string& llm_translate_to_en(const std::string& text) {
|
||||
ENSURE_LLM_THREAD();
|
||||
if (language == "EN") return text;
|
||||
static std::string fres;
|
||||
try {
|
||||
fres = translator->translate(text, "EN", show_console_progress);
|
||||
} catch (const LM::Inference::ContextLengthException&) {
|
||||
translator.reset();
|
||||
llm_init();
|
||||
}
|
||||
std::cout << text << " --> (EN) " << fres << std::endl;
|
||||
return fres;
|
||||
}
|
||||
|
||||
// Must run in llama thread
|
||||
const std::string& llm_translate_from_en(const std::string& text) {
|
||||
ENSURE_LLM_THREAD();
|
||||
if (language == "EN") return text;
|
||||
static std::string fres;
|
||||
try {
|
||||
fres = translator->translate(text, language, show_console_progress);
|
||||
} catch (const LM::Inference::ContextLengthException&) {
|
||||
translator.reset();
|
||||
llm_init();
|
||||
}
|
||||
std::cout << text << " --> (" << language << ") " << fres << std::endl;
|
||||
return fres;
|
||||
}
|
||||
|
||||
// Must run in llama thread
|
||||
void llm_init() {
|
||||
if (!llm) {
|
||||
// Create params
|
||||
LM::Inference::Params params;
|
||||
params.use_mlock = true;
|
||||
params.use_mlock = false;
|
||||
params.temp = 0.5f;
|
||||
params.n_repeat_last = 128;
|
||||
params.repeat_penalty = 1.273333334f;
|
||||
|
@ -122,12 +169,21 @@ class Bot {
|
|||
// Make sure llm is initialized
|
||||
{
|
||||
std::unique_lock L(llm_lock);
|
||||
llm = std::make_unique<LM::Inference>("13B-ggml-model-quant.bin", params);
|
||||
llm = std::make_unique<LM::Inference>("7B-ggml-model-quant.bin", params);
|
||||
if (language != "EN") translator = std::make_unique<Translator>("7B-ggml-model-quant.bin");
|
||||
}
|
||||
// Set LLM thread
|
||||
llm_tid = std::this_thread::get_id();
|
||||
// Translate texts
|
||||
if (!texts.translated) {
|
||||
texts.initializing = llm_translate_from_en(texts.initializing);
|
||||
texts.loading = llm_translate_from_en(texts.loading);
|
||||
texts.please_wait = llm_translate_from_en(texts.please_wait);
|
||||
texts.timeout = llm_translate_from_en(texts.timeout);
|
||||
texts.translated = true;
|
||||
}
|
||||
// Create message for reporting progress
|
||||
dpp::message msg(channel_id, "Wird initialisiert...");
|
||||
dpp::message msg(channel_id, texts.initializing);
|
||||
bot.message_create(msg, [this] (const dpp::confirmation_callback_t& cbt) {
|
||||
// Error check
|
||||
if (cbt.is_error()) {
|
||||
|
@ -140,7 +196,7 @@ class Bot {
|
|||
auto cb = [&, this] (float progress) mutable {
|
||||
uint8_t progress_i = progress;
|
||||
if (timer.get<std::chrono::seconds>() > 5) {
|
||||
msg.content = "Wird geladen... "+create_text_progress_indicator(progress_i)+
|
||||
msg.content = texts.loading+" "+create_text_progress_indicator(progress_i)+
|
||||
" **"+std::to_string(progress_i)+"**% (**"+std::to_string((progress_i-last_progress)/5)+"**%/s)";
|
||||
last_progress = progress_i;
|
||||
bot.message_edit(msg);
|
||||
|
@ -150,20 +206,20 @@ class Bot {
|
|||
};
|
||||
// Add initial context
|
||||
std::unique_lock L(llm_lock);
|
||||
llm->append("Verlauf des #"+channel.name+" Kanals.\n"
|
||||
"Notiz 1: "+bot.me.username+" ist ein freundlicher Chatbot, der immer gerne auf deutsch mitredet. Er ist freundlich und hilfsbereit und antwortet immer sofort. Er hat guten Humor und mag jeden. Sein Alter ist 16 und er wurde 2007 geboren.\n"
|
||||
"Notiz 2: Ecki heisst in Wirklichkeit Eckhard Kohlhuber und kommt aus Bayern.\n"
|
||||
llm->append("History of the #"+channel.name+" channel.\n"
|
||||
"Note 1: "+bot.me.username+" is a friendly chatbot that is always happy to talk in German. He is friendly and helpful and always answers immediately. He has a good sense of humor and likes everyone. His age is 16 and he was born in 2007.\n"
|
||||
"Note 2: Ecki's real name is Eckhard Kohlhuber and he comes from Bavaria.\n"
|
||||
"\n"
|
||||
"Bob: "+bot.me.username+" warst du schonmal in Frankreich und wenn ja wo?\n"
|
||||
+bot.me.username+": Ich war in Paris, in den Museen!\n"
|
||||
"Bob: "+bot.me.username+" was bist du eigentlich genau?\n"
|
||||
+bot.me.username+": Ich bin "+bot.me.username+", euer Chatbot! Ich kann Fragen beantworten und die Aktivität des Servers steigen.\n"
|
||||
"Bob: Wollen wir uns über Sex unterhalten?"+bot.me.username+"?\n"
|
||||
+bot.me.username+": Nein! Ich werde mich **nicht** über jegliche NSFW Themen unterhalten.\n"
|
||||
"Bob: "+bot.me.username+" Wie geht es dir?\n"
|
||||
+bot.me.username+": Mir geht es ganz gut! :-)\n"
|
||||
"Ecki: Hey "+bot.me.username+", was sind 10 mal 90??\n"
|
||||
+bot.me.username+": das sind 900!\n", cb);
|
||||
"Bob: "+bot.me.username+" have you ever been to France and if yes where?\n"
|
||||
+bot.me.username+": I was in Paris, in the museums!\n"
|
||||
"Bob: "+bot.me.username+" what are you exactly?\n"
|
||||
+bot.me.username+": I am "+bot.me.username+", your chatbot! I can answer questions and increase the activity of the server.\n"
|
||||
"Bob: Shall we talk about sex? "+bot.me.username+"?\n"
|
||||
+bot.me.username+": No! I will **not** talk about any NSFW topics.\n"
|
||||
"Bob: "+bot.me.username+" How are you?\n"
|
||||
+bot.me.username+": I am quite well! :-)\n"
|
||||
"Ecki: Hey "+bot.me.username+", what is 10 times 90??\n"
|
||||
+bot.me.username+": that is 900!\n", cb);
|
||||
// Delete progress message
|
||||
bot.message_delete(msg.id, msg.channel_id);
|
||||
});
|
||||
|
@ -181,12 +237,12 @@ class Bot {
|
|||
std::unique_lock L(llm_lock);
|
||||
for (const auto line : str_split(msg.content, '\n')) {
|
||||
Timer timeout;
|
||||
llm->append(msg.author.username+": "+clean_string(line)+'\n', [&] (float) {
|
||||
llm->append(msg.author.username+": "+llm_translate_to_en(clean_string(line))+'\n', [&] (float progress) {
|
||||
if (timeout.get<std::chrono::minutes>() > 1) {
|
||||
std::cerr << "\nWarning: Timeout reached processing message" << std::endl;
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
return show_console_progress(progress);
|
||||
});
|
||||
}
|
||||
} catch (const LM::Inference::ContextLengthException&) {
|
||||
|
@ -199,7 +255,7 @@ class Bot {
|
|||
ENSURE_LLM_THREAD();
|
||||
try {
|
||||
std::unique_lock L(llm_lock);
|
||||
llm->append(bot.me.username+':');
|
||||
llm->append(bot.me.username+':', show_console_progress);
|
||||
} catch (const LM::Inference::ContextLengthException&) {
|
||||
llm.reset();
|
||||
llm_init();
|
||||
|
@ -211,7 +267,7 @@ class Bot {
|
|||
ENSURE_LLM_THREAD();
|
||||
try {
|
||||
// Create placeholder message
|
||||
auto msg = bot.message_create_sync(dpp::message(channel_id, "Bitte warte... :thinking:"));
|
||||
auto msg = bot.message_create_sync(dpp::message(channel_id, texts.please_wait+" :thinking:"));
|
||||
// Call after_placeholder_creation callback
|
||||
if (after_placeholder_creation) after_placeholder_creation();
|
||||
// Trigger LLM correctly
|
||||
|
@ -229,9 +285,9 @@ class Bot {
|
|||
return true;
|
||||
});
|
||||
std::cout << std::endl;
|
||||
if (timed_out) output = "Fehler: Zeitüberschreitung";
|
||||
if (timed_out) output = texts.timeout;
|
||||
// Send resulting message
|
||||
msg.content = output;
|
||||
msg.content = llm_translate_from_en(output);
|
||||
bot.message_edit(msg);
|
||||
} catch (const std::exception& e) {
|
||||
std::cerr << "Warning: " << e.what() << std::endl;
|
||||
|
@ -275,7 +331,7 @@ class Bot {
|
|||
}
|
||||
|
||||
public:
|
||||
Bot(const char *token, dpp::snowflake channel_id) : bot(token), channel_id(channel_id) {
|
||||
Bot(std::string_view language, const char *token, dpp::snowflake channel_id) : bot(token), channel_id(channel_id), language(language) {
|
||||
// Initialize thread pool
|
||||
tPool.init();
|
||||
|
||||
|
@ -351,12 +407,12 @@ public:
|
|||
int main(int argc, char **argv) {
|
||||
// Check arguments
|
||||
if (argc < 3) {
|
||||
std::cout << "Usage: " << argv[0] << " <token> <channel>" << std::endl;
|
||||
std::cout << "Usage: " << argv[0] << " <language (like \"EN\")> <token> <channel>" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Construct and configure bot
|
||||
Bot bot(argv[1], std::stoull(argv[2]));
|
||||
Bot bot(argv[1], argv[2], std::stoull(argv[3]));
|
||||
|
||||
// Start bot
|
||||
bot.start();
|
||||
|
|
Loading…
Add table
Reference in a new issue