#include "unicode_emojis.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef REAKTOR_WITH_CHAT std::unordered_map user_cache; std::vector past_responses; LM::Inference::Params get_chat_params() { LM::Inference::Params fres; fres.n_eos_ignores = 0; fres.temp = 0.3f; fres.repeat_penalty = 1.3f; fres.scroll_keep = 0.3f; return fres; } char get_limited_alpha(unsigned hv) { return (hv % ('U' - 'A')) + 'A'; } bool chat_message_add(LM::Inference *model, const std::string& username, const std::string& content) { Halo halo({"Evaluating message...", spinners::line}); halo.start(); if (!common::utils::ends_with(model->get_prompt(), "\n<")) model->append("\n<"); // Make SURE message is terminated common::Timer timeout; bool timed_out = false; model->append(username+"> "+content+"\n<", [&] (float progress) -> bool { const auto seconds_passed = timeout.get(); if (seconds_passed > 45) { timed_out = true; halo.stop(); halo.warning("Timed out processing message!"); return false; } halo.settings.text = "Evaluating message... ("+std::to_string(seconds_passed)+"s) "+std::to_string(unsigned(progress))+"%"; return true; }); if (!timed_out) halo.stop(); halo.success("Message evaluated"); return !timed_out; } struct ChatMessageGenerateContext { constexpr static float init_temp = 0.4f; std::string last_content; float temp; unsigned repeats; void reset() { last_content.clear(); temp = init_temp; repeats = 0; } ChatMessageGenerateContext() { reset(); } }; bool chat_message_generate(LM::Inference *model, dpp::cluster& bot, dpp::message& new_msg, ChatMessageGenerateContext& ctx, const std::string& username, std::string& content) { Halo halo({"Generating response... "+content, spinners::line}); halo.start(); if (!common::utils::ends_with(model->get_prompt(), "\n<")) model->append("\n<"); // Make SURE message is terminated model->append(username+">"); if (ctx.repeats > 0) { content = std::string(1, get_limited_alpha(time(nullptr))); model->append(" "+content); halo.settings.text += content; ctx.temp += 1.8f; } else { content.clear(); } common::Timer timeout, edit_timer; bool timed_out = false; float temp_before_linebreak; bool did_linebreak = false; unsigned token_count = 0; content += model->run("\n<", [&, buf = std::string()] (const char *token) mutable -> bool { buf += token; // Update temperature if (common::utils::contains(token, "\n")) { did_linebreak = true; temp_before_linebreak = ctx.temp; ctx.temp = 0.1f; } else if (did_linebreak) { ctx.temp = temp_before_linebreak; did_linebreak = false; } else if (common::utils::starts_with(ctx.last_content, buf) && ctx.temp < 0.9f) ctx.temp += 0.2f; else if (buf.size() > 6) { for (const auto& response : past_responses) { if (response == ctx.last_content) continue; if (common::utils::starts_with(response, buf) && ctx.temp < 0.9f) ctx.temp += 0.1f; } } if (ctx.temp > ctx.init_temp + std::numeric_limits::epsilon()) ctx.temp -= 0.05f; model->params.temp = ctx.temp; // Update repeat penalty model->params.n_repeat_last = token_count++; // Check for timeout if (timeout.get() > 90) { timed_out = true; halo.stop(); halo.warning("Timed out generating message!"); return false; } // Update message if (edit_timer.get() > 8) { new_msg.content = std::move(buf); try { bot.message_edit_sync(new_msg); } catch (...) {} buf = std::move(new_msg.content); bot.channel_typing(new_msg.channel_id); edit_timer.reset(); } // Update halo text halo.settings.text += common::utils::remove_nonprintable(token); return true; }); if (!timed_out) halo.stop(); halo.success("Generated message:"); std::cout << common::utils::remove_nonprintable(content) << std::endl; for (auto response : past_responses) { response.erase(std::max(int(response.size()) - int(response.size()) / 4, 1)); ctx.temp += float(common::utils::ends_with(content, response)) * 0.4f; } ctx.repeats = (ctx.repeats+1) * unsigned(content == ctx.last_content); ctx.last_content = content; if (!content.empty()) past_responses.push_back(content); return !timed_out; } bool resolve_mentions(const dpp::user& user, dpp::message& msg) { user_cache[user.id] = user; user_cache[msg.author.id] = msg.author; if (common::utils::str_replace_in_place(msg.content, user.get_mention(), user.username)) return true; if (common::utils::contains(msg.content, user.username)) return true; bool fres = false; for (const auto& [mentioned_user, guild] : msg.mentions) { auto res = user_cache.find(mentioned_user.id); if (res != user_cache.end()) common::utils::str_replace_in_place(msg.content, mentioned_user.get_mention(), res->second.username); if (mentioned_user.id == user.id) fres = true; } return fres; } #endif LM::Inference::Params get_reaction_params() { LM::Inference::Params fres; fres.temp = 0.2f; return fres; } void run(const std::string& token, const std::string& model_path, const std::string& system_prompt, dpp::snowflake chat_channel_id) { // Create last message timer common::Timer last_reaction_timer; // Prepare models common::PooledThread thread; std::unique_ptr chatModel; std::unique_ptr reactionModel; thread.start(); // Prepare model LM::Inference::Savestate reactionSavestate; thread.enqueue([&] () { Halo halo; halo.settings.spinner = spinners::line; halo.start(); #ifdef REAKTOR_WITH_CHAT halo.settings.text = "Preparing chat model..."; chatModel.reset(LM::Inference::construct(model_path, get_chat_params())); chatModel->append(" "+system_prompt, [&halo] (float progress) { halo.settings.text = "Preparing chat model... "+std::to_string(unsigned(progress))+'%'; return true; }); chatModel->params.n_ctx_window_top_bar = chatModel->get_context_size(); chatModel->params.n_ctx = chatModel->get_context_size() + 100; #endif halo.settings.text = "Preparing reaction model..."; reactionModel.reset(LM::Inference::construct(model_path, get_reaction_params())); reactionModel->append("An unicode emoji fitting this message:\n\n> "); reactionModel->create_savestate(reactionSavestate); halo.stop(); halo.success("Models have been prepared"); }); // Configure bot dpp::cluster bot(token); bot.on_log(dpp::utility::cout_logger()); bot.intents |= dpp::i_guild_messages | dpp::i_message_content | dpp::i_guild_voice_states; // Create random generator common::RandomGenerator rng; rng.seed(); #ifdef REAKTOR_WITH_CHAT // Message generator ChatMessageGenerateContext genCtx; bot.on_message_create([&] (const dpp::message_create_t& event) { // Skip empty messages if (event.msg.content.empty()) return; // Skip messages outside of chat channel if (event.msg.channel_id != chat_channel_id) return; // Skip own messages if (event.msg.author.id == bot.me.id) return; // Respond to reset command if (event.msg.content == ";reset") { try { bot.message_delete_sync(event.msg.id, event.msg.channel_id); } catch (...) {} exit(74); } // Move to another thread thread.enqueue([=, &chatModel, &bot, &genCtx, &rng/*no mutex needed*/] () { auto msg = event.msg; bool ok; // Check if mentioned and resolve mentions const bool mentioned = resolve_mentions(bot.me, msg); // Append message; skip on error ok = chat_message_add(chatModel.get(), msg.author.username, msg.content); if (!ok) return; // Skip if not mentioned and random chance if (!mentioned && !rng.getBool(0.125f)) return; // Create initial message auto new_msg = bot.message_create_sync(dpp::message(msg.channel_id, "⠀")); bot.channel_typing(new_msg.channel_id); // Generate response std::string response; ok = chat_message_generate(chatModel.get(), bot, new_msg, genCtx, bot.me.username, response); // Add ... to response on error if (!ok) response += "..."; // Send updated response new_msg.content = response; new_msg.set_reference(msg.id, msg.guild_id, msg.channel_id, true); bot.message_delete(new_msg.id, new_msg.channel_id); bot.message_create(new_msg); }); }); #endif // Reaction generator bot.on_message_create([&] (const dpp::message_create_t& event) { // Only react to messages that are sufficiently long if (event.msg.content.size() < 34) return; // Only react to approx. every 10th message if (!rng.getBool(0.04f) && last_reaction_timer.get() < 2) return; // Get shortened message content std::string content{event.msg.content.data(), std::min(event.msg.content.size(), 160)}; // Move to another thread thread.enqueue([=, &reactionModel, &bot, &last_reaction_timer] () { Halo halo({"Generating reaction to: "+std::string(content), spinners::line}); halo.start(); // Prepare model reactionModel->append(std::string(content)+"\n\nis:\n\n>"); // Run model common::Timer timeout; std::string result; reactionModel->run("", [&result, &timeout] (const char *token) { // Check for timeout if (timeout.get() > 10) { return false; } // Skip leading whitespaces while (*token == ' ') token++; if (*token == '\0') return true; // Check for completion result += token; bool fres = false; for (const char delim : {' ', '\n', '\r', '.', ',', ':'}) { fres += common::utils::chop_down(result, delim); } // Stop if emoji is done return !fres; }); // Extract unicode emoji while (!result.empty()) { if (is_unicode_emoji(result)) break; result.pop_back(); } // Check that there is anything left if (result.empty()) { // Nope, it went wrong halo.stop(); halo.warning("Got an invalid response, discarding"); } else { // We got it! halo.stop(); halo.success("Response generated: "+result); // Add emoji to message bot.message_add_reaction(event.msg, result, [&last_reaction_timer] (const dpp::confirmation_callback_t& ccb) { if (!ccb.is_error()) last_reaction_timer.reset(); }); } // Finalize model reactionModel->restore_savestate(reactionSavestate); }); }); // Connection success message bot.on_ready([] (const dpp::ready_t& event) { std::cout << "Connected to Discord!" << std::endl; }); // Start bot bot.start(dpp::st_wait); } int main(int argc, char **argv) { colorama::init(); // Check args #ifdef REAKTOR_WITH_CHAT if (argc != 4) { std::cout << "Usage: " << argv[0] << " " << std::endl; #else if (argc != 2) { std::cout << "Usage: " << argv[0] << " " << std::endl; #endif return -1; } // Get args const auto model_path(argv[1]); const dpp::snowflake chat_channel_id(argv[2]); std::string system_prompt = common::utils::read_text_file(argv[3]); common::utils::force_trailing(system_prompt, "\n"); // Get token std::string token; const char *token_env = getenv("LMFUN_BOT_TOKEN"); if (token_env) { // Use token from environment token = token_env; unsetenv("LMFUN_BOT_TOKEN"); } else { // Request token from stdin std::cout << "Token: "; std::getline(std::cin, token); // Redact token std::cout << '\r' << colorama::Cursor::UP() << "Token: "; for (size_t it = 0; it != std::max(token.size(), 10)-10; it++) { std::cout << ' '; } std::cout << std::endl; } run(token, model_path, system_prompt, chat_channel_id); }