From 553697d65e1380a30fb8c0b6e191e0a3502c6146 Mon Sep 17 00:00:00 2001 From: niansa Date: Sat, 1 Apr 2023 15:04:52 +0200 Subject: [PATCH] Use libjustlm --- .gitmodules | 6 +- CMakeLists.txt | 4 +- libjustlm | 1 + llama.cpp | 1 - main.cpp | 222 +++++++------------------------------------------ 5 files changed, 37 insertions(+), 197 deletions(-) create mode 160000 libjustlm delete mode 160000 llama.cpp diff --git a/.gitmodules b/.gitmodules index 3b68747..cf7a7e5 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,6 @@ [submodule "DPP"] path = DPP url = https://github.com/brainboxdotcc/DPP.git -[submodule "llama.cpp"] - path = llama.cpp - url = https://github.com/ggerganov/llama.cpp.git +[submodule "libjustlm"] + path = libjustlm + url = https://gitlab.com/niansa/libjustlm.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 3cd4c51..d3f4a9c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,11 +5,11 @@ project(discord_llama LANGUAGES C CXX) set(CMAKE_CXX_STANDARD 20) set(CMAKE_CXX_STANDARD_REQUIRED ON) -add_subdirectory(llama.cpp) +add_subdirectory(libjustlm) add_subdirectory(DPP) add_executable(discord_llama main.cpp) -target_link_libraries(discord_llama PUBLIC dpp pthread llama ggml) +target_link_libraries(discord_llama PUBLIC dpp pthread libjustlm ggml) install(TARGETS discord_llama LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/libjustlm b/libjustlm new file mode 160000 index 0000000..dc7fe7f --- /dev/null +++ b/libjustlm @@ -0,0 +1 @@ +Subproject commit dc7fe7f9f01681544916da4d41ef704d254a4ca4 diff --git a/llama.cpp b/llama.cpp deleted file mode 160000 index 1972616..0000000 --- a/llama.cpp +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 19726169b379bebc96189673a19b89ab1d307659 diff --git a/main.cpp b/main.cpp index 5f452ec..44722b0 100644 --- a/main.cpp +++ b/main.cpp @@ -15,8 +15,7 @@ #include #include #include -#include -#include +#include @@ -43,194 +42,33 @@ void str_replace_in_place(std::string& subject, std::string_view search, } } - -class LLM { - struct { - std::string model = "7B-ggml-model-quant.bin"; - - int32_t seed; // RNG seed - int32_t n_threads = static_cast(std::thread::hardware_concurrency()) / 4; - int32_t n_ctx = 2024; // Context size - int32_t n_batch = 8; // Batch size, unused for now - - int32_t top_k = 40; - float top_p = 0.5f; - float temp = 0.72f; - - bool no_repeat = true; - } params; - - struct State { - std::string prompt; - std::vector embd; - int n_ctx; - std::string last_result; - int repeats; - } state; - - llama_context *ctx = nullptr; - std::mutex lock; - - static inline - std::string clean_string(const std::string& str) { - std::string fres; - for (const auto c : str) { - if ((c >= 0x20 && c <= 0x7E) - || c == '\n' - || c == "ä"[0] || c == "ä"[1] || c == "ä"[2] - || c == "ö"[0] || c == "ö"[1] || c == "ö"[2] - || c == "ü"[0] || c == "ü"[1] || c == "ü"[2] - || c == "Ä"[0] || c == "Ä"[1] || c == "Ä"[2] - || c == "Ö"[0] || c == "Ö"[1] || c == "Ö"[2] - || c == "Ü"[0] || c == "Ü"[1] || c == "Ü"[2] - || c == "ß"[0] || c == "ß"[1] || c == "ß"[2]) { - fres.push_back(c); - } +static inline +std::string clean_string(std::string_view str) { + std::string fres; + for (const auto c : str) { + if ((c >= 0x20 && c <= 0x7E) + || c == '\n' + || c == "ä"[0] || c == "ä"[1] || c == "ä"[2] + || c == "ö"[0] || c == "ö"[1] || c == "ö"[2] + || c == "ü"[0] || c == "ü"[1] || c == "ü"[2] + || c == "Ä"[0] || c == "Ä"[1] || c == "Ä"[2] + || c == "Ö"[0] || c == "Ö"[1] || c == "Ö"[2] + || c == "Ü"[0] || c == "Ü"[1] || c == "Ü"[2] + || c == "ß"[0] || c == "ß"[1] || c == "ß"[2]) { + fres.push_back(c); } - return fres; } - - void init() { - // Get llama parameters - auto lparams = llama_context_default_params(); - lparams.seed = params.seed; - lparams.n_ctx = 2024; - - // Create context - ctx = llama_init_from_file(params.model.c_str(), lparams); - if (!ctx) { - throw Exception("Failed to initialize llama from file"); - } - - // Initialize some variables - state.n_ctx = llama_n_ctx(ctx); - state.repeats = 0; - } - -public: - struct Exception : public std::runtime_error { - using std::runtime_error::runtime_error; - }; - struct ContextLengthException : public Exception { - ContextLengthException() : Exception("Max. context length exceeded") {} - }; - - - LLM(int32_t seed = 0) { - // Set random seed - params.seed = seed?seed:time(NULL); - - // Initialize llama - init(); - } - ~LLM() { - std::scoped_lock L(lock); - if (ctx) llama_free(ctx); - } - - void append(std::string prompt, const std::function& on_tick = nullptr) { - std::scoped_lock L(lock); - - // Remove non-printables - prompt = clean_string(prompt); - - // Check if prompt was empty - const bool was_empty = state.prompt.empty(); - - // Append to current prompt - state.prompt.append(prompt); - - // Debug - std::ofstream("prompt.txt") << state.prompt; - - // Resize buffer for tokens - const auto old_token_count = state.embd.size(); - state.embd.resize(old_token_count+state.prompt.size()+1); - - // Run tokenizer - const auto token_count = llama_tokenize(ctx, prompt.data(), state.embd.data()+old_token_count, state.embd.size()-old_token_count, was_empty); - state.embd.resize(old_token_count+token_count); - - // Make sure limit is far from being hit - if (state.embd.size() > state.n_ctx-6) { - // Yup. *this MUST be decomposed now. - throw ContextLengthException(); - } - - // Evaluate new tokens - // TODO: Larger batch size - std::cout << "Context size: " << old_token_count << '+' << token_count << '=' << state.embd.size() << '/' << state.n_ctx << std::endl; - for (int it = old_token_count; it != state.embd.size(); it++) { - std::cout << llama_token_to_str(ctx, state.embd.data()[it]) << std::flush; - llama_eval(ctx, state.embd.data()+it, 1, it, params.n_threads); - - // Tick - if (on_tick) { - // Calculate progress - auto progress = float(it) / (state.embd.size()) * 100.f; - // Run callback - if (!on_tick(progress)) break; - } - } - std::cout << std::endl; - } - - std::string run(std::string_view end, const std::function& on_tick = nullptr) { - std::scoped_lock L(lock); - std::string fres; - - // Loop until done - bool abort = false; - while (!abort && !fres.ends_with(end)) { - // Sample top p and top k - bool has_repeated = state.repeats>=4; - const auto id = llama_sample_top_p_top_k(ctx, nullptr, 0, params.top_k, has_repeated?(params.top_p+0.15f):params.top_p, has_repeated?(params.temp+0.4f):params.temp, 1.0f); - - // Add token - state.embd.push_back(id); - - // Get token as string - const auto str = llama_token_to_str(ctx, id); - - // Debug - std::cout << str << std::flush; - - // Append string to function result - fres.append(str); - - // Evaluate token - // TODO: Respect batch size - llama_eval(ctx, state.embd.data()+state.embd.size()-1, 1, state.embd.size()-1, params.n_threads); - - // Tick - if (on_tick && !on_tick()) abort = true; - } - - // Create final string - state.prompt.append(fres); - fres = std::string(fres.data(), fres.size()-end.size()); - - // Check for repetition - if (state.last_result == fres && params.no_repeat) { - state.repeats++; - } else { - state.repeats = 0; - state.last_result = fres; - } - - // Return final string - return fres; - } -}; + return fres; +} class Bot { RandomGenerator rng; Timer last_message_timer; std::shared_ptr stopping; - std::unique_ptr llm; + std::unique_ptr llm; std::vector my_messages; - std::mutex llm_init_lock; + std::mutex llm_lock; dpp::cluster bot; dpp::channel channel; @@ -269,8 +107,8 @@ class Bot { if (!llm) { // Make sure llm is initialized { - std::unique_lock L(llm_init_lock); - llm = std::make_unique(); + std::unique_lock L(llm_lock); + llm = std::make_unique("7B-ggml-model-quant.bin"); } // Create message for reporting progress dpp::message msg(channel_id, "Wird initialisiert..."); @@ -295,6 +133,7 @@ class Bot { return true; }; // Add initial context + std::unique_lock L(llm_lock); llm->append("Verlauf des #"+channel.name+" Kanals.\n" "Notiz 1: "+bot.me.username+" ist ein freundlicher Chatbot, der immer gerne auf deutsch mitredet. Er ist freundlich und hilfsbereit und antwortet immer sofort. Er hat guten Humor und mag jeden. Sein Alter ist 16 und er wurde 2007 geboren.\n" "Notiz 2: Ecki heisst in Wirklichkeit Eckhard Kohlhuber und kommt aus Bayern.\n" @@ -311,9 +150,10 @@ class Bot { return; } // Format and append line + std::unique_lock L(llm_lock); for (const auto line : str_split(msg.content, '\n')) { Timer timeout; - llm->append(msg.author.username+": "+std::string(line)+'\n', [&] (float) { + llm->append(msg.author.username+": "+clean_string(line)+'\n', [&] (float) { if (timeout.get() > 1) { std::cerr << "\nWarning: Timeout reached processing message" << std::endl; return false; @@ -321,15 +161,16 @@ class Bot { return true; }); } - } catch (const LLM::ContextLengthException&) { + } catch (const LM::Inference::ContextLengthException&) { llm.reset(); llm_init(); } } void prompt_add_trigger() { try { + std::unique_lock L(llm_lock); llm->append(bot.me.username+':'); - } catch (const LLM::ContextLengthException&) { + } catch (const LM::Inference::ContextLengthException&) { llm.reset(); llm_init(); } @@ -346,14 +187,16 @@ class Bot { // Run model Timer timeout; bool timed_out = false; - auto output = llm->run("\n", [&] () { + auto output = llm->run("\n", [&] (std::string_view str) { + std::cout << str << std::flush; if (timeout.get() > 2) { timed_out = true; - std::cerr << "\nWarning: Timeout reached generating message" << std::endl; + std::cerr << "\nWarning: Timeout reached generating message"; return false; } return true; }); + std::cout << std::endl; if (timed_out) output = "Fehler: Zeitüberschreitung"; // Send resulting message msg.content = output; @@ -463,9 +306,6 @@ public: int main(int argc, char **argv) { - // Init GGML - ggml_time_init(); - // Check arguments if (argc < 3) { std::cout << "Usage: " << argv[0] << " " << std::endl;