1
0
Fork 0
mirror of https://gitlab.com/niansa/discord_llama.git synced 2025-03-06 20:48:25 +01:00

Use libjustlm

This commit is contained in:
niansa 2023-04-01 15:04:52 +02:00
parent 766700602b
commit 553697d65e
5 changed files with 37 additions and 197 deletions

6
.gitmodules vendored
View file

@ -1,6 +1,6 @@
[submodule "DPP"]
path = DPP
url = https://github.com/brainboxdotcc/DPP.git
[submodule "llama.cpp"]
path = llama.cpp
url = https://github.com/ggerganov/llama.cpp.git
[submodule "libjustlm"]
path = libjustlm
url = https://gitlab.com/niansa/libjustlm.git

View file

@ -5,11 +5,11 @@ project(discord_llama LANGUAGES C CXX)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
add_subdirectory(llama.cpp)
add_subdirectory(libjustlm)
add_subdirectory(DPP)
add_executable(discord_llama main.cpp)
target_link_libraries(discord_llama PUBLIC dpp pthread llama ggml)
target_link_libraries(discord_llama PUBLIC dpp pthread libjustlm ggml)
install(TARGETS discord_llama
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR})

1
libjustlm Submodule

@ -0,0 +1 @@
Subproject commit dc7fe7f9f01681544916da4d41ef704d254a4ca4

@ -1 +0,0 @@
Subproject commit 19726169b379bebc96189673a19b89ab1d307659

196
main.cpp
View file

