1
0
Fork 0
mirror of https://gitlab.com/niansa/discord_llama.git synced 2025-03-06 20:48:25 +01:00

Minor cleanup

This commit is contained in:
niansa 2023-04-22 23:59:06 +02:00
parent 3716eb9db0
commit 6364f49d4a

View file

@ -1,4 +1,3 @@
#include "Random.hpp"
#include "Timer.hpp"
#include <string>
@ -44,28 +43,8 @@ void str_replace_in_place(std::string& subject, std::string_view search,
}
}
static inline
std::string clean_string(std::string_view str) {
std::string fres;
for (const auto c : str) {
if ((c >= 0x20 && c <= 0x7E)
|| c == '\n'
|| c == "ä"[0] || c == "ä"[1] || c == "ä"[2]
|| c == "ö"[0] || c == "ö"[1] || c == "ö"[2]
|| c == "ü"[0] || c == "ü"[1] || c == "ü"[2]
|| c == "Ä"[0] || c == "Ä"[1] || c == "Ä"[2]
|| c == "Ö"[0] || c == "Ö"[1] || c == "Ö"[2]
|| c == "Ü"[0] || c == "Ü"[1] || c == "Ü"[2]
|| c == "ß"[0] || c == "ß"[1] || c == "ß"[2]) {
fres.push_back(c);
}
}
return fres;
}
class Bot {
RandomGenerator rng;
ThreadPool tPool{1};
Timer last_message_timer;
std::shared_ptr<bool> stopping;
@ -129,7 +108,7 @@ class Bot {
# define ENSURE_LLM_THREAD() if (std::this_thread::get_id() != llm_tid) {throw std::runtime_error("LLM execution of '"+std::string(__PRETTY_FUNCTION__)+"' on wrong thread detected");} 0
// Must run in llama thread
const std::string& llm_translate_to_en(const std::string& text) {
std::string_view llm_translate_to_en(std::string_view text) {
ENSURE_LLM_THREAD();
// No need for translation if language is english already
if (language == "EN") return text;
@ -153,7 +132,7 @@ class Bot {
}
// Must run in llama thread
const std::string& llm_translate_from_en(const std::string& text) {
std::string_view llm_translate_from_en(std::string_view text) {
ENSURE_LLM_THREAD();
// No need for translation if language is english already
if (language == "EN") return text;
@ -261,7 +240,7 @@ class Bot {
for (const auto line : str_split(msg.content, '\n')) {
Timer timeout;
bool timeout_exceeded = false;
llm.append(msg.author.username+": "+llm_translate_to_en(clean_string(line))+'\n', [&] (float progress) {
llm.append(msg.author.username+": "+std::string(llm_translate_to_en(line))+'\n', [&] (float progress) {
if (timeout.get<std::chrono::minutes>() > 1) {
std::cerr << "\nWarning: Timeout exceeded processing message" << std::endl;
timeout_exceeded = true;
@ -324,10 +303,6 @@ class Bot {
// Must run in llama thread
bool attempt_reply(const dpp::message& msg, const std::function<void ()>& after_placeholder_creation = nullptr) {
ENSURE_LLM_THREAD();
// Decide randomly
/*if (rng.getBool(0.075f)) {
return reply();
}*/
// Reply if message contains username, mention or ID
if (msg.content.find(bot.me.username) != std::string::npos) {
reply(after_placeholder_creation);
@ -379,8 +354,6 @@ public:
throw std::runtime_error("Failed to get channel: "+cbt.get_error().message);
}
channel = cbt.get<dpp::channel>();
// Initialize random generator
rng.seed(bot.me.id);
// Append initial prompt
tPool.submit(std::bind(&Bot::llm_init, this));
// Start idle auto reply thread