1
0
Fork 0
mirror of https://gitlab.com/niansa/discord_llama.git synced 2025-03-06 20:48:25 +01:00

Added random response chance and guard translator with a mutex

This commit is contained in:
niansa 2023-05-22 22:55:59 +02:00
parent be4bd425fd
commit 4f0c95fef9
7 changed files with 26 additions and 46 deletions

View file

@ -1,43 +0,0 @@
#ifndef _PHASMOENGINE_RANDOM_HPP
#define _PHASMOENGINE_RANDOM_HPP
#include <random>
class RandomGenerator {
std::mt19937 rng;
uint32_t initialSeed;
public:
void seed() {
rng.seed(initialSeed = std::random_device{}());
}
void seed(uint32_t customSeed) {
rng.seed(initialSeed = customSeed);
}
unsigned getUInt() {
std::uniform_int_distribution<unsigned> dist;
return dist(rng);
}
unsigned getUInt(unsigned max) {
std::uniform_int_distribution<unsigned> dist(0, max);
return dist(rng);
}
unsigned getUInt(unsigned min, unsigned max) {
std::uniform_int_distribution<unsigned> dist(min, max);
return dist(rng);
}
double getDouble(double max) {
std::uniform_real_distribution<double> dist(0.0, max);
return dist(rng);
}
double getDouble(double min, double max) {
std::uniform_real_distribution<double> dist(min, max);
return dist(rng);
}
bool getBool(float chance) {
return getDouble(1.0) <= chance && chance != 0.0f;
}
};
#endif

View file

@ -145,6 +145,8 @@ void Configuration::fill(std::unordered_map<std::string, std::string>&& map, boo
ctx_size = std::stoi(value);
} else if (key == "max_context_age") {
max_context_age = std::stoi(value);
} else if (key == "random_response_chance") {
random_response_chance = std::stoi(value);
} else if (key == "mlock") {
mlock = parse_bool(value);
} else if (key == "live_edit") {
@ -192,6 +194,9 @@ void Configuration::check(bool allow_non_instruct) const {
if (shard_id >= shard_count) {
throw Exception("Error: Not enough shards for this ID to exist.");
}
if (random_response_chance && threads_only) {
throw Exception("Error: Random responses may only be given if responses outside threads are allowed.");
}
}
#include <iostream>

View file

@ -101,7 +101,8 @@ public:
scroll_keep = 20,
shard_count = 1,
shard_id = 0,
max_context_age = 0;
max_context_age = 0,
random_response_chance = 0;
bool persistance = true,
mlock = false,
live_edit = false,

@ -1 +1 @@
Subproject commit 9d1047be5460cfea8958464421efb8cafed065d0
Subproject commit 364971205902538bfc6981390319db08ae262edf

View file

@ -5,6 +5,7 @@ models_dir models
texts_file none
language EN
threads_only true
random_response_chance 0
live_edit false
default_inference_model 13b-vanilla

View file

@ -11,9 +11,12 @@ texts_file none
# Language everything is translated to (will be disabled if set to "EN" anyways)
language EN
# Weather the bot should respond to pings outside threads
# Weather the bot should respond to pings outside threads. Disabling this may increase load by a LOT
threads_only true
# Chance for bot to respond at random when allowed to talk outside threads (see option above). Chance in percent is 100 divided by given number (Example: 2 = 50%). 0 implies no random responses
random_response_chance 0
# Weather the bot should update messages periodically while writing them. Incompatible with translation
live_edit false

View file

@ -22,6 +22,7 @@
#include <justlm_pool.hpp>
#include <anyproc.hpp>
#include <scheduled_thread.hpp>
#include <scheduler_mutex.hpp>
@ -29,6 +30,7 @@ class Bot {
CoSched::ScheduledThread sched_thread;
LM::InferencePool llm_pool;
std::unique_ptr<Translator> translator;
CoSched::Mutex translator_mutex;
std::vector<dpp::snowflake> my_messages;
std::unordered_map<dpp::snowflake, dpp::user> users;
std::thread::id llm_tid;
@ -74,7 +76,9 @@ private:
// Replace bot username with [43]
utils::str_replace_in_place(fres, bot.me.username, "[43]");
// Run translation
co_await translator_mutex.lock();
fres = co_await translator->translate(fres, "EN", show_console_progress);
translator_mutex.unlock();
// Replace [43] back with bot username
utils::str_replace_in_place(fres, "[43]", bot.me.username);
std::cout << text << " --> (EN) " << fres << std::endl;
@ -93,7 +97,9 @@ private:
// Replace bot username with [43]
utils::str_replace_in_place(fres, bot.me.username, "[43]");
// Run translation
co_await translator_mutex.lock();
fres = co_await translator->translate(fres, config.language, show_console_progress);
translator_mutex.unlock();
// Replace [43] back with bot username
utils::str_replace_in_place(fres, "[43]", bot.me.username);
std::cout << text << " --> (" << config.language << ") " << fres << std::endl;
@ -437,6 +443,13 @@ private:
co_return true;
}
}
// Reply at random
if (config.random_response_chance) {
if (!(unsigned(msg.id.get_creation_time()) % config.random_response_chance)) {
co_await reply(msg.channel_id, placeholder_msg, channel_cfg);
co_return true;
}
}
// Don't reply otherwise
co_return false;
}