@ -15,8 +15,7 @@
#include <mutex>
#include <memory>
#include <dpp/dpp.h>
#include <ggml.h>
#include <llama.h>
#include <justlm.hpp>
@ -43,36 +42,8 @@ void str_replace_in_place(std::string& subject, std::string_view search,
}
}
class LLM {
struct {
std::string model = "7B-ggml-model-quant.bin";
int32_t seed; // RNG seed
int32_t n_threads = static_cast<int32_t>(std::thread::hardware_concurrency()) / 4;
int32_t n_ctx = 2024; // Context size
int32_t n_batch = 8; // Batch size, unused for now
int32_t top_k = 40;
float top_p = 0.5f;
float temp = 0.72f;
bool no_repeat = true;
} params;
struct State {
std::string prompt;
std::vector<llama_token> embd;
int n_ctx;
std::string last_result;
int repeats;
} state;
llama_context *ctx = nullptr;
std::mutex lock;
static inline
std::string clean_string(const std::string& str) {
static inline
std::string clean_string(std::string_view str) {
std::string fres;
for (const auto c : str) {
if ((c >= 0x20 && c <= 0x7E)
@ -88,149 +59,16 @@ class LLM {
}
}
return fres;
}
void init() {
// Get llama parameters
auto lparams = llama_context_default_params();
lparams.seed = params.seed;
lparams.n_ctx = 2024;
// Create context
ctx = llama_init_from_file(params.model.c_str(), lparams);
if (!ctx) {
throw Exception("Failed to initialize llama from file");
}
// Initialize some variables
state.n_ctx = llama_n_ctx(ctx);
state.repeats = 0;
}
public:
struct Exception : public std::runtime_error {
using std::runtime_error::runtime_error;
};
struct ContextLengthException : public Exception {
ContextLengthException() : Exception("Max. context length exceeded") {}
};
LLM(int32_t seed = 0) {
// Set random seed
params.seed = seed?seed:time(NULL);
// Initialize llama
init();
}
~LLM() {
std::scoped_lock L(lock);
if (ctx) llama_free(ctx);
}
void append(std::string prompt, const std::function<bool (float progress)>& on_tick = nullptr) {
std::scoped_lock L(lock);
// Remove non-printables
prompt = clean_string(prompt);
// Check if prompt was empty
const bool was_empty = state.prompt.empty();
// Append to current prompt
state.prompt.append(prompt);
// Debug
std::ofstream("prompt.txt") << state.prompt;
// Resize buffer for tokens
const auto old_token_count = state.embd.size();
state.embd.resize(old_token_count+state.prompt.size()+1);
// Run tokenizer
const auto token_count = llama_tokenize(ctx, prompt.data(), state.embd.data()+old_token_count, state.embd.size()-old_token_count, was_empty);
state.embd.resize(old_token_count+token_count);
// Make sure limit is far from being hit
if (state.embd.size() > state.n_ctx-6) {
// Yup. *this MUST be decomposed now.
throw ContextLengthException();
}
// Evaluate new tokens
// TODO: Larger batch size
std::cout << "Context size: " << old_token_count << '+' << token_count << '=' << state.embd.size() << '/' << state.n_ctx << std::endl;
for (int it = old_token_count; it != state.embd.size(); it++) {
std::cout << llama_token_to_str(ctx, state.embd.data()[it]) << std::flush;
llama_eval(ctx, state.embd.data()+it, 1, it, params.n_threads);
// Tick
if (on_tick) {
// Calculate progress
auto progress = float(it) / (state.embd.size()) * 100.f;
// Run callback
if (!on_tick(progress)) break;
}
}
std::cout << std::endl;
}
std::string run(std::string_view end, const std::function<bool ()>& on_tick = nullptr) {
std::scoped_lock L(lock);
std::string fres;
// Loop until done
bool abort = false;
while (!abort && !fres.ends_with(end)) {
// Sample top p and top k
bool has_repeated = state.repeats>=4;
const auto id = llama_sample_top_p_top_k(ctx, nullptr, 0, params.top_k, has_repeated?(params.top_p+0.15f):params.top_p, has_repeated?(params.temp+0.4f):params.temp, 1.0f);
// Add token
state.embd.push_back(id);
// Get token as string
const auto str = llama_token_to_str(ctx, id);
// Debug
std::cout << str << std::flush;
// Append string to function result
fres.append(str);
// Evaluate token
// TODO: Respect batch size
llama_eval(ctx, state.embd.data()+state.embd.size()-1, 1, state.embd.size()-1, params.n_threads);
// Tick
if (on_tick && !on_tick()) abort = true;
}
// Create final string
state.prompt.append(fres);
fres = std::string(fres.data(), fres.size()-end.size());
// Check for repetition
if (state.last_result == fres && params.no_repeat) {
state.repeats++;
} else {
state.repeats = 0;
state.last_result = fres;
}
// Return final string
return fres;
}
};
}
class Bot {
RandomGenerator rng;
Timer last_message_timer;
std::shared_ptr<bool> stopping;
std::unique_ptr<LLM> llm;
std::unique_ptr<LM::Inference> llm;
std::vector<dpp::snowflake> my_messages;
std::mutex llm_init_lock;
std::mutex llm_lock;
dpp::cluster bot;
dpp::channel channel;
@ -269,8 +107,8 @@ class Bot {
if (!llm) {
// Make sure llm is initialized
{
std::unique_lock L(llm_init_lock);
llm = std::make_unique<LLM>();
std::unique_lock L(llm_lock);
llm = std::make_unique<LM::Inference>("7B-ggml-model-quant.bin");
}
// Create message for reporting progress
dpp::message msg(channel_id, "Wird initialisiert...");
@ -295,6 +133,7 @@ class Bot {
return true;
};
// Add initial context
std::unique_lock L(llm_lock);
llm->append("Verlauf des #"+channel.name+" Kanals.\n"
"Notiz 1: "+bot.me.username+" ist ein freundlicher Chatbot, der immer gerne auf deutsch mitredet. Er ist freundlich und hilfsbereit und antwortet immer sofort. Er hat guten Humor und mag jeden. Sein Alter ist 16 und er wurde 2007 geboren.\n"
"Notiz 2: Ecki heisst in Wirklichkeit Eckhard Kohlhuber und kommt aus Bayern.\n"
@ -311,9 +150,10 @@ class Bot {
return;
}
// Format and append line
std::unique_lock L(llm_lock);
for (const auto line : str_split(msg.content, '\n')) {
Timer timeout;
llm->append(msg.author.username+": "+std::string(line)+'\n', [&] (float) {
llm->append(msg.author.username+": "+clean_string(line)+'\n', [&] (float) {
if (timeout.get<std::chrono::minutes>() > 1) {
std::cerr << "\nWarning: Timeout reached processing message" << std::endl;
return false;
@ -321,15 +161,16 @@ class Bot {
return true;
});
}
} catch (const LLM::ContextLengthException&) {
} catch (const LM::Inference::ContextLengthException&) {
llm.reset();
llm_init();
}
}
void prompt_add_trigger() {
try {
std::unique_lock L(llm_lock);
llm->append(bot.me.username+':');
} catch (const LLM::ContextLengthException&) {
} catch (const LM::Inference::ContextLengthException&) {
llm.reset();
llm_init();
}
@ -346,14 +187,16 @@ class Bot {
// Run model
Timer timeout;
bool timed_out = false;
auto output = llm->run("\n", [&] () {
auto output = llm->run("\n", [&] (std::string_view str) {
std::cout << str << std::flush;
if (timeout.get<std::chrono::minutes>() > 2) {
timed_out = true;
std::cerr << "\nWarning: Timeout reached generating message" << std::endl;
std::cerr << "\nWarning: Timeout reached generating message";
return false;
}
return true;
});
std::cout << std::endl;
if (timed_out) output = "Fehler: Zeitüberschreitung";
// Send resulting message
msg.content = output;
@ -463,9 +306,6 @@ public:
int main(int argc, char **argv) {
// Init GGML
ggml_time_init();
// Check arguments
if (argc < 3) {
std::cout << "Usage: " << argv[0] << " <token> <channel>" << std::endl;