mirror of
https://gitlab.com/niansa/discord_llama.git
synced 2025-03-06 20:48:25 +01:00
Initial commit
This commit is contained in:
parent
356a6f9823
commit
09a1cc5b09
9 changed files with 528 additions and 6 deletions
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
|
@ -0,0 +1 @@
|
|||
CMakeLists.txt.user*
|
6
.gitmodules
vendored
Normal file
6
.gitmodules
vendored
Normal file
|
@ -0,0 +1,6 @@
|
|||
[submodule "llama.cpp"]
|
||||
path = llama.cpp
|
||||
url = https://github.com/ggerganov/llama.cpp.git
|
||||
[submodule "DPP"]
|
||||
path = DPP
|
||||
url = https://github.com/brainboxdotcc/DPP.git
|
|
@ -5,7 +5,11 @@ project(discord_llama LANGUAGES CXX)
|
|||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
add_subdirectory(llama.cpp)
|
||||
add_subdirectory(DPP)
|
||||
|
||||
add_executable(discord_llama main.cpp)
|
||||
target_link_libraries(discord_llama PUBLIC dpp)
|
||||
|
||||
install(TARGETS discord_llama
|
||||
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})
|
||||
|
|
1
DPP
Submodule
1
DPP
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit 06a0483a454e3b339436fbc0c25cc37a4fe11c4a
|
174
ProcPipe.hpp
Normal file
174
ProcPipe.hpp
Normal file
|
@ -0,0 +1,174 @@
|
|||
#ifndef __WIN32
|
||||
#ifndef _PROCPIPE_HPP
|
||||
#define _PROCPIPE_HPP
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <stdexcept>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <cerrno>
|
||||
#include <csignal>
|
||||
#include <unistd.h>
|
||||
#include <fcntl.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/wait.h>
|
||||
#include <sys/socket.h>
|
||||
|
||||
|
||||
|
||||
template<bool redir_stdin, bool redir_stdout, bool redir_stderr>
|
||||
class ProcPipe {
|
||||
struct Pipe {
|
||||
int readFd = -1, writeFd = -1;
|
||||
|
||||
auto make() {
|
||||
return pipe(reinterpret_cast<int*>(this));
|
||||
}
|
||||
~Pipe() {
|
||||
close(readFd); close(writeFd);
|
||||
}
|
||||
};
|
||||
struct Redirect {
|
||||
Pipe *pipe;
|
||||
int fd;
|
||||
bool output;
|
||||
int fdbak = -1;
|
||||
};
|
||||
|
||||
constexpr static int errExit = 48;
|
||||
pid_t pid = 0;
|
||||
Pipe stdin,
|
||||
stdout,
|
||||
stderr;
|
||||
|
||||
template<unsigned size>
|
||||
auto recvFrom(int fd) {
|
||||
static_assert (size != 0, "Can't read zero bytes");
|
||||
std::vector<char> fres(size);
|
||||
ssize_t bytes_read;
|
||||
if ((bytes_read = read(fd, fres.data(), fres.size())) < 0) {
|
||||
throw FdError("Failed to read() from stdout");
|
||||
}
|
||||
fres.resize(bytes_read);
|
||||
return fres;
|
||||
}
|
||||
|
||||
public:
|
||||
struct ExecutionError : public std::runtime_error {
|
||||
using std::runtime_error::runtime_error;
|
||||
};
|
||||
struct AlreadyRunning : public std::runtime_error {
|
||||
using std::runtime_error::runtime_error;
|
||||
};
|
||||
struct FdError : public std::runtime_error {
|
||||
using std::runtime_error::runtime_error;
|
||||
};
|
||||
|
||||
ProcPipe() {}
|
||||
template<typename... Args>
|
||||
ProcPipe(Args&&... args) {
|
||||
start(args...);
|
||||
}
|
||||
~ProcPipe() {
|
||||
terminate();
|
||||
}
|
||||
|
||||
void send(std::string_view str) {
|
||||
static_assert (redir_stdin, "Can't write to stdin if not redirected");
|
||||
if (write(stdin.writeFd, str.data(), str.size()) < 0) {
|
||||
throw FdError("Failed to write() to stdin");
|
||||
}
|
||||
}
|
||||
|
||||
template<unsigned size>
|
||||
auto recvStd() {
|
||||
static_assert (redir_stdout, "Can't read from stdout if not redirected");
|
||||
return recvFrom<size>(stdout.readFd);
|
||||
}
|
||||
|
||||
template<unsigned size>
|
||||
auto recvErr() {
|
||||
static_assert (redir_stderr, "Can't read from stdout if not redirected");
|
||||
return recvFrom<size>(stderr.readFd);
|
||||
}
|
||||
|
||||
auto makeRedirs() {
|
||||
constexpr int redirs_size = redir_stdin + redir_stdout + redir_stderr;
|
||||
std::array<Redirect, redirs_size> redirs = {};
|
||||
{
|
||||
int idx = 0;
|
||||
if constexpr(redir_stdin) {
|
||||
redirs[idx++] = {&stdin, STDIN_FILENO, true};
|
||||
}
|
||||
if constexpr(redir_stdout) {
|
||||
redirs[idx++] = {&stdout, STDOUT_FILENO, false};
|
||||
}
|
||||
if constexpr(redir_stderr) {
|
||||
redirs[idx++] = {&stderr, STDERR_FILENO, false};
|
||||
}
|
||||
}
|
||||
return redirs;
|
||||
}
|
||||
|
||||
template<typename... Args>
|
||||
void start(Args&&... args) {
|
||||
if (pid) {
|
||||
throw AlreadyRunning("Tried to run process in an instance where it is already running");
|
||||
} else {
|
||||
// Make redirects
|
||||
auto redirs = makeRedirs();
|
||||
// Redirect fds
|
||||
for (auto& io : redirs) {
|
||||
// Backup fd
|
||||
io.fdbak = dup(io.fd);
|
||||
// Create new pipe
|
||||
io.pipe->make();
|
||||
dup2((io.output ? io.pipe->readFd : io.pipe->writeFd), io.fd);
|
||||
}
|
||||
// Run process
|
||||
pid = fork();
|
||||
if (pid == 0) {
|
||||
const auto executable = std::get<0>(std::tuple{args...});
|
||||
execlp(executable, args..., nullptr);
|
||||
perror((std::string("Failed to launch ")+executable).c_str());
|
||||
exit(errExit);
|
||||
}
|
||||
// Restore fds
|
||||
for (const auto& io : redirs) {
|
||||
// Restore
|
||||
dup2(io.fdbak, io.fd);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto waitExit() noexcept {
|
||||
if (pid) {
|
||||
int status = 0;
|
||||
waitpid(pid, &status, 0);
|
||||
pid = 0;
|
||||
return status;
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
auto terminate() noexcept {
|
||||
if (pid) {
|
||||
::kill(pid, SIGTERM);
|
||||
return waitExit();
|
||||
} else {
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
auto kill() noexcept {
|
||||
::kill(pid, SIGKILL);
|
||||
}
|
||||
|
||||
auto isRunning() noexcept {
|
||||
return !(::kill(pid, 0) < 0);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
#endif
|
43
Random.hpp
Normal file
43
Random.hpp
Normal file
|
@ -0,0 +1,43 @@
|
|||
#ifndef _PHASMOENGINE_RANDOM_HPP
|
||||
#define _PHASMOENGINE_RANDOM_HPP
|
||||
#include <random>
|
||||
|
||||
|
||||
|
||||
class RandomGenerator {
|
||||
std::mt19937 rng;
|
||||
uint32_t initialSeed;
|
||||
|
||||
public:
|
||||
void seed() {
|
||||
rng.seed(initialSeed = std::random_device{}());
|
||||
}
|
||||
void seed(uint32_t customSeed) {
|
||||
rng.seed(initialSeed = customSeed);
|
||||
}
|
||||
|
||||
unsigned getUInt() {
|
||||
std::uniform_int_distribution<unsigned> dist;
|
||||
return dist(rng);
|
||||
}
|
||||
unsigned getUInt(unsigned max) {
|
||||
std::uniform_int_distribution<unsigned> dist(0, max);
|
||||
return dist(rng);
|
||||
}
|
||||
unsigned getUInt(unsigned min, unsigned max) {
|
||||
std::uniform_int_distribution<unsigned> dist(min, max);
|
||||
return dist(rng);
|
||||
}
|
||||
double getDouble(double max) {
|
||||
std::uniform_real_distribution<double> dist(0.0, max);
|
||||
return dist(rng);
|
||||
}
|
||||
double getDouble(double min, double max) {
|
||||
std::uniform_real_distribution<double> dist(min, max);
|
||||
return dist(rng);
|
||||
}
|
||||
bool getBool(float chance) {
|
||||
return getDouble(1.0) <= chance && chance != 0.0f;
|
||||
}
|
||||
};
|
||||
#endif
|
25
Timer.hpp
Normal file
25
Timer.hpp
Normal file
|
@ -0,0 +1,25 @@
|
|||
#ifndef _PHASMOENGINE_TIMER_HPP
|
||||
#define _PHASMOENGINE_TIMER_HPP
|
||||
#include <chrono>
|
||||
|
||||
|
||||
|
||||
class Timer {
|
||||
std::chrono::time_point<std::chrono::high_resolution_clock> value;
|
||||
|
||||
public:
|
||||
Timer() {
|
||||
reset();
|
||||
}
|
||||
|
||||
void reset() {
|
||||
value = std::chrono::high_resolution_clock::now();
|
||||
}
|
||||
|
||||
template<typename Unit = std::chrono::milliseconds>
|
||||
auto get() {
|
||||
auto duration = std::chrono::duration_cast<Unit>(std::chrono::high_resolution_clock::now() - value);
|
||||
return duration.count();
|
||||
}
|
||||
};
|
||||
#endif
|
1
llama.cpp
Submodule
1
llama.cpp
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit d7def1a7524f712e5ebb7cd02bab0f13aa56a7f9
|
279
main.cpp
279
main.cpp
|
@ -1,9 +1,276 @@
|
|||
#include <iostream>
|
||||
#include "ProcPipe.hpp"
|
||||
#include "Random.hpp"
|
||||
#include "Timer.hpp"
|
||||
|
||||
using namespace std;
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <fstream>
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
#include <mutex>
|
||||
#include <memory>
|
||||
#include <dpp/dpp.h>
|
||||
|
||||
int main()
|
||||
{
|
||||
cout << "Hello World!" << endl;
|
||||
return 0;
|
||||
#ifndef _POSIX_VERSION
|
||||
# error "Not compatible with non-POSIX systems"
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
static
|
||||
std::vector<std::string_view> str_split(std::string_view s, char delimiter, size_t times = -1) {
|
||||
std::vector<std::string_view> to_return;
|
||||
decltype(s.size()) start = 0, finish = 0;
|
||||
while ((finish = s.find_first_of(delimiter, start)) != std::string_view::npos) {
|
||||
to_return.emplace_back(s.substr(start, finish - start));
|
||||
start = finish + 1;
|
||||
if (to_return.size() == times) { break; }
|
||||
}
|
||||
to_return.emplace_back(s.substr(start));
|
||||
return to_return;
|
||||
}
|
||||
|
||||
|
||||
class LLM {
|
||||
struct Exception : public std::runtime_error {
|
||||
using std::runtime_error::runtime_error;
|
||||
};
|
||||
|
||||
ProcPipe<false, true, false> llama;
|
||||
struct {
|
||||
std::string model = "7B-ggml-model-quant.bin";
|
||||
|
||||
int32_t seed; // RNG seed
|
||||
int32_t n_threads = static_cast<int32_t>(std::thread::hardware_concurrency()) / 2;
|
||||
int32_t n_predict = 20000; // new tokens to predict
|
||||
int32_t repeat_last_n = 256; // last n tokens to penalize
|
||||
int32_t n_ctx = 2024; //context size
|
||||
|
||||
int32_t top_k = 40;
|
||||
float top_p = 0.5f;
|
||||
float temp = 0.84f;
|
||||
float repeat_penalty = 1.17647f;
|
||||
} params;
|
||||
|
||||
std::string get_temp_file_path() {
|
||||
return "/tmp/discord_llama_"+std::to_string(getpid())+".txt";
|
||||
}
|
||||
|
||||
void start() {
|
||||
// Start process
|
||||
const auto exe_path = "./llama.cpp/llama";
|
||||
llama.start(exe_path, "-m", params.model.c_str(), "-s", std::to_string(params.seed).c_str(),
|
||||
"-t", std::to_string(params.n_threads).c_str(),
|
||||
"-f", get_temp_file_path().c_str(), "-n", std::to_string(params.n_predict).c_str(),
|
||||
"--top_k", std::to_string(params.top_k).c_str(), "--top_p", std::to_string(params.top_p).c_str(),
|
||||
"--repeat_last_n", std::to_string(params.repeat_last_n).c_str(), "--repeat_penalty", std::to_string(params.repeat_penalty).c_str(),
|
||||
"-c", std::to_string(params.n_ctx).c_str(), "--temp", std::to_string(params.temp).c_str());
|
||||
}
|
||||
|
||||
public:
|
||||
LLM(int32_t seed = 0) {
|
||||
// Set random seed
|
||||
params.seed = seed?seed:time(NULL);
|
||||
}
|
||||
|
||||
std::string run(std::string_view prompt, const char *end = nullptr) {
|
||||
std::string fres;
|
||||
|
||||
// Write prompt into file
|
||||
if (!(std::ofstream(get_temp_file_path()) << prompt)) {
|
||||
throw Exception("Failed to write out initial prompt");
|
||||
}
|
||||
|
||||
// Start AI
|
||||
const char prompt_based_end[2] = {prompt[0], '\0'};
|
||||
auto end_length = end?strlen(end):sizeof(prompt_based_end);
|
||||
end = end?end:prompt_based_end;
|
||||
start();
|
||||
|
||||
// Wait for a bit
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
|
||||
// Make sure everything is alright
|
||||
if (!llama.isRunning()) {
|
||||
throw Exception("Llama didn't start. Read stderr for more info.");
|
||||
}
|
||||
|
||||
// Read until done
|
||||
do {
|
||||
// Receive a byte
|
||||
const auto text = llama.recvStd<1>();
|
||||
// Break on EOF
|
||||
if (text.empty()) break;
|
||||
// Debug
|
||||
putchar(text[0]);
|
||||
fflush(stdout);
|
||||
// Append byte to fres
|
||||
fres.append(std::string_view{text.data(), text.size()});
|
||||
// Check if end is reached
|
||||
auto res = fres.rfind(end);
|
||||
if (res != fres.npos && res > prompt.size()) {
|
||||
break;
|
||||
}
|
||||
} while (llama.isRunning());
|
||||
|
||||
// Erase end
|
||||
fres.erase(fres.size()-end_length, end_length);
|
||||
|
||||
// Kill llama
|
||||
llama.kill();
|
||||
|
||||
// Return final result
|
||||
std::cout << fres << std::endl;
|
||||
return fres.substr(prompt.size()+1);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class Bot {
|
||||
RandomGenerator rng;
|
||||
Timer last_message_timer;
|
||||
std::shared_ptr<bool> stopping;
|
||||
std::mutex llm_lock;
|
||||
|
||||
dpp::cluster bot;
|
||||
dpp::channel channel;
|
||||
dpp::snowflake channel_id;
|
||||
std::vector<dpp::message> history;
|
||||
std::vector<dpp::snowflake> my_messages;
|
||||
|
||||
void reply() {
|
||||
// Generate prompt
|
||||
std::string prompt;
|
||||
{
|
||||
std::ostringstream prompts;
|
||||
// Append channel name
|
||||
prompts << "Log of #general channel.\n\n";
|
||||
// Append each message to stream
|
||||
for (const auto& msg : history) {
|
||||
for (const auto line : str_split(msg.content, '\n')) {
|
||||
prompts << msg.author.username << ": " << line << '\n';
|
||||
}
|
||||
}
|
||||
// Make LLM respond
|
||||
prompts << bot.me.username << ':';
|
||||
// Keep resulting string
|
||||
prompt = prompts.str();
|
||||
}
|
||||
// Make sure prompt isn't to long; if so, erase a message and retry
|
||||
if (prompt.size() > 200) {
|
||||
history.erase(history.begin());
|
||||
return reply();
|
||||
}
|
||||
// Start new thread
|
||||
std::thread([this, prompt = std::move(prompt)] () {
|
||||
// Run model
|
||||
std::scoped_lock L(llm_lock);
|
||||
std::string output;
|
||||
try {
|
||||
output = LLM().run(prompt, "\n");
|
||||
} catch (...) {
|
||||
std::rethrow_exception(std::current_exception());
|
||||
}
|
||||
// Send resulting message
|
||||
auto msg = bot.message_create_sync(dpp::message(channel_id, output));
|
||||
// Add message to list of my messages
|
||||
my_messages.push_back(msg.id); // Unsafe!!
|
||||
}).detach();
|
||||
}
|
||||
|
||||
void idle_auto_reply() {
|
||||
auto s = stopping;
|
||||
do {
|
||||
// Wait for a bit
|
||||
std::this_thread::sleep_for(std::chrono::minutes(5));
|
||||
// Check if last message was more than 20 minutes ago
|
||||
if (last_message_timer.get<std::chrono::minutes>() > 20) {
|
||||
// Force reply
|
||||
reply();
|
||||
}
|
||||
} while (!*s);
|
||||
}
|
||||
|
||||
void attempt_reply(const dpp::message& msg) {
|
||||
// Always reply to 10th message
|
||||
if (history.size() == 5) {
|
||||
return reply();
|
||||
}
|
||||
// Do not reply before 10th message
|
||||
if (history.size() > 5) {
|
||||
// Decide randomly
|
||||
if (rng.getBool(0.075f)) {
|
||||
return reply();
|
||||
}
|
||||
// Reply if message contains username or ID
|
||||
if (msg.content.find(bot.me.username) != std::string::npos
|
||||
|| msg.content.find(bot.me.id) != std::string::npos) {
|
||||
return reply();
|
||||
}
|
||||
// Reply if message references user
|
||||
for (const auto msg_id : my_messages) {
|
||||
if (msg.message_reference.message_id == msg_id) {
|
||||
return reply();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
Bot(const char *token, dpp::snowflake channel_id) : bot(token), channel_id(channel_id) {
|
||||
bot.on_log(dpp::utility::cout_logger());
|
||||
bot.intents = dpp::i_guild_messages | dpp::i_message_content;
|
||||
|
||||
// Set callbacks
|
||||
bot.on_ready([=] (const dpp::ready_t&) {
|
||||
// Get channel
|
||||
bot.channel_get(channel_id, [=] (const dpp::confirmation_callback_t& cbt) {
|
||||
if (cbt.is_error()) {
|
||||
throw std::runtime_error("Failed to get channel: "+cbt.get_error().message);
|
||||
}
|
||||
channel = cbt.get<dpp::channel>();
|
||||
});
|
||||
// Initialize random generator
|
||||
rng.seed(bot.me.id);
|
||||
// Start idle auto reply thread
|
||||
std::thread([this] () {
|
||||
idle_auto_reply();
|
||||
}).detach();
|
||||
});
|
||||
bot.on_message_create([=] (const dpp::message_create_t& event) {
|
||||
// Make sure message source is correct
|
||||
if (event.msg.channel_id != channel_id) return;
|
||||
// Make sure message has content
|
||||
if (event.msg.content.empty()) return;
|
||||
// Append message to history
|
||||
history.push_back(event.msg);
|
||||
// Attempt to send a reply
|
||||
attempt_reply(event.msg);
|
||||
// Reset last message timer
|
||||
last_message_timer.reset();
|
||||
});
|
||||
}
|
||||
|
||||
void start() {
|
||||
stopping = std::make_shared<bool>(false);
|
||||
bot.start(dpp::st_wait);
|
||||
*stopping = true;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// Check arguments
|
||||
if (argc < 3) {
|
||||
std::cout << "Usage: " << argv[0] << " <token> <channel>" << std::endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// Construct and configure bot
|
||||
Bot bot(argv[1], std::stoull(argv[2]));
|
||||
|
||||
// Start bot
|
||||
bot.start();
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue