From 09a1cc5b095ef2b53e3b149426c2f0369078f1c6 Mon Sep 17 00:00:00 2001 From: niansa Date: Sun, 19 Mar 2023 21:09:15 +0100 Subject: [PATCH] Initial commit --- .gitignore | 1 + .gitmodules | 6 ++ CMakeLists.txt | 4 + DPP | 1 + ProcPipe.hpp | 174 ++++++++++++++++++++++++++++++ Random.hpp | 43 ++++++++ Timer.hpp | 25 +++++ llama.cpp | 1 + main.cpp | 279 +++++++++++++++++++++++++++++++++++++++++++++++-- 9 files changed, 528 insertions(+), 6 deletions(-) create mode 100644 .gitignore create mode 100644 .gitmodules create mode 160000 DPP create mode 100644 ProcPipe.hpp create mode 100644 Random.hpp create mode 100644 Timer.hpp create mode 160000 llama.cpp diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cf25ac6 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +CMakeLists.txt.user* diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..0621655 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,6 @@ +[submodule "llama.cpp"] + path = llama.cpp + url = https://github.com/ggerganov/llama.cpp.git +[submodule "DPP"] + path = DPP + url = https://github.com/brainboxdotcc/DPP.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 8f3a59f..dc886e4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,7 +5,11 @@ project(discord_llama LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) +add_subdirectory(llama.cpp) +add_subdirectory(DPP) + add_executable(discord_llama main.cpp) +target_link_libraries(discord_llama PUBLIC dpp) install(TARGETS discord_llama LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/DPP b/DPP new file mode 160000 index 0000000..06a0483 --- /dev/null +++ b/DPP @@ -0,0 +1 @@ +Subproject commit 06a0483a454e3b339436fbc0c25cc37a4fe11c4a diff --git a/ProcPipe.hpp b/ProcPipe.hpp new file mode 100644 index 0000000..181587b --- /dev/null +++ b/ProcPipe.hpp @@ -0,0 +1,174 @@ +#ifndef __WIN32 +#ifndef _PROCPIPE_HPP +#define _PROCPIPE_HPP +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + +template +class ProcPipe { + struct Pipe { + int readFd = -1, writeFd = -1; + + auto make() { + return pipe(reinterpret_cast(this)); + } + ~Pipe() { + close(readFd); close(writeFd); + } + }; + struct Redirect { + Pipe *pipe; + int fd; + bool output; + int fdbak = -1; + }; + + constexpr static int errExit = 48; + pid_t pid = 0; + Pipe stdin, + stdout, + stderr; + + template + auto recvFrom(int fd) { + static_assert (size != 0, "Can't read zero bytes"); + std::vector fres(size); + ssize_t bytes_read; + if ((bytes_read = read(fd, fres.data(), fres.size())) < 0) { + throw FdError("Failed to read() from stdout"); + } + fres.resize(bytes_read); + return fres; + } + +public: + struct ExecutionError : public std::runtime_error { + using std::runtime_error::runtime_error; + }; + struct AlreadyRunning : public std::runtime_error { + using std::runtime_error::runtime_error; + }; + struct FdError : public std::runtime_error { + using std::runtime_error::runtime_error; + }; + + ProcPipe() {} + template + ProcPipe(Args&&... args) { + start(args...); + } + ~ProcPipe() { + terminate(); + } + + void send(std::string_view str) { + static_assert (redir_stdin, "Can't write to stdin if not redirected"); + if (write(stdin.writeFd, str.data(), str.size()) < 0) { + throw FdError("Failed to write() to stdin"); + } + } + + template + auto recvStd() { + static_assert (redir_stdout, "Can't read from stdout if not redirected"); + return recvFrom(stdout.readFd); + } + + template + auto recvErr() { + static_assert (redir_stderr, "Can't read from stdout if not redirected"); + return recvFrom(stderr.readFd); + } + + auto makeRedirs() { + constexpr int redirs_size = redir_stdin + redir_stdout + redir_stderr; + std::array redirs = {}; + { + int idx = 0; + if constexpr(redir_stdin) { + redirs[idx++] = {&stdin, STDIN_FILENO, true}; + } + if constexpr(redir_stdout) { + redirs[idx++] = {&stdout, STDOUT_FILENO, false}; + } + if constexpr(redir_stderr) { + redirs[idx++] = {&stderr, STDERR_FILENO, false}; + } + } + return redirs; + } + + template + void start(Args&&... args) { + if (pid) { + throw AlreadyRunning("Tried to run process in an instance where it is already running"); + } else { + // Make redirects + auto redirs = makeRedirs(); + // Redirect fds + for (auto& io : redirs) { + // Backup fd + io.fdbak = dup(io.fd); + // Create new pipe + io.pipe->make(); + dup2((io.output ? io.pipe->readFd : io.pipe->writeFd), io.fd); + } + // Run process + pid = fork(); + if (pid == 0) { + const auto executable = std::get<0>(std::tuple{args...}); + execlp(executable, args..., nullptr); + perror((std::string("Failed to launch ")+executable).c_str()); + exit(errExit); + } + // Restore fds + for (const auto& io : redirs) { + // Restore + dup2(io.fdbak, io.fd); + } + } + } + + auto waitExit() noexcept { + if (pid) { + int status = 0; + waitpid(pid, &status, 0); + pid = 0; + return status; + } else { + return -1; + } + } + + auto terminate() noexcept { + if (pid) { + ::kill(pid, SIGTERM); + return waitExit(); + } else { + return -1; + } + } + auto kill() noexcept { + ::kill(pid, SIGKILL); + } + + auto isRunning() noexcept { + return !(::kill(pid, 0) < 0); + } +}; +#endif +#endif diff --git a/Random.hpp b/Random.hpp new file mode 100644 index 0000000..c6d2345 --- /dev/null +++ b/Random.hpp @@ -0,0 +1,43 @@ +#ifndef _PHASMOENGINE_RANDOM_HPP +#define _PHASMOENGINE_RANDOM_HPP +#include + + + +class RandomGenerator { + std::mt19937 rng; + uint32_t initialSeed; + +public: + void seed() { + rng.seed(initialSeed = std::random_device{}()); + } + void seed(uint32_t customSeed) { + rng.seed(initialSeed = customSeed); + } + + unsigned getUInt() { + std::uniform_int_distribution dist; + return dist(rng); + } + unsigned getUInt(unsigned max) { + std::uniform_int_distribution dist(0, max); + return dist(rng); + } + unsigned getUInt(unsigned min, unsigned max) { + std::uniform_int_distribution dist(min, max); + return dist(rng); + } + double getDouble(double max) { + std::uniform_real_distribution dist(0.0, max); + return dist(rng); + } + double getDouble(double min, double max) { + std::uniform_real_distribution dist(min, max); + return dist(rng); + } + bool getBool(float chance) { + return getDouble(1.0) <= chance && chance != 0.0f; + } +}; +#endif diff --git a/Timer.hpp b/Timer.hpp new file mode 100644 index 0000000..1276974 --- /dev/null +++ b/Timer.hpp @@ -0,0 +1,25 @@ +#ifndef _PHASMOENGINE_TIMER_HPP +#define _PHASMOENGINE_TIMER_HPP +#include + + + +class Timer { + std::chrono::time_point value; + +public: + Timer() { + reset(); + } + + void reset() { + value = std::chrono::high_resolution_clock::now(); + } + + template + auto get() { + auto duration = std::chrono::duration_cast(std::chrono::high_resolution_clock::now() - value); + return duration.count(); + } +}; +#endif diff --git a/llama.cpp b/llama.cpp new file mode 160000 index 0000000..d7def1a --- /dev/null +++ b/llama.cpp @@ -0,0 +1 @@ +Subproject commit d7def1a7524f712e5ebb7cd02bab0f13aa56a7f9 diff --git a/main.cpp b/main.cpp index 3129fb9..d852dd7 100644 --- a/main.cpp +++ b/main.cpp @@ -1,9 +1,276 @@ -#include +#include "ProcPipe.hpp" +#include "Random.hpp" +#include "Timer.hpp" -using namespace std; +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include -int main() -{ - cout << "Hello World!" << endl; - return 0; +#ifndef _POSIX_VERSION +# error "Not compatible with non-POSIX systems" +#endif + + + +static +std::vector str_split(std::string_view s, char delimiter, size_t times = -1) { + std::vector to_return; + decltype(s.size()) start = 0, finish = 0; + while ((finish = s.find_first_of(delimiter, start)) != std::string_view::npos) { + to_return.emplace_back(s.substr(start, finish - start)); + start = finish + 1; + if (to_return.size() == times) { break; } + } + to_return.emplace_back(s.substr(start)); + return to_return; +} + + +class LLM { + struct Exception : public std::runtime_error { + using std::runtime_error::runtime_error; + }; + + ProcPipe llama; + struct { + std::string model = "7B-ggml-model-quant.bin"; + + int32_t seed; // RNG seed + int32_t n_threads = static_cast(std::thread::hardware_concurrency()) / 2; + int32_t n_predict = 20000; // new tokens to predict + int32_t repeat_last_n = 256; // last n tokens to penalize + int32_t n_ctx = 2024; //context size + + int32_t top_k = 40; + float top_p = 0.5f; + float temp = 0.84f; + float repeat_penalty = 1.17647f; + } params; + + std::string get_temp_file_path() { + return "/tmp/discord_llama_"+std::to_string(getpid())+".txt"; + } + + void start() { + // Start process + const auto exe_path = "./llama.cpp/llama"; + llama.start(exe_path, "-m", params.model.c_str(), "-s", std::to_string(params.seed).c_str(), + "-t", std::to_string(params.n_threads).c_str(), + "-f", get_temp_file_path().c_str(), "-n", std::to_string(params.n_predict).c_str(), + "--top_k", std::to_string(params.top_k).c_str(), "--top_p", std::to_string(params.top_p).c_str(), + "--repeat_last_n", std::to_string(params.repeat_last_n).c_str(), "--repeat_penalty", std::to_string(params.repeat_penalty).c_str(), + "-c", std::to_string(params.n_ctx).c_str(), "--temp", std::to_string(params.temp).c_str()); + } + +public: + LLM(int32_t seed = 0) { + // Set random seed + params.seed = seed?seed:time(NULL); + } + + std::string run(std::string_view prompt, const char *end = nullptr) { + std::string fres; + + // Write prompt into file + if (!(std::ofstream(get_temp_file_path()) << prompt)) { + throw Exception("Failed to write out initial prompt"); + } + + // Start AI + const char prompt_based_end[2] = {prompt[0], '\0'}; + auto end_length = end?strlen(end):sizeof(prompt_based_end); + end = end?end:prompt_based_end; + start(); + + // Wait for a bit + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // Make sure everything is alright + if (!llama.isRunning()) { + throw Exception("Llama didn't start. Read stderr for more info."); + } + + // Read until done + do { + // Receive a byte + const auto text = llama.recvStd<1>(); + // Break on EOF + if (text.empty()) break; + // Debug + putchar(text[0]); + fflush(stdout); + // Append byte to fres + fres.append(std::string_view{text.data(), text.size()}); + // Check if end is reached + auto res = fres.rfind(end); + if (res != fres.npos && res > prompt.size()) { + break; + } + } while (llama.isRunning()); + + // Erase end + fres.erase(fres.size()-end_length, end_length); + + // Kill llama + llama.kill(); + + // Return final result + std::cout << fres << std::endl; + return fres.substr(prompt.size()+1); + } +}; + + +class Bot { + RandomGenerator rng; + Timer last_message_timer; + std::shared_ptr stopping; + std::mutex llm_lock; + + dpp::cluster bot; + dpp::channel channel; + dpp::snowflake channel_id; + std::vector history; + std::vector my_messages; + + void reply() { + // Generate prompt + std::string prompt; + { + std::ostringstream prompts; + // Append channel name + prompts << "Log of #general channel.\n\n"; + // Append each message to stream + for (const auto& msg : history) { + for (const auto line : str_split(msg.content, '\n')) { + prompts << msg.author.username << ": " << line << '\n'; + } + } + // Make LLM respond + prompts << bot.me.username << ':'; + // Keep resulting string + prompt = prompts.str(); + } + // Make sure prompt isn't to long; if so, erase a message and retry + if (prompt.size() > 200) { + history.erase(history.begin()); + return reply(); + } + // Start new thread + std::thread([this, prompt = std::move(prompt)] () { + // Run model + std::scoped_lock L(llm_lock); + std::string output; + try { + output = LLM().run(prompt, "\n"); + } catch (...) { + std::rethrow_exception(std::current_exception()); + } + // Send resulting message + auto msg = bot.message_create_sync(dpp::message(channel_id, output)); + // Add message to list of my messages + my_messages.push_back(msg.id); // Unsafe!! + }).detach(); + } + + void idle_auto_reply() { + auto s = stopping; + do { + // Wait for a bit + std::this_thread::sleep_for(std::chrono::minutes(5)); + // Check if last message was more than 20 minutes ago + if (last_message_timer.get() > 20) { + // Force reply + reply(); + } + } while (!*s); + } + + void attempt_reply(const dpp::message& msg) { + // Always reply to 10th message + if (history.size() == 5) { + return reply(); + } + // Do not reply before 10th message + if (history.size() > 5) { + // Decide randomly + if (rng.getBool(0.075f)) { + return reply(); + } + // Reply if message contains username or ID + if (msg.content.find(bot.me.username) != std::string::npos + || msg.content.find(bot.me.id) != std::string::npos) { + return reply(); + } + // Reply if message references user + for (const auto msg_id : my_messages) { + if (msg.message_reference.message_id == msg_id) { + return reply(); + } + } + } + } + +public: + Bot(const char *token, dpp::snowflake channel_id) : bot(token), channel_id(channel_id) { + bot.on_log(dpp::utility::cout_logger()); + bot.intents = dpp::i_guild_messages | dpp::i_message_content; + + // Set callbacks + bot.on_ready([=] (const dpp::ready_t&) { + // Get channel + bot.channel_get(channel_id, [=] (const dpp::confirmation_callback_t& cbt) { + if (cbt.is_error()) { + throw std::runtime_error("Failed to get channel: "+cbt.get_error().message); + } + channel = cbt.get(); + }); + // Initialize random generator + rng.seed(bot.me.id); + // Start idle auto reply thread + std::thread([this] () { + idle_auto_reply(); + }).detach(); + }); + bot.on_message_create([=] (const dpp::message_create_t& event) { + // Make sure message source is correct + if (event.msg.channel_id != channel_id) return; + // Make sure message has content + if (event.msg.content.empty()) return; + // Append message to history + history.push_back(event.msg); + // Attempt to send a reply + attempt_reply(event.msg); + // Reset last message timer + last_message_timer.reset(); + }); + } + + void start() { + stopping = std::make_shared(false); + bot.start(dpp::st_wait); + *stopping = true; + } +}; + + +int main(int argc, char **argv) { + // Check arguments + if (argc < 3) { + std::cout << "Usage: " << argv[0] << " " << std::endl; + return -1; + } + + // Construct and configure bot + Bot bot(argv[1], std::stoull(argv[2])); + + // Start bot + bot.start(); }