1
0
Fork 0
mirror of https://gitlab.com/niansa/reaktor.git synced 2025-03-06 20:53:30 +01:00
reaktor/main.cpp
2023-10-31 15:28:08 +01:00

393 lines
14 KiB
C++
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#include "unicode_emojis.h"
#include <iostream>
#include <string>
#include <chrono>
#include <memory>
#include <unordered_map>
#include <limits>
#include <cstdlib>
#include <ctime>
#include <dpp/dpp.h>
#include <justlm.hpp>
#include <halo.hpp>
#include <colorama.hpp>
#include <commoncpp/utils.hpp>
#include <commoncpp/timer.hpp>
#include <commoncpp/random.hpp>
#include <commoncpp/pooled_thread.hpp>
#ifdef REAKTOR_WITH_CHAT
std::unordered_map<dpp::snowflake, dpp::user> user_cache;
std::vector<std::string> past_responses;
LM::Inference::Params get_chat_params() {
LM::Inference::Params fres;
fres.n_eos_ignores = 0;
fres.temp = 0.3f;
fres.repeat_penalty = 1.3f;
fres.scroll_keep = 0.3f;
return fres;
}
char get_limited_alpha(unsigned hv) {
return (hv % ('U' - 'A')) + 'A';
}
bool chat_message_add(LM::Inference *model, const std::string& username, const std::string& content) {
Halo halo({"Evaluating message...", spinners::line});
halo.start();
if (!common::utils::ends_with(model->get_prompt(), "\n<")) model->append("\n<"); // Make SURE message is terminated
common::Timer timeout;
bool timed_out = false;
model->append(username+"> "+content+"\n<", [&] (float progress) -> bool {
const auto seconds_passed = timeout.get<std::chrono::seconds>();
if (seconds_passed > 45) {
timed_out = true;
halo.stop();
halo.warning("Timed out processing message!");
return false;
}
halo.settings.text = "Evaluating message... ("+std::to_string(seconds_passed)+"s) "+std::to_string(unsigned(progress))+"%";
return true;
});
if (!timed_out) halo.stop();
halo.success("Message evaluated");
return !timed_out;
}
struct ChatMessageGenerateContext {
constexpr static float init_temp = 0.4f;
std::string last_content;
float temp;
unsigned repeats;
void reset() {
last_content.clear();
temp = init_temp;
repeats = 0;
}
ChatMessageGenerateContext() {
reset();
}
};
bool chat_message_generate(LM::Inference *model, dpp::cluster& bot, dpp::message& new_msg, ChatMessageGenerateContext& ctx, const std::string& username, std::string& content) {
Halo halo({"Generating response... "+content, spinners::line});
halo.start();
if (!common::utils::ends_with(model->get_prompt(), "\n<")) model->append("\n<"); // Make SURE message is terminated
model->append(username+">");
if (ctx.repeats > 0) {
content = std::string(1, get_limited_alpha(time(nullptr)));
model->append(" "+content);
halo.settings.text += content;
ctx.temp += 1.8f;
} else {
content.clear();
}
common::Timer timeout, edit_timer;
bool timed_out = false;
float temp_before_linebreak;
bool did_linebreak = false;
unsigned token_count = 0;
content += model->run("\n<", [&, buf = std::string()] (const char *token) mutable -> bool {
buf += token;
// Update temperature
if (common::utils::contains(token, "\n")) {
did_linebreak = true;
temp_before_linebreak = ctx.temp;
ctx.temp = 0.1f;
} else if (did_linebreak) {
ctx.temp = temp_before_linebreak;
did_linebreak = false;
}
else if (common::utils::starts_with(ctx.last_content, buf) && ctx.temp < 0.9f) ctx.temp += 0.2f;
else if (buf.size() > 6) {
for (const auto& response : past_responses) {
if (response == ctx.last_content) continue;
if (common::utils::starts_with(response, buf) && ctx.temp < 0.9f) ctx.temp += 0.1f;
}
}
if (ctx.temp > ctx.init_temp + std::numeric_limits<float>::epsilon()) ctx.temp -= 0.05f;
model->params.temp = ctx.temp;
// Update repeat penalty
model->params.n_repeat_last = token_count++;
// Check for timeout
if (timeout.get<std::chrono::seconds>() > 90) {
timed_out = true;
halo.stop();
halo.warning("Timed out generating message!");
return false;
}
// Update message
if (edit_timer.get<std::chrono::seconds>() > 8) {
new_msg.content = std::move(buf);
try {
bot.message_edit_sync(new_msg);
} catch (...) {}
buf = std::move(new_msg.content);
bot.channel_typing(new_msg.channel_id);
edit_timer.reset();
}
// Update halo text
halo.settings.text += common::utils::remove_nonprintable(token);
return true;
});
if (!timed_out) halo.stop();
halo.success("Generated message:");
std::cout << common::utils::remove_nonprintable(content) << std::endl;
for (auto response : past_responses) {
response.erase(std::max(int(response.size()) - int(response.size()) / 4, 1));
ctx.temp += float(common::utils::ends_with(content, response)) * 0.4f;
}
ctx.repeats = (ctx.repeats+1) * unsigned(content == ctx.last_content);
ctx.last_content = content;
if (!content.empty())
past_responses.push_back(content);
return !timed_out;
}
bool resolve_mentions(const dpp::user& user, dpp::message& msg) {
user_cache[user.id] = user;
user_cache[msg.author.id] = msg.author;
if (common::utils::str_replace_in_place(msg.content, user.get_mention(), user.username))
return true;
if (common::utils::contains(msg.content, user.username))
return true;
bool fres = false;
for (const auto& [mentioned_user, guild] : msg.mentions) {
auto res = user_cache.find(mentioned_user.id);
if (res != user_cache.end())
common::utils::str_replace_in_place(msg.content, mentioned_user.get_mention(), res->second.username);
if (mentioned_user.id == user.id)
fres = true;
}
return fres;
}
#endif
LM::Inference::Params get_reaction_params() {
LM::Inference::Params fres;
fres.temp = 0.2f;
return fres;
}
void run(const std::string& token, const std::string& model_path, const std::string& system_prompt, dpp::snowflake chat_channel_id) {
// Create last message timer
common::Timer last_reaction_timer;
// Prepare models
common::PooledThread thread;
std::unique_ptr<LM::Inference> chatModel;
std::unique_ptr<LM::Inference> reactionModel;
thread.start();
// Prepare model
LM::Inference::Savestate reactionSavestate;
thread.enqueue([&] () {
Halo halo;
halo.settings.spinner = spinners::line;
halo.start();
#ifdef REAKTOR_WITH_CHAT
halo.settings.text = "Preparing chat model...";
chatModel.reset(LM::Inference::construct(model_path, get_chat_params()));
chatModel->append(" "+system_prompt, [&halo] (float progress) {
halo.settings.text = "Preparing chat model... "+std::to_string(unsigned(progress))+'%';
return true;
});
chatModel->params.n_ctx_window_top_bar = chatModel->get_context_size();
chatModel->params.n_ctx = chatModel->get_context_size() + 100;
#endif
halo.settings.text = "Preparing reaction model...";
reactionModel.reset(LM::Inference::construct(model_path, get_reaction_params()));
reactionModel->append("An unicode emoji fitting this message:\n\n> ");
reactionModel->create_savestate(reactionSavestate);
halo.stop();
halo.success("Models have been prepared");
});
// Configure bot
dpp::cluster bot(token);
bot.on_log(dpp::utility::cout_logger());
bot.intents |= dpp::i_guild_messages | dpp::i_message_content | dpp::i_guild_voice_states;
// Create random generator
common::RandomGenerator rng;
rng.seed();
#ifdef REAKTOR_WITH_CHAT
// Message generator
ChatMessageGenerateContext genCtx;
bot.on_message_create([&] (const dpp::message_create_t& event) {
// Skip empty messages
if (event.msg.content.empty()) return;
// Skip messages outside of chat channel
if (event.msg.channel_id != chat_channel_id) return;
// Skip own messages
if (event.msg.author.id == bot.me.id) return;
// Respond to reset command
if (event.msg.content == ";reset") {
try {
bot.message_delete_sync(event.msg.id, event.msg.channel_id);
} catch (...) {}
exit(74);
}
// Move to another thread
thread.enqueue([=, &chatModel, &bot, &genCtx, &rng/*no mutex needed*/] () {
auto msg = event.msg;
bool ok;
// Check if mentioned and resolve mentions
const bool mentioned = resolve_mentions(bot.me, msg);
// Append message; skip on error
ok = chat_message_add(chatModel.get(), msg.author.username, msg.content);
if (!ok) return;
// Skip if not mentioned and random chance
if (!mentioned && !rng.getBool(0.125f)) return;
// Create initial message
auto new_msg = bot.message_create_sync(dpp::message(msg.channel_id, ""));
bot.channel_typing(new_msg.channel_id);
// Generate response
std::string response;
ok = chat_message_generate(chatModel.get(), bot, new_msg, genCtx, bot.me.username, response);
// Add ... to response on error
if (!ok) response += "...";
// Send updated response
new_msg.content = response;
new_msg.set_reference(msg.id, msg.guild_id, msg.channel_id, true);
bot.message_delete(new_msg.id, new_msg.channel_id);
bot.message_create(new_msg);
});
});
#endif
// Reaction generator
bot.on_message_create([&] (const dpp::message_create_t& event) {
// Only react to messages that are sufficiently long
if (event.msg.content.size() < 34) return;
// Only react to approx. every 10th message
if (!rng.getBool(0.04f) && last_reaction_timer.get<std::chrono::hours>() < 2) return;
// Get shortened message content
std::string content{event.msg.content.data(), std::min<size_t>(event.msg.content.size(), 160)};
// Move to another thread
thread.enqueue([=, &reactionModel, &bot, &last_reaction_timer] () {
Halo halo({"Generating reaction to: "+std::string(content), spinners::line});
halo.start();
// Prepare model
reactionModel->append(std::string(content)+"\n\nis:\n\n>");
// Run model
common::Timer timeout;
std::string result;
reactionModel->run("", [&result, &timeout] (const char *token) {
// Check for timeout
if (timeout.get<std::chrono::seconds>() > 10) {
return false;
}
// Skip leading whitespaces
while (*token == ' ') token++;
if (*token == '\0') return true;
// Check for completion
result += token;
bool fres = false;
for (const char delim : {' ', '\n', '\r', '.', ',', ':'}) {
fres += common::utils::chop_down(result, delim);
}
// Stop if emoji is done
return !fres;
});
// Extract unicode emoji
while (!result.empty()) {
if (is_unicode_emoji(result))
break;
result.pop_back();
}
// Check that there is anything left
if (result.empty()) {
// Nope, it went wrong
halo.stop();
halo.warning("Got an invalid response, discarding");
} else {
// We got it!
halo.stop();
halo.success("Response generated: "+result);
// Add emoji to message
bot.message_add_reaction(event.msg, result, [&last_reaction_timer] (const dpp::confirmation_callback_t& ccb) {
if (!ccb.is_error()) last_reaction_timer.reset();
});
}
// Finalize model
reactionModel->restore_savestate(reactionSavestate);
});
});
// Connection success message
bot.on_ready([] (const dpp::ready_t& event) {
std::cout << "Connected to Discord!" << std::endl;
});
// Start bot
bot.start(dpp::st_wait);
}
int main(int argc, char **argv) {
colorama::init();
// Check args
#ifdef REAKTOR_WITH_CHAT
if (argc != 4) {
std::cout << "Usage: " << argv[0] << " <model file> <chat channel id> <system prompt file>" << std::endl;
#else
if (argc != 2) {
std::cout << "Usage: " << argv[0] << " <model file>" << std::endl;
#endif
return -1;
}
// Get args
const auto model_path(argv[1]);
const dpp::snowflake chat_channel_id(argv[2]);
std::string system_prompt = common::utils::read_text_file(argv[3]);
common::utils::force_trailing(system_prompt, "\n");
// Get token
std::string token;
const char *token_env = getenv("LMFUN_BOT_TOKEN");
if (token_env) {
// Use token from environment
token = token_env;
unsetenv("LMFUN_BOT_TOKEN");
} else {
// Request token from stdin
std::cout << "Token: ";
std::getline(std::cin, token);
// Redact token
std::cout << '\r' << colorama::Cursor::UP() << "Token: <redacted>";
for (size_t it = 0; it != std::max<std::size_t>(token.size(), 10)-10; it++) {
std::cout << ' ';
}
std::cout << std::endl;
}
run(token, model_path, system_prompt, chat_channel_id);
}