1
0
Fork 0
mirror of https://gitlab.com/niansa/discord_llama.git synced 2025-03-06 20:48:25 +01:00

Initial commit

This commit is contained in:
niansa 2023-03-19 21:09:15 +01:00
parent 356a6f9823
commit 09a1cc5b09
9 changed files with 528 additions and 6 deletions

1
.gitignore vendored Normal file
View file

@ -0,0 +1 @@
CMakeLists.txt.user*

6
.gitmodules vendored Normal file
View file

@ -0,0 +1,6 @@
[submodule "llama.cpp"]
path = llama.cpp
url = https://github.com/ggerganov/llama.cpp.git
[submodule "DPP"]
path = DPP
url = https://github.com/brainboxdotcc/DPP.git

View file

@ -5,7 +5,11 @@ project(discord_llama LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
add_subdirectory(llama.cpp)
add_subdirectory(DPP)
add_executable(discord_llama main.cpp)
target_link_libraries(discord_llama PUBLIC dpp)
install(TARGETS discord_llama
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})

1
DPP Submodule

@ -0,0 +1 @@
Subproject commit 06a0483a454e3b339436fbc0c25cc37a4fe11c4a

174
ProcPipe.hpp Normal file
View file

@ -0,0 +1,174 @@
#ifndef __WIN32
#ifndef _PROCPIPE_HPP
#define _PROCPIPE_HPP
#include <vector>
#include <tuple>
#include <string>
#include <string_view>
#include <stdexcept>
#include <cstdio>
#include <cstring>
#include <cerrno>
#include <csignal>
#include <unistd.h>
#include <fcntl.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <sys/socket.h>
template<bool redir_stdin, bool redir_stdout, bool redir_stderr>
class ProcPipe {
struct Pipe {
int readFd = -1, writeFd = -1;
auto make() {
return pipe(reinterpret_cast<int*>(this));
}
~Pipe() {
close(readFd); close(writeFd);
}
};
struct Redirect {
Pipe *pipe;
int fd;
bool output;
int fdbak = -1;
};
constexpr static int errExit = 48;
pid_t pid = 0;
Pipe stdin,
stdout,
stderr;
template<unsigned size>
auto recvFrom(int fd) {
static_assert (size != 0, "Can't read zero bytes");
std::vector<char> fres(size);
ssize_t bytes_read;
if ((bytes_read = read(fd, fres.data(), fres.size())) < 0) {
throw FdError("Failed to read() from stdout");
}
fres.resize(bytes_read);
return fres;
}
public:
struct ExecutionError : public std::runtime_error {
using std::runtime_error::runtime_error;
};
struct AlreadyRunning : public std::runtime_error {
using std::runtime_error::runtime_error;
};
struct FdError : public std::runtime_error {
using std::runtime_error::runtime_error;
};
ProcPipe() {}
template<typename... Args>
ProcPipe(Args&&... args) {
start(args...);
}
~ProcPipe() {
terminate();
}
void send(std::string_view str) {
static_assert (redir_stdin, "Can't write to stdin if not redirected");
if (write(stdin.writeFd, str.data(), str.size()) < 0) {
throw FdError("Failed to write() to stdin");
}
}
template<unsigned size>
auto recvStd() {
static_assert (redir_stdout, "Can't read from stdout if not redirected");
return recvFrom<size>(stdout.readFd);
}
template<unsigned size>
auto recvErr() {
static_assert (redir_stderr, "Can't read from stdout if not redirected");
return recvFrom<size>(stderr.readFd);
}
auto makeRedirs() {
constexpr int redirs_size = redir_stdin + redir_stdout + redir_stderr;
std::array<Redirect, redirs_size> redirs = {};
{
int idx = 0;
if constexpr(redir_stdin) {
redirs[idx++] = {&stdin, STDIN_FILENO, true};
}
if constexpr(redir_stdout) {
redirs[idx++] = {&stdout, STDOUT_FILENO, false};
}
if constexpr(redir_stderr) {
redirs[idx++] = {&stderr, STDERR_FILENO, false};
}
}
return redirs;
}
template<typename... Args>
void start(Args&&... args) {
if (pid) {
throw AlreadyRunning("Tried to run process in an instance where it is already running");
} else {
// Make redirects
auto redirs = makeRedirs();
// Redirect fds
for (auto& io : redirs) {
// Backup fd
io.fdbak = dup(io.fd);
// Create new pipe
io.pipe->make();
dup2((io.output ? io.pipe->readFd : io.pipe->writeFd), io.fd);
}
// Run process
pid = fork();
if (pid == 0) {
const auto executable = std::get<0>(std::tuple{args...});
execlp(executable, args..., nullptr);
perror((std::string("Failed to launch ")+executable).c_str());
exit(errExit);
}
// Restore fds
for (const auto& io : redirs) {
// Restore
dup2(io.fdbak, io.fd);
}
}
}
auto waitExit() noexcept {
if (pid) {
int status = 0;
waitpid(pid, &status, 0);
pid = 0;
return status;
} else {
return -1;
}
}
auto terminate() noexcept {
if (pid) {
::kill(pid, SIGTERM);
return waitExit();
} else {
return -1;
}
}
auto kill() noexcept {
::kill(pid, SIGKILL);
}
auto isRunning() noexcept {
return !(::kill(pid, 0) < 0);
}
};
#endif
#endif

