diff --git a/.gitmodules b/.gitmodules index adfb07a..d59d7cd 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,12 +1,12 @@ -[submodule "llama.cpp"] - path = llama.cpp-old - url = https://github.com/ggerganov/llama.cpp.git [submodule "llama.cpp-alibi"] path = llama.cpp-alibi url = https://github.com/manyoso/llama.cpp.git -[submodule "llama.cpp-mainline"] - path = llama.cpp-old2 +[submodule "llama.cpp-230511"] + path = llama.cpp-230511 url = https://github.com/ggerganov/llama.cpp.git -[submodule "llama.cpp-2"] +[submodule "llama.cpp-230519"] + path = llama.cpp-230519 + url = https://github.com/ggerganov/llama.cpp.git +[submodule "llama.cpp-mainline"] path = llama.cpp-mainline url = https://github.com/ggerganov/llama.cpp.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 7b1cd06..f5cf6ee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,8 +37,8 @@ endfunction() include(llama.cpp.cmake) include_ggml(llama.cpp-mainline _mainline Yes) -include_ggml(llama.cpp-old _old Yes) -include_ggml(llama.cpp-old2 _old2 Yes) +include_ggml(llama.cpp-230511 _230511 Yes) +include_ggml(llama.cpp-230519 _230519 Yes) include_ggml(llama.cpp-alibi _alibi No) @@ -53,24 +53,30 @@ endif() if (LM_GPTJ) add_library(justlm_gptj SHARED gptj.cpp justlm_gptj.hpp gptj/gptj.cpp gptj/gptj.hpp) - target_link_libraries(justlm_gptj PRIVATE ggml_old justlm_g4a_common) + target_link_libraries(justlm_gptj PRIVATE ggml_230511 justlm_g4a_common) target_justlm_setup(justlm_gptj) endif() if (LM_LLAMA) - add_library(justlm_llama SHARED llama_mainline.cpp justlm_llama.hpp) + add_library(justlm_llama SHARED llama.cpp justlm_llama.hpp) target_link_libraries(justlm_llama PRIVATE ggml_mainline llama_mainline) + target_compile_definitions(justlm_llama PRIVATE + LLAMA_VERSIONS=>=3 LLAMA_DATE=999999) target_justlm_setup(justlm_llama) endif() if (LM_LLAMA_OLD) - add_library(justlm_llama_old SHARED llama_old.cpp justlm_llama_old.hpp) - target_link_libraries(justlm_llama_old PRIVATE ggml_old llama_old) - target_justlm_setup(justlm_llama_old) + add_library(justLM_LLAMA_OLD SHARED llama.cpp justlm_llama.hpp) + target_link_libraries(justLM_LLAMA_OLD PRIVATE ggml_230511 llama_230511) + target_compile_definitions(justLM_LLAMA_OLD PRIVATE + LLAMA_VERSIONS=<=1 LLAMA_DATE=230511) + target_justlm_setup(justLM_LLAMA_OLD) - add_library(justlm_llama_old2 SHARED llama_old2.cpp justlm_llama.hpp) - target_link_libraries(justlm_llama_old2 PRIVATE ggml_old2 llama_old2) - target_justlm_setup(justlm_llama_old2) + add_library(justlm_llama_230519 SHARED llama.cpp justlm_llama.hpp) + target_link_libraries(justlm_llama_230519 PRIVATE ggml_230519 llama_230519) + target_compile_definitions(justlm_llama_230519 PRIVATE + LLAMA_VERSIONS===2 LLAMA_DATE=230519) + target_justlm_setup(justlm_llama_230519) endif() diff --git a/justlm_llama.hpp b/justlm_llama.hpp index 4840376..a3e42bb 100644 --- a/justlm_llama.hpp +++ b/justlm_llama.hpp @@ -111,6 +111,7 @@ class LLaMAInference final : public Inference { LM_CORETURN LM_BOOL_SUCCESS; } +#if LLAMA_DATE >= 230519 int llama_sample_top_p_top_k() { auto& state = get_state(); auto logits = llama_get_logits(state->ctx); @@ -133,6 +134,13 @@ class LLaMAInference final : public Inference { llama_sample_temperature(state->ctx, &candidates_p, params.temp); return llama_sample_token(state->ctx, &candidates_p); } +#else + int llama_sample_top_p_top_k() { + auto& state = get_state(); + auto n_repeat_last = std::min(state->tokens.size(), params.n_repeat_last); + return ::llama_sample_top_p_top_k(state->ctx, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-n_repeat_last):nullptr, n_repeat_last, params.top_k, params.top_p, params.temp, params.repeat_penalty); + } +#endif public: LLaMAInference(const std::string& weights_path, const Params& p) : Inference(p) { @@ -245,7 +253,7 @@ public: auto& state = get_state(); if (sv.ctx != generic_state) LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR); - llama_set_state_data(state->ctx, sv.buf.data()); + llama_set_state_data(state->ctx, const_cast(sv.buf.data())); state->tokens = sv.tokens; state->prompt = sv.prompt; LM_CORETURN LM_BOOL_SUCCESS; diff --git a/justlm_llama_old.hpp b/justlm_llama_old.hpp deleted file mode 100644 index bf960e3..0000000 --- a/justlm_llama_old.hpp +++ /dev/null @@ -1,295 +0,0 @@ -#include "justlm.hpp" - -#include -#include -#include - - -namespace LM { -class LLaMAInference final : public Inference { - struct State { - llama_context *ctx = nullptr; - std::string prompt; // Mostly here for easy "debugging" - std::vector tokens; - unsigned n_ctx; - }; - - State*& get_state() { - return *reinterpret_cast(&generic_state); - } - State* const& get_state() const { - return *reinterpret_cast(&generic_state); - } - - LM_ERRBOOL init(const std::string& weights_path) LM_NOEXCEPTDECL { - auto& state = get_state(); - - // Allocate state - state = new State; - - // Get llama parameters - auto lparams = llama_context_default_params(); - lparams.seed = params.seed; - lparams.n_ctx = params.n_ctx = params.n_ctx>0?params.n_ctx:2024; - lparams.use_mlock = params.use_mlock; - - // Create context - state->ctx = llama_init_from_file(weights_path.c_str(), lparams); - if (!state->ctx) { - LM_THROW("Failed to initialize llama from file", LM_BOOL_ERROR); - } - - // Initialize some variables - state->n_ctx = llama_n_ctx(state->ctx); - - return LM_BOOL_SUCCESS; - } - - // This function reduces the size of our tokens vector according to some parameters - // All tokens will be evaluated if scrolling was needed and true will be returned - LM_SCHEDULABLE(bool) window_scroll() LM_NOEXCEPTDECL { - auto &state = get_state(); - // Check that we actually need to scroll - if (state->tokens.size() <= state->n_ctx) { - // Nope - LM_CORETURN false; - } - // Start scrolling - if (params.scroll_keep > 0.0f) { - // "Scroll" down the context window... - unsigned keep_count = float(state->tokens.size() - params.n_ctx_window_top_bar) * 0.4f; // We keep about 40% - // Get vector of tokens to keep - std::vector tokens_in_view(state->tokens.end()-keep_count, state->tokens.end()); - // Cut down tokens vector size - state->tokens.resize(params.n_ctx_window_top_bar+keep_count); - // Overwrite tokens after top bar with tokens in view - std::memcpy(state->tokens.data()+params.n_ctx_window_top_bar, tokens_in_view.data(), tokens_in_view.size()*sizeof(int)); - } else { - // Cut down tokens vector size to top bar - state->tokens.resize(params.n_ctx_window_top_bar); - } - // Evaluate tokens - LM_ERROR_FORWARD(LM_COAWAIT evaluate_tokens(0, on_scroll)); - LM_CORETURN true; - } - - LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const std::function &on_tick = nullptr) LM_NOEXCEPTDECL { - auto& state = get_state(); - - // Evaluate tokens in batches - unsigned it; - for (it = starting_offset; ; it += params.n_batch) { - if (it + params.n_batch >= ssize_t(state->tokens.size())) break; - - // Evaluate - if (llama_eval(state->ctx, state->tokens.data()+it, params.n_batch, it, params.n_threads)) { - LM_COTHROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR); - } - - // Tick - if (on_tick) { - // Calculate progress - auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f; - // Tick and yield - if (!on_tick(progress)) LM_CORETURN LM_BOOL_SUCCESS; - else if (!LM_TASKYIELD) LM_CORETURN LM_BOOL_SUCCESS; - } - } - - // Evaluate remaining tokens - if (it < state->tokens.size()) { - for (; it != state->tokens.size(); it++) { - if (llama_eval(state->ctx, state->tokens.data()+it, 1, it, params.n_threads)) { - LM_COTHROW("Failed to evaluate individual tokens", LM_BOOL_ERROR); - } - } - } - - // Notify about completion - if (on_tick) on_tick(100.f); - - LM_CORETURN LM_BOOL_SUCCESS; - } - -public: - LLaMAInference(const std::string& weights_path, const Params& p) : Inference(p) { - init(weights_path); - } - ~LLaMAInference() override { - auto& state = get_state(); - - if (state) { - if (state->ctx) llama_free(state->ctx); - delete state; - } - } - - LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function &on_tick = nullptr) LM_NOEXCEPTDECL override { - auto& state = get_state(); - - // Check if prompt was empty - const bool was_empty = state->prompt.empty(); - - // Append to current prompt - state->prompt.append(prompt); - - // Resize buffer for tokens - const auto old_token_count = state->tokens.size(); - state->tokens.resize(old_token_count+state->prompt.size()); - - // Run tokenizer - const auto token_count = llama_tokenize(state->ctx, prompt.c_str(), state->tokens.data()+old_token_count, state->tokens.size()-old_token_count, was_empty); - state->tokens.resize(old_token_count+token_count); - - // Make sure token limit isn't being hit - if (LM_COAWAIT window_scroll()) { - // That function already has evaluated our tokens since scrolling was needed - LM_CORETURN LM_BOOL_SUCCESS; - } - - // Evaluate new tokens - LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick); - } - - LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function &on_tick = nullptr) LM_NOEXCEPTDECL override { - auto& state = get_state(); - std::string fres; - - // Loop until done - bool abort = false; - unsigned eos_count = 0; - while (!abort && (end.empty() || fres.find(end) == fres.npos)) { - // Sample top p and top k - const auto n_repeat_last = std::min(state->tokens.size(), params.n_repeat_last); - auto id = llama_sample_top_p_top_k(state->ctx, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-n_repeat_last):nullptr, n_repeat_last, params.top_k, params.top_p, params.temp, params.repeat_penalty); - - if (id == llama_token_eos()) { - if (eos_count++ == params.eos_ignores) { - abort = true; - continue; - } - state->tokens.push_back(0); - llama_tokenize(state->ctx, "\n", &state->tokens.back(), 1, false); - id = state->tokens.back(); - } else { - // Add token - state->tokens.push_back(id); - } - - // Make sure token limit isn't hit - LM_COAWAIT window_scroll(); - - // Get token as string - const std::string_view str = llama_token_to_str(state->ctx, id); - - // Append string to function result - state->prompt.append(str); - fres.append(str); - - // Evaluate token - // TODO: Respect batch size - if (llama_eval(state->ctx, state->tokens.data()+state->tokens.size()-1, 1, state->tokens.size()-1, params.n_threads)) { - LM_COTHROW("Failed to evaluate new tokens", ""); - } - - // Tick and yield - if (on_tick && !on_tick(str.data())) abort = true; - else if (!LM_TASKYIELD) abort = true; - } - - // Create final string TODO: Could be optimized - if (!abort) { - fres = std::string(fres.data(), fres.size()-end.size()); - } - - // Return final string - LM_CORETURN fres; - } - - unsigned get_context_size() const noexcept override { - return get_state()->tokens.size(); - } - - LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override { - auto& state = get_state(); - sv.buf.resize(llama_get_state_size(state->ctx)); - llama_copy_state_data(state->ctx, sv.buf.data()); - sv.tokens = state->tokens; - sv.prompt = state->prompt; - sv.ctx = generic_state; - LM_CORETURN LM_BOOL_SUCCESS; - } - LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override { - auto& state = get_state(); - if (sv.ctx != generic_state) - LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR); - llama_set_state_data(state->ctx, sv.buf.data()); - state->tokens = sv.tokens; - state->prompt = sv.prompt; - LM_CORETURN LM_BOOL_SUCCESS; - } - - LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream &o) const LM_NOEXCEPTDECL override { - auto& state = get_state(); - // Get state size - auto state_size = llama_get_state_size(state->ctx); - // Write sizes - for (const uint32_t s : {static_cast(state->n_ctx), state->tokens.size(), state->prompt.size(), state_size}) { - if (!o.write(reinterpret_cast(&s), sizeof(s))) { - LM_COTHROW("Failed to serialize data sizes", LM_BOOL_ERROR); - } - } - // Write tokens - if (!o.write(reinterpret_cast(state->tokens.data()), state->tokens.size()*sizeof(int))) { - LM_COTHROW("Failed to serialize tokens", LM_BOOL_ERROR); - } - // Write prompt - if (!o.write(state->prompt.data(), state->prompt.size())) { - LM_COTHROW("Failed to serialize prompt", LM_BOOL_ERROR); - } - // Write state - std::vector state_buf(state_size); - llama_copy_state_data(state->ctx, state_buf.data()); - if (!o.write(reinterpret_cast(state_buf.data()), state_size)) { - LM_COTHROW("Failed to serialize state", LM_BOOL_ERROR); - } - LM_CORETURN LM_BOOL_SUCCESS; - } - LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream &i) LM_NOEXCEPTDECL override { - auto& state = get_state(); - uint32_t n_ctx, embd_size, prompt_size, state_size; - // Initialization to prevent compiler complaints - n_ctx = embd_size = prompt_size = state_size = 0; - // Read sizes - for (uint32_t *s : {&n_ctx, &embd_size, &prompt_size, &state_size}) { - if (!i.read(reinterpret_cast(s), sizeof(*s))) { - LM_COTHROW("Failed to deserialize data sizes", LM_BOOL_ERROR); - } - } - if (state->n_ctx != n_ctx) { - LM_COTHROW("Context length differs (My "+std::to_string(state->n_ctx)+" vs. files "+std::to_string(n_ctx)+')', LM_BOOL_ERROR); - } - // Read tokens - state->tokens.resize(embd_size); - if (!i.read(reinterpret_cast(state->tokens.data()), state->tokens.size()*sizeof(int))) { - LM_COTHROW("Failed to deserialize tokens", LM_BOOL_ERROR); - } - // Read prompt - state->prompt.resize(prompt_size); - if (!i.read(state->prompt.data(), state->prompt.size())) { - LM_COTHROW("Failed to deserialize prompt", LM_BOOL_ERROR); - } - // Read state - std::vector state_buf(state_size); - if (!i.read(reinterpret_cast(state_buf.data()), state_buf.size())) { - LM_COTHROW("Failed to deserialize state", LM_BOOL_ERROR); - } - llama_set_state_data(state->ctx, state_buf.data()); - LM_CORETURN LM_BOOL_SUCCESS; - } - - const std::string &get_prompt() const LM_NOEXCEPTDECL override { - return get_state()->prompt; - } -}; -} diff --git a/llama_mainline.cpp b/llama.cpp similarity index 92% rename from llama_mainline.cpp rename to llama.cpp index 7b606de..0484162 100644 --- a/llama_mainline.cpp +++ b/llama.cpp @@ -16,13 +16,13 @@ const LM::Implementation *get_justlm_implementation() { bool magic_match(std::istream& f) { // Check magic - uint32_t magic; + uint32_t magic = 0; f.read(reinterpret_cast(&magic), sizeof(magic)); if (magic != 0x67676a74) return false; // Check version uint32_t version = 0; f.read(reinterpret_cast(&version), sizeof(version)); - return version >= 3; + return version LLAMA_VERSIONS; } LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) { diff --git a/llama.cpp-old b/llama.cpp-230511 similarity index 100% rename from llama.cpp-old rename to llama.cpp-230511 diff --git a/llama.cpp-old2 b/llama.cpp-230519 similarity index 100% rename from llama.cpp-old2 rename to llama.cpp-230519 diff --git a/llama.cpp-mainline b/llama.cpp-mainline index 2d5db48..29cf559 160000 --- a/llama.cpp-mainline +++ b/llama.cpp-mainline @@ -1 +1 @@ -Subproject commit 2d5db48371052087a83974abda3767d1aedec598 +Subproject commit 29cf5596fe0c37213f9b74e80d8f631193a93f0f diff --git a/llama_old.cpp b/llama_old.cpp deleted file mode 100644 index 582e8ea..0000000 --- a/llama_old.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include "justlm_llama_old.hpp" -#include "justlm.hpp" - -#include -#include -#include -#include - - - -extern "C" { -const LM::Implementation *get_justlm_implementation() { - static LM::Implementation fres{true}; - return &fres; -} - -LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) { - f.close(); - return new LM::LLaMAInference(weights_path, p); -} -} diff --git a/llama_old2.cpp b/llama_old2.cpp deleted file mode 100644 index b8897f0..0000000 --- a/llama_old2.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include "justlm_llama.hpp" -#include "justlm.hpp" - -#include -#include -#include -#include - - - -extern "C" { -const LM::Implementation *get_justlm_implementation() { - static LM::Implementation fres{false}; - return &fres; -} - -bool magic_match(std::istream& f) { - // Check magic - uint32_t magic; - f.read(reinterpret_cast(&magic), sizeof(magic)); - if (magic != 0x67676a74) return false; - // Check version - uint32_t version = 0; - f.read(reinterpret_cast(&version), sizeof(version)); - return version == 2; -} - -LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) { - f.close(); - return new LM::LLaMAInference(weights_path, p); -} -}