1
0
Fork 0
mirror of https://gitlab.com/niansa/discord_llama.git synced 2025-03-06 20:48:25 +01:00
discord_llama/main.cpp
2023-03-23 08:18:05 -05:00

353 lines
10 KiB
C++

#include "Random.hpp"
#include "Timer.hpp"
#include <string>
#include <string_view>
#include <sstream>
#include <stdexcept>
#include <fstream>
#include <thread>
#include <chrono>
#include <functional>
#include <vector>
#include <map>
#include <mutex>
#include <memory>
#include <dpp/dpp.h>
#include <ggml.h>
#include <llama.h>
#ifndef _POSIX_VERSION
# error "Not compatible with non-POSIX systems"
#endif
static
std::vector<std::string_view> str_split(std::string_view s, char delimiter, size_t times = -1) {
std::vector<std::string_view> to_return;
decltype(s.size()) start = 0, finish = 0;
while ((finish = s.find_first_of(delimiter, start)) != std::string_view::npos) {
to_return.emplace_back(s.substr(start, finish - start));
start = finish + 1;
if (to_return.size() == times) { break; }
}
to_return.emplace_back(s.substr(start));
return to_return;
}
static
void str_replace_in_place(std::string& subject, std::string_view search,
const std::string& replace) {
size_t pos = 0;
while ((pos = subject.find(search, pos)) != std::string::npos) {
subject.replace(pos, search.length(), replace);
pos += replace.length();
}
}
class LLM {
struct {
std::string model = "7B-ggml-model-quant.bin";
int32_t seed; // RNG seed
int32_t n_threads = static_cast<int32_t>(std::thread::hardware_concurrency()) / 4;
int32_t n_ctx = 2024; // Context size
int32_t n_batch = 8; // Batch size
int32_t top_k = 40;
float top_p = 0.5f;
float temp = 0.81f;
} params;
struct State {
std::string prompt;
std::vector<llama_token> embd;
int n_ctx;
} state;
llama_context *ctx = nullptr;
std::mutex lock;
void init() {
// Get llama parameters
puts("30");
auto lparams = llama_context_default_params();
lparams.seed = params.seed;
lparams.n_ctx = 2024;
// Create context
puts("31");
ctx = llama_init_from_file(params.model.c_str(), lparams);
if (!ctx) {
throw Exception("Failed to initialize llama from file");
}
puts("32");
// Initialize some variables
state.n_ctx = llama_n_ctx(ctx);
}
public:
struct Exception : public std::runtime_error {
using std::runtime_error::runtime_error;
};
struct ContextLengthException : public Exception {
ContextLengthException() : Exception("Max. context length exceeded") {}
};
LLM(int32_t seed = 0) {
// Set random seed
params.seed = seed?seed:time(NULL);
// Initialize llama
init();
}
~LLM() {
if (ctx) llama_free(ctx);
}
void append(const std::string& prompt) {
std::scoped_lock L(lock);
// Check if prompt was empty
const bool was_empty = state.prompt.empty();
// Append to current prompt
printf("ddd %s\n", prompt.c_str());
state.prompt.append(prompt);
// Resize buffer for tokens
puts("cccc");
const auto old_token_count = state.embd.size();
state.embd.resize(old_token_count+state.prompt.size()+1);
// Run tokenizer
puts("bbbb");
const auto token_count = llama_tokenize(ctx, prompt.data(), state.embd.data()+old_token_count, state.embd.size()-old_token_count, was_empty);
state.embd.resize(old_token_count+token_count);
// Make sure limit is far from being hit
if (token_count > state.n_ctx-6) {
// Yup. *this MUST be decomposed now.
throw ContextLengthException();
}
// Evaluate new tokens
// TODO: Larger batch size
printf("aaa %lu+%d=%lu\n", old_token_count, token_count, old_token_count+token_count);
for (int it = old_token_count; it != old_token_count+token_count; it++) {
printf("aaa %i %s\n", it, llama_token_to_str(ctx, state.embd.data()[it]));
llama_eval(ctx, state.embd.data()+it, 1, it, params.n_threads);
}
}
std::string run(std::string_view end, const std::function<bool ()>& on_tick = nullptr) {
std::scoped_lock L(lock);
std::string fres;
// Loop until done
puts("6");
bool abort = false;
while (!abort && !fres.ends_with(end)) {
// Sample top p and top k
const auto id = llama_sample_top_p_top_k(ctx, nullptr, 0, params.top_k, params.top_p, params.temp, 1.0f);
// Add token
state.embd.push_back(id);
// Get token as string
const auto str = llama_token_to_str(ctx, id);
// Debug
std::cout << str << std::flush;
// Append string to function result
fres.append(str);
// Evaluate token
// TODO: Larger batch size
llama_eval(ctx, state.embd.data()+state.embd.size()-1, 1, state.embd.size()-1, params.n_threads);
// Tick
if (on_tick && !on_tick()) abort = true;
}
// Return final string
puts("23");
state.prompt.append(fres);
return std::string(fres.data(), fres.size()-end.size());
}
};
class Bot {
RandomGenerator rng;
Timer last_message_timer;
std::shared_ptr<bool> stopping;
std::unique_ptr<LLM> llm;
std::vector<dpp::snowflake> my_messages;
std::mutex llm_init_lock;
dpp::cluster bot;
dpp::channel channel;
dpp::snowflake channel_id;
void llm_init() {
if (!llm) {
{
std::unique_lock L(llm_init_lock);
llm = std::make_unique<LLM>();
}
llm->append("Verlauf des #chat Kanals.\nNotiz: "+bot.me.username+" ist ein freundlicher Chatbot, der immer gerne auf deutsch mitredet.\n\n");
}
}
void prompt_add_msg(const dpp::message& msg) {
try {
// Format and append line
for (const auto line : str_split(msg.content, '\n')) {
llm->append(msg.author.username+": ");
llm->append(std::string(line));
llm->append("\n");
}
} catch (const LLM::ContextLengthException&) {
llm.reset();
llm_init();
}
}
void prompt_add_trigger() {
try {
llm->append(bot.me.username+':');
} catch (const LLM::ContextLengthException&) {
llm.reset();
llm_init();
}
}
void reply() {
// Start new thread
std::thread([this] () {
// Create placeholder message
auto msg = bot.message_create_sync(dpp::message(channel_id, "Bitte warte... :thinking:"));
// Trigger LLM correctly
prompt_add_trigger();
// Run model
Timer timeout;
bool timed_out = false;
auto output = llm->run("\n", [&] () {
if (timeout.get<std::chrono::minutes>() > 4) {
timed_out = true;
return false;
}
return true;
});
if (timed_out) output = "Fehler: Zeitüberschreitung";
// Send resulting message
msg.content = output;
bot.message_edit(msg);
}).detach();
}
void idle_auto_reply() {
auto s = stopping;
do {
// Wait for a bit
std::this_thread::sleep_for(std::chrono::minutes(5));
// Check if last message was more than 20 minutes ago
if (last_message_timer.get<std::chrono::hours>() > 3) {
// Force reply
reply();
}
} while (!*s);
}
void attempt_reply(const dpp::message& msg) {
// Decide randomly
/*if (rng.getBool(0.075f)) {
return reply();
}*/
// Reply if message contains username, mention or ID
if (msg.content.find(bot.me.username) != std::string::npos) {
return reply();
}
// Reply if message references user
for (const auto msg_id : my_messages) {
if (msg.message_reference.message_id == msg_id) {
return reply();
}
}
}
public:
Bot(const char *token, dpp::snowflake channel_id) : bot(token), channel_id(channel_id) {
// Configure bot
bot.on_log(dpp::utility::cout_logger());
bot.intents = dpp::i_guild_messages | dpp::i_message_content;
// Set callbacks
bot.on_ready([=, this] (const dpp::ready_t&) {
// Get channel
bot.channel_get(channel_id, [=, this] (const dpp::confirmation_callback_t& cbt) {
if (cbt.is_error()) {
throw std::runtime_error("Failed to get channel: "+cbt.get_error().message);
}
channel = cbt.get<dpp::channel>();
// Initialize random generator
rng.seed(bot.me.id);
// Append initial prompt
llm_init();
// Start idle auto reply thread
std::thread([this] () {
idle_auto_reply();
}).detach();
});
});
bot.on_message_create([=, this] (const dpp::message_create_t& event) {
// Make sure message source is correct
if (event.msg.channel_id != channel_id) return;
// Make sure message has content
if (event.msg.content.empty()) return;
// Ignore own messages
if (event.msg.author.id == bot.me.id) {
// Add message to list of own messages
my_messages.push_back(event.msg.id);
return;
}
// Replace bot mentions with bot username
auto msg = event.msg;
str_replace_in_place(msg.content, "<@"+std::to_string(bot.me.id)+'>', bot.me.username);
// Attempt to send a reply
attempt_reply(msg);
// Append message to history
prompt_add_msg(msg);
// Reset last message timer
last_message_timer.reset();
});
}
void start() {
stopping = std::make_shared<bool>(false);
bot.start(dpp::st_wait);
*stopping = true;
}
};
int main(int argc, char **argv) {
// Init GGML
ggml_time_init();
// Check arguments
if (argc < 3) {
std::cout << "Usage: " << argv[0] << " <token> <channel>" << std::endl;
return -1;
}
// Construct and configure bot
Bot bot(argv[1], std::stoull(argv[2]));
// Start bot
bot.start();
}