43
Random.hpp Normal file
View file

@ -0,0 +1,43 @@
#ifndef _PHASMOENGINE_RANDOM_HPP
#define _PHASMOENGINE_RANDOM_HPP
#include <random>
class RandomGenerator {
std::mt19937 rng;
uint32_t initialSeed;
public:
void seed() {
rng.seed(initialSeed = std::random_device{}());
}
void seed(uint32_t customSeed) {
rng.seed(initialSeed = customSeed);
}
unsigned getUInt() {
std::uniform_int_distribution<unsigned> dist;
return dist(rng);
}
unsigned getUInt(unsigned max) {
std::uniform_int_distribution<unsigned> dist(0, max);
return dist(rng);
}
unsigned getUInt(unsigned min, unsigned max) {
std::uniform_int_distribution<unsigned> dist(min, max);
return dist(rng);
}
double getDouble(double max) {
std::uniform_real_distribution<double> dist(0.0, max);
return dist(rng);
}
double getDouble(double min, double max) {
std::uniform_real_distribution<double> dist(min, max);
return dist(rng);
}
bool getBool(float chance) {
return getDouble(1.0) <= chance && chance != 0.0f;
}
};
#endif

25
Timer.hpp Normal file
View file

@ -0,0 +1,25 @@
#ifndef _PHASMOENGINE_TIMER_HPP
#define _PHASMOENGINE_TIMER_HPP
#include <chrono>
class Timer {
std::chrono::time_point<std::chrono::high_resolution_clock> value;
public:
Timer() {
reset();
}
void reset() {
value = std::chrono::high_resolution_clock::now();
}
template<typename Unit = std::chrono::milliseconds>
auto get() {
auto duration = std::chrono::duration_cast<Unit>(std::chrono::high_resolution_clock::now() - value);
return duration.count();
}
};
#endif

1
llama.cpp Submodule

@ -0,0 +1 @@
Subproject commit d7def1a7524f712e5ebb7cd02bab0f13aa56a7f9

279
main.cpp
View file

