mirror of
https://gitlab.com/niansa/reaktor.git
synced 2025-03-06 20:53:30 +01:00
393 lines
14 KiB
C++
393 lines
14 KiB
C++
#include "unicode_emojis.h"
|
||
|
||
#include <iostream>
|
||
#include <string>
|
||
#include <chrono>
|
||
#include <memory>
|
||
#include <unordered_map>
|
||
#include <limits>
|
||
#include <cstdlib>
|
||
#include <ctime>
|
||
#include <dpp/dpp.h>
|
||
#include <justlm.hpp>
|
||
#include <halo.hpp>
|
||
#include <colorama.hpp>
|
||
#include <commoncpp/utils.hpp>
|
||
#include <commoncpp/timer.hpp>
|
||
#include <commoncpp/random.hpp>
|
||
#include <commoncpp/pooled_thread.hpp>
|
||
|
||
|
||
#ifdef REAKTOR_WITH_CHAT
|
||
std::unordered_map<dpp::snowflake, dpp::user> user_cache;
|
||
std::vector<std::string> past_responses;
|
||
|
||
|
||
LM::Inference::Params get_chat_params() {
|
||
LM::Inference::Params fres;
|
||
fres.n_eos_ignores = 0;
|
||
fres.temp = 0.3f;
|
||
fres.repeat_penalty = 1.3f;
|
||
fres.scroll_keep = 0.3f;
|
||
return fres;
|
||
}
|
||
|
||
|
||
char get_limited_alpha(unsigned hv) {
|
||
return (hv % ('U' - 'A')) + 'A';
|
||
}
|
||
|
||
|
||
bool chat_message_add(LM::Inference *model, const std::string& username, const std::string& content) {
|
||
Halo halo({"Evaluating message...", spinners::line});
|
||
halo.start();
|
||
|
||
if (!common::utils::ends_with(model->get_prompt(), "\n<")) model->append("\n<"); // Make SURE message is terminated
|
||
|
||
common::Timer timeout;
|
||
bool timed_out = false;
|
||
model->append(username+"> "+content+"\n<", [&] (float progress) -> bool {
|
||
const auto seconds_passed = timeout.get<std::chrono::seconds>();
|
||
if (seconds_passed > 45) {
|
||
timed_out = true;
|
||
halo.stop();
|
||
halo.warning("Timed out processing message!");
|
||
return false;
|
||
}
|
||
halo.settings.text = "Evaluating message... ("+std::to_string(seconds_passed)+"s) "+std::to_string(unsigned(progress))+"%";
|
||
return true;
|
||
});
|
||
|
||
if (!timed_out) halo.stop();
|
||
|
||
halo.success("Message evaluated");
|
||
return !timed_out;
|
||
}
|
||
|
||
struct ChatMessageGenerateContext {
|
||
constexpr static float init_temp = 0.4f;
|
||
|
||
std::string last_content;
|
||
float temp;
|
||
unsigned repeats;
|
||
|
||
void reset() {
|
||
last_content.clear();
|
||
temp = init_temp;
|
||
repeats = 0;
|
||
}
|
||
|
||
ChatMessageGenerateContext() {
|
||
reset();
|
||
}
|
||
};
|
||
|
||
bool chat_message_generate(LM::Inference *model, dpp::cluster& bot, dpp::message& new_msg, ChatMessageGenerateContext& ctx, const std::string& username, std::string& content) {
|
||
Halo halo({"Generating response... "+content, spinners::line});
|
||
halo.start();
|
||
|
||
if (!common::utils::ends_with(model->get_prompt(), "\n<")) model->append("\n<"); // Make SURE message is terminated
|
||
|
||
model->append(username+">");
|
||
|
||
if (ctx.repeats > 0) {
|
||
content = std::string(1, get_limited_alpha(time(nullptr)));
|
||
model->append(" "+content);
|
||
halo.settings.text += content;
|
||
ctx.temp += 1.8f;
|
||
} else {
|
||
content.clear();
|
||
}
|
||
|
||
common::Timer timeout, edit_timer;
|
||
bool timed_out = false;
|
||
float temp_before_linebreak;
|
||
bool did_linebreak = false;
|
||
unsigned token_count = 0;
|
||
content += model->run("\n<", [&, buf = std::string()] (const char *token) mutable -> bool {
|
||
buf += token;
|
||
// Update temperature
|
||
if (common::utils::contains(token, "\n")) {
|
||
did_linebreak = true;
|
||
temp_before_linebreak = ctx.temp;
|
||
ctx.temp = 0.1f;
|
||
} else if (did_linebreak) {
|
||
ctx.temp = temp_before_linebreak;
|
||
did_linebreak = false;
|
||
}
|
||
else if (common::utils::starts_with(ctx.last_content, buf) && ctx.temp < 0.9f) ctx.temp += 0.2f;
|
||
else if (buf.size() > 6) {
|
||
for (const auto& response : past_responses) {
|
||
if (response == ctx.last_content) continue;
|
||
if (common::utils::starts_with(response, buf) && ctx.temp < 0.9f) ctx.temp += 0.1f;
|
||
}
|
||
}
|
||
if (ctx.temp > ctx.init_temp + std::numeric_limits<float>::epsilon()) ctx.temp -= 0.05f;
|
||
model->params.temp = ctx.temp;
|
||
// Update repeat penalty
|
||
model->params.n_repeat_last = token_count++;
|
||
// Check for timeout
|
||
if (timeout.get<std::chrono::seconds>() > 90) {
|
||
timed_out = true;
|
||
halo.stop();
|
||
halo.warning("Timed out generating message!");
|
||
return false;
|
||
}
|
||
// Update message
|
||
if (edit_timer.get<std::chrono::seconds>() > 8) {
|
||
new_msg.content = std::move(buf);
|
||
try {
|
||
bot.message_edit_sync(new_msg);
|
||
} catch (...) {}
|
||
buf = std::move(new_msg.content);
|
||
bot.channel_typing(new_msg.channel_id);
|
||
edit_timer.reset();
|
||
}
|
||
// Update halo text
|
||
halo.settings.text += common::utils::remove_nonprintable(token);
|
||
return true;
|
||
});
|
||
|
||
if (!timed_out) halo.stop();
|
||
|
||
halo.success("Generated message:");
|
||
std::cout << common::utils::remove_nonprintable(content) << std::endl;
|
||
|
||
for (auto response : past_responses) {
|
||
response.erase(std::max(int(response.size()) - int(response.size()) / 4, 1));
|
||
ctx.temp += float(common::utils::ends_with(content, response)) * 0.4f;
|
||
}
|
||
ctx.repeats = (ctx.repeats+1) * unsigned(content == ctx.last_content);
|
||
ctx.last_content = content;
|
||
if (!content.empty())
|
||
past_responses.push_back(content);
|
||
return !timed_out;
|
||
}
|
||
|
||
bool resolve_mentions(const dpp::user& user, dpp::message& msg) {
|
||
user_cache[user.id] = user;
|
||
user_cache[msg.author.id] = msg.author;
|
||
if (common::utils::str_replace_in_place(msg.content, user.get_mention(), user.username))
|
||
return true;
|
||
if (common::utils::contains(msg.content, user.username))
|
||
return true;
|
||
bool fres = false;
|
||
for (const auto& [mentioned_user, guild] : msg.mentions) {
|
||
auto res = user_cache.find(mentioned_user.id);
|
||
if (res != user_cache.end())
|
||
common::utils::str_replace_in_place(msg.content, mentioned_user.get_mention(), res->second.username);
|
||
if (mentioned_user.id == user.id)
|
||
fres = true;
|
||
}
|
||
return fres;
|
||
}
|
||
#endif
|
||
|
||
LM::Inference::Params get_reaction_params() {
|
||
LM::Inference::Params fres;
|
||
fres.temp = 0.2f;
|
||
return fres;
|
||
}
|
||
|
||
|
||
void run(const std::string& token, const std::string& model_path, const std::string& system_prompt, dpp::snowflake chat_channel_id) {
|
||
// Create last message timer
|
||
common::Timer last_reaction_timer;
|
||
|
||
// Prepare models
|
||
common::PooledThread thread;
|
||
std::unique_ptr<LM::Inference> chatModel;
|
||
std::unique_ptr<LM::Inference> reactionModel;
|
||
thread.start();
|
||
|
||
// Prepare model
|
||
LM::Inference::Savestate reactionSavestate;
|
||
thread.enqueue([&] () {
|
||
Halo halo;
|
||
halo.settings.spinner = spinners::line;
|
||
halo.start();
|
||
|
||
#ifdef REAKTOR_WITH_CHAT
|
||
halo.settings.text = "Preparing chat model...";
|
||
chatModel.reset(LM::Inference::construct(model_path, get_chat_params()));
|
||
chatModel->append(" "+system_prompt, [&halo] (float progress) {
|
||
halo.settings.text = "Preparing chat model... "+std::to_string(unsigned(progress))+'%';
|
||
return true;
|
||
});
|
||
chatModel->params.n_ctx_window_top_bar = chatModel->get_context_size();
|
||
chatModel->params.n_ctx = chatModel->get_context_size() + 100;
|
||
#endif
|
||
|
||
halo.settings.text = "Preparing reaction model...";
|
||
reactionModel.reset(LM::Inference::construct(model_path, get_reaction_params()));
|
||
reactionModel->append("An unicode emoji fitting this message:\n\n> ");
|
||
reactionModel->create_savestate(reactionSavestate);
|
||
|
||
halo.stop();
|
||
halo.success("Models have been prepared");
|
||
});
|
||
|
||
// Configure bot
|
||
dpp::cluster bot(token);
|
||
bot.on_log(dpp::utility::cout_logger());
|
||
bot.intents |= dpp::i_guild_messages | dpp::i_message_content | dpp::i_guild_voice_states;
|
||
|
||
// Create random generator
|
||
common::RandomGenerator rng;
|
||
rng.seed();
|
||
|
||
#ifdef REAKTOR_WITH_CHAT
|
||
// Message generator
|
||
ChatMessageGenerateContext genCtx;
|
||
bot.on_message_create([&] (const dpp::message_create_t& event) {
|
||
// Skip empty messages
|
||
if (event.msg.content.empty()) return;
|
||
// Skip messages outside of chat channel
|
||
if (event.msg.channel_id != chat_channel_id) return;
|
||
// Skip own messages
|
||
if (event.msg.author.id == bot.me.id) return;
|
||
// Respond to reset command
|
||
if (event.msg.content == ";reset") {
|
||
try {
|
||
bot.message_delete_sync(event.msg.id, event.msg.channel_id);
|
||
} catch (...) {}
|
||
exit(74);
|
||
}
|
||
// Move to another thread
|
||
thread.enqueue([=, &chatModel, &bot, &genCtx, &rng/*no mutex needed*/] () {
|
||
auto msg = event.msg;
|
||
bool ok;
|
||
// Check if mentioned and resolve mentions
|
||
const bool mentioned = resolve_mentions(bot.me, msg);
|
||
// Append message; skip on error
|
||
ok = chat_message_add(chatModel.get(), msg.author.username, msg.content);
|
||
if (!ok) return;
|
||
// Skip if not mentioned and random chance
|
||
if (!mentioned && !rng.getBool(0.125f)) return;
|
||
// Create initial message
|
||
auto new_msg = bot.message_create_sync(dpp::message(msg.channel_id, "⠀"));
|
||
bot.channel_typing(new_msg.channel_id);
|
||
// Generate response
|
||
std::string response;
|
||
ok = chat_message_generate(chatModel.get(), bot, new_msg, genCtx, bot.me.username, response);
|
||
// Add ... to response on error
|
||
if (!ok) response += "...";
|
||
// Send updated response
|
||
new_msg.content = response;
|
||
new_msg.set_reference(msg.id, msg.guild_id, msg.channel_id, true);
|
||
bot.message_delete(new_msg.id, new_msg.channel_id);
|
||
bot.message_create(new_msg);
|
||
});
|
||
});
|
||
#endif
|
||
|
||
// Reaction generator
|
||
bot.on_message_create([&] (const dpp::message_create_t& event) {
|
||
// Only react to messages that are sufficiently long
|
||
if (event.msg.content.size() < 34) return;
|
||
// Only react to approx. every 10th message
|
||
if (!rng.getBool(0.04f) && last_reaction_timer.get<std::chrono::hours>() < 2) return;
|
||
// Get shortened message content
|
||
std::string content{event.msg.content.data(), std::min<size_t>(event.msg.content.size(), 160)};
|
||
// Move to another thread
|
||
thread.enqueue([=, &reactionModel, &bot, &last_reaction_timer] () {
|
||
Halo halo({"Generating reaction to: "+std::string(content), spinners::line});
|
||
halo.start();
|
||
// Prepare model
|
||
reactionModel->append(std::string(content)+"\n\nis:\n\n>");
|
||
// Run model
|
||
common::Timer timeout;
|
||
std::string result;
|
||
reactionModel->run("", [&result, &timeout] (const char *token) {
|
||
// Check for timeout
|
||
if (timeout.get<std::chrono::seconds>() > 10) {
|
||
return false;
|
||
}
|
||
// Skip leading whitespaces
|
||
while (*token == ' ') token++;
|
||
if (*token == '\0') return true;
|
||
// Check for completion
|
||
result += token;
|
||
bool fres = false;
|
||
for (const char delim : {' ', '\n', '\r', '.', ',', ':'}) {
|
||
fres += common::utils::chop_down(result, delim);
|
||
}
|
||
// Stop if emoji is done
|
||
return !fres;
|
||
});
|
||
// Extract unicode emoji
|
||
while (!result.empty()) {
|
||
if (is_unicode_emoji(result))
|
||
break;
|
||
result.pop_back();
|
||
}
|
||
// Check that there is anything left
|
||
if (result.empty()) {
|
||
// Nope, it went wrong
|
||
halo.stop();
|
||
halo.warning("Got an invalid response, discarding");
|
||
} else {
|
||
// We got it!
|
||
halo.stop();
|
||
halo.success("Response generated: "+result);
|
||
// Add emoji to message
|
||
bot.message_add_reaction(event.msg, result, [&last_reaction_timer] (const dpp::confirmation_callback_t& ccb) {
|
||
if (!ccb.is_error()) last_reaction_timer.reset();
|
||
});
|
||
}
|
||
// Finalize model
|
||
reactionModel->restore_savestate(reactionSavestate);
|
||
});
|
||
});
|
||
|
||
// Connection success message
|
||
bot.on_ready([] (const dpp::ready_t& event) {
|
||
std::cout << "Connected to Discord!" << std::endl;
|
||
});
|
||
|
||
// Start bot
|
||
bot.start(dpp::st_wait);
|
||
}
|
||
|
||
int main(int argc, char **argv) {
|
||
colorama::init();
|
||
|
||
// Check args
|
||
#ifdef REAKTOR_WITH_CHAT
|
||
if (argc != 4) {
|
||
std::cout << "Usage: " << argv[0] << " <model file> <chat channel id> <system prompt file>" << std::endl;
|
||
#else
|
||
if (argc != 2) {
|
||
std::cout << "Usage: " << argv[0] << " <model file>" << std::endl;
|
||
#endif
|
||
return -1;
|
||
}
|
||
|
||
// Get args
|
||
const auto model_path(argv[1]);
|
||
const dpp::snowflake chat_channel_id(argv[2]);
|
||
std::string system_prompt = common::utils::read_text_file(argv[3]);
|
||
common::utils::force_trailing(system_prompt, "\n");
|
||
|
||
// Get token
|
||
std::string token;
|
||
const char *token_env = getenv("LMFUN_BOT_TOKEN");
|
||
if (token_env) {
|
||
// Use token from environment
|
||
token = token_env;
|
||
unsetenv("LMFUN_BOT_TOKEN");
|
||
} else {
|
||
// Request token from stdin
|
||
std::cout << "Token: ";
|
||
std::getline(std::cin, token);
|
||
|
||
// Redact token
|
||
std::cout << '\r' << colorama::Cursor::UP() << "Token: <redacted>";
|
||
for (size_t it = 0; it != std::max<std::size_t>(token.size(), 10)-10; it++) {
|
||
std::cout << ' ';
|
||
}
|
||
std::cout << std::endl;
|
||
}
|
||
|
||
run(token, model_path, system_prompt, chat_channel_id);
|
||
}
|