@ -1,9 +1,276 @@
#include <iostream>
#include "ProcPipe.hpp"
#include "Random.hpp"
#include "Timer.hpp"
using namespace std;
#include <string>
#include <string_view>
#include <sstream>
#include <stdexcept>
#include <fstream>
#include <thread>
#include <chrono>
#include <mutex>
#include <memory>
#include <dpp/dpp.h>
int main()
{
cout << "Hello World!" << endl;
return 0;
#ifndef _POSIX_VERSION
# error "Not compatible with non-POSIX systems"
#endif
static
std::vector<std::string_view> str_split(std::string_view s, char delimiter, size_t times = -1) {
std::vector<std::string_view> to_return;
decltype(s.size()) start = 0, finish = 0;
while ((finish = s.find_first_of(delimiter, start)) != std::string_view::npos) {
to_return.emplace_back(s.substr(start, finish - start));
start = finish + 1;
if (to_return.size() == times) { break; }
}
to_return.emplace_back(s.substr(start));
return to_return;
}
class LLM {
struct Exception : public std::runtime_error {
using std::runtime_error::runtime_error;
};
ProcPipe<false, true, false> llama;
struct {
std::string model = "7B-ggml-model-quant.bin";
int32_t seed; // RNG seed
int32_t n_threads = static_cast<int32_t>(std::thread::hardware_concurrency()) / 2;
int32_t n_predict = 20000; // new tokens to predict
int32_t repeat_last_n = 256; // last n tokens to penalize
int32_t n_ctx = 2024; //context size
int32_t top_k = 40;
float top_p = 0.5f;
float temp = 0.84f;
float repeat_penalty = 1.17647f;
} params;
std::string get_temp_file_path() {
return "/tmp/discord_llama_"+std::to_string(getpid())+".txt";
}
void start() {
// Start process
const auto exe_path = "./llama.cpp/llama";
llama.start(exe_path, "-m", params.model.c_str(), "-s", std::to_string(params.seed).c_str(),
"-t", std::to_string(params.n_threads).c_str(),
"-f", get_temp_file_path().c_str(), "-n", std::to_string(params.n_predict).c_str(),
"--top_k", std::to_string(params.top_k).c_str(), "--top_p", std::to_string(params.top_p).c_str(),
"--repeat_last_n", std::to_string(params.repeat_last_n).c_str(), "--repeat_penalty", std::to_string(params.repeat_penalty).c_str(),
"-c", std::to_string(params.n_ctx).c_str(), "--temp", std::to_string(params.temp).c_str());
}
public:
LLM(int32_t seed = 0) {
// Set random seed
params.seed = seed?seed:time(NULL);
}
std::string run(std::string_view prompt, const char *end = nullptr) {
std::string fres;
// Write prompt into file
if (!(std::ofstream(get_temp_file_path()) << prompt)) {
throw Exception("Failed to write out initial prompt");
}
// Start AI
const char prompt_based_end[2] = {prompt[0], '\0'};
auto end_length = end?strlen(end):sizeof(prompt_based_end);
end = end?end:prompt_based_end;
start();
// Wait for a bit
std::this_thread::sleep_for(std::chrono::milliseconds(100));
// Make sure everything is alright
if (!llama.isRunning()) {
throw Exception("Llama didn't start. Read stderr for more info.");
}
// Read until done
do {
// Receive a byte
const auto text = llama.recvStd<1>();
// Break on EOF
if (text.empty()) break;
// Debug
putchar(text[0]);
fflush(stdout);
// Append byte to fres
fres.append(std::string_view{text.data(), text.size()});
// Check if end is reached
auto res = fres.rfind(end);
if (res != fres.npos && res > prompt.size()) {
break;
}
} while (llama.isRunning());
// Erase end
fres.erase(fres.size()-end_length, end_length);
// Kill llama
llama.kill();
// Return final result
std::cout << fres << std::endl;
return fres.substr(prompt.size()+1);
}
};
class Bot {
RandomGenerator rng;
Timer last_message_timer;
std::shared_ptr<bool> stopping;
std::mutex llm_lock;
dpp::cluster bot;
dpp::channel channel;
dpp::snowflake channel_id;
std::vector<dpp::message> history;
std::vector<dpp::snowflake> my_messages;
void reply() {
// Generate prompt
std::string prompt;
{
std::ostringstream prompts;
// Append channel name
prompts << "Log of #general channel.\n\n";
// Append each message to stream
for (const auto& msg : history) {
for (const auto line : str_split(msg.content, '\n')) {
prompts << msg.author.username << ": " << line << '\n';
}
}
// Make LLM respond
prompts << bot.me.username << ':';
// Keep resulting string
prompt = prompts.str();
}
// Make sure prompt isn't to long; if so, erase a message and retry
if (prompt.size() > 200) {
history.erase(history.begin());
return reply();
}
// Start new thread
std::thread([this, prompt = std::move(prompt)] () {
// Run model
std::scoped_lock L(llm_lock);
std::string output;
try {
output = LLM().run(prompt, "\n");
} catch (...) {
std::rethrow_exception(std::current_exception());
}
// Send resulting message
auto msg = bot.message_create_sync(dpp::message(channel_id, output));
// Add message to list of my messages
my_messages.push_back(msg.id); // Unsafe!!
}).detach();
}
void idle_auto_reply() {
auto s = stopping;
do {
// Wait for a bit
std::this_thread::sleep_for(std::chrono::minutes(5));
// Check if last message was more than 20 minutes ago
if (last_message_timer.get<std::chrono::minutes>() > 20) {
// Force reply
reply();
}
} while (!*s);
}
void attempt_reply(const dpp::message& msg) {
// Always reply to 10th message
if (history.size() == 5) {
return reply();
}
// Do not reply before 10th message
if (history.size() > 5) {
// Decide randomly
if (rng.getBool(0.075f)) {
return reply();
}
// Reply if message contains username or ID
if (msg.content.find(bot.me.username) != std::string::npos
|| msg.content.find(bot.me.id) != std::string::npos) {
return reply();
}
// Reply if message references user
for (const auto msg_id : my_messages) {
if (msg.message_reference.message_id == msg_id) {
return reply();
}
}
}
}
public:
Bot(const char *token, dpp::snowflake channel_id) : bot(token), channel_id(channel_id) {
bot.on_log(dpp::utility::cout_logger());
bot.intents = dpp::i_guild_messages | dpp::i_message_content;
// Set callbacks
bot.on_ready([=] (const dpp::ready_t&) {
// Get channel
bot.channel_get(channel_id, [=] (const dpp::confirmation_callback_t& cbt) {
if (cbt.is_error()) {
throw std::runtime_error("Failed to get channel: "+cbt.get_error().message);
}
channel = cbt.get<dpp::channel>();
});
// Initialize random generator
rng.seed(bot.me.id);
// Start idle auto reply thread
std::thread([this] () {
idle_auto_reply();
}).detach();
});
bot.on_message_create([=] (const dpp::message_create_t& event) {
// Make sure message source is correct
if (event.msg.channel_id != channel_id) return;
// Make sure message has content
if (event.msg.content.empty()) return;
// Append message to history
history.push_back(event.msg);
// Attempt to send a reply
attempt_reply(event.msg);
// Reset last message timer
last_message_timer.reset();
});
}
void start() {
stopping = std::make_shared<bool>(false);
bot.start(dpp::st_wait);
*stopping = true;
}
};
int main(int argc, char **argv) {
// Check arguments
if (argc < 3) {
std::cout << "Usage: " << argv[0] << " <token> <channel>" << std::endl;
return -1;
}
// Construct and configure bot
Bot bot(argv[1], std::stoull(argv[2]));
// Start bot
bot.start();
}