From 5b01daa7645741a2ca519ca9f293b8f0e637bba6 Mon Sep 17 00:00:00 2001 From: niansa Date: Mon, 15 May 2023 14:46:19 +0200 Subject: [PATCH] Properly implemented MPT --- .gitmodules | 3 +++ CMakeLists.txt | 16 ++++++++++++-- gptj/gptj.cpp | 10 +++++---- justlm.cpp | 8 +++++-- llama.cpp-alibi | 1 + mpt/mpt.cpp | 55 +++++-------------------------------------------- 6 files changed, 35 insertions(+), 58 deletions(-) create mode 160000 llama.cpp-alibi diff --git a/.gitmodules b/.gitmodules index 0477fdd..3f9cc66 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "llama.cpp"] path = llama.cpp url = https://github.com/ggerganov/llama.cpp.git +[submodule "llama.cpp-alibi"] + path = llama.cpp-alibi + url = https://github.com/manyoso/llama.cpp.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 7053afc..38b50af 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,22 +9,34 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(LM_PYBIND No CACHE BOOL "If Libjustlm Python bindings should be build") set(LM_COSCHED No CACHE BOOL "If Libjustlm should make use of CoSched") set(LM_NOEXCEPT No CACHE BOOL "If exceptions should be disabled") +set(LM_MPT No CACHE BOOL "If MPT model support should be built") if (LM_COSCHED) set(CMAKE_CXX_STANDARD 20) endif() -add_subdirectory(llama.cpp) +if (LM_MPT) + set(LM_MPT_SOURCES justlm_mpt.hpp mpt/mpt.cpp mpt/mpt.hpp) + add_subdirectory(llama.cpp-alibi) +else() + set(LM_MPT_SOURCES ) + add_subdirectory(llama.cpp) +endif() + add_library(libjustlm STATIC include/justlm.hpp justlm.cpp justlm_llama.hpp g4a-common.cpp g4a-common.hpp - justlm_mpt.hpp mpt/mpt.cpp mpt/mpt.hpp justlm_gptj.hpp gptj/gptj.cpp gptj/gptj.hpp + ${LM_MPT_SOURCES} include/justlm_pool.hpp justlm_pool.cpp ) target_link_libraries(libjustlm PRIVATE llama) +if (LM_MPT) + target_compile_definitions(libjustlm PUBLIC LM_MPT) +endif() + if (LM_COSCHED) target_compile_definitions(libjustlm PUBLIC LM_COSCHED) target_link_libraries(libjustlm PRIVATE cosched) diff --git a/gptj/gptj.cpp b/gptj/gptj.cpp index 02e255d..9f4a062 100644 --- a/gptj/gptj.cpp +++ b/gptj/gptj.cpp @@ -16,8 +16,10 @@ #include #include -// default hparams (GPT-J 6B) -static const size_t MB = 1024*1024; +constexpr inline +unsigned long long operator ""_MB(unsigned long long bytes) { + return bytes*1024*1024; +} static bool kv_cache_init( const struct gptj_hparams & hparams, @@ -30,7 +32,7 @@ static bool kv_cache_init( const int64_t n_mem = (int64_t)n_layer*n_ctx; const int64_t n_elements = n_embd*n_mem; - cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); + cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MB); struct ggml_init_params params; params.mem_size = cache.buf.size; @@ -392,7 +394,7 @@ bool gptj_eval( const int n_vocab = hparams.n_vocab; const int n_rot = hparams.n_rot; - static size_t buf_size = 1024u*MB; + static size_t buf_size = 1024_MB; if (!model.buf.addr || model.buf.size < buf_size) model.buf.resize(buf_size); diff --git a/justlm.cpp b/justlm.cpp index fbbf0b3..4f68b6e 100644 --- a/justlm.cpp +++ b/justlm.cpp @@ -1,7 +1,9 @@ #include "justlm.hpp" #include "justlm_llama.hpp" #include "justlm_gptj.hpp" -#include "justlm_mpt.hpp" +#ifdef LM_MPT +# include "justlm_mpt.hpp" +#endif #include @@ -12,13 +14,15 @@ LM::Inference *LM::Inference::construct(const std::string &weights_path, const P std::ifstream f(weights_path, std::ios::binary); uint32_t magic; f.read(reinterpret_cast(&magic), sizeof(magic)); - // Create model + // Create inference instance if (magic == 0x67676d6c) { f.seekg(0); return new GPTJInference(weights_path, f, p); +# ifdef LM_MPT } else if (magic == 0x67676d6d) { f.seekg(0); return new MPTInference(weights_path, f, p); +# endif } else { f.close(); return new LLaMaInference(weights_path, p); diff --git a/llama.cpp-alibi b/llama.cpp-alibi new file mode 160000 index 0000000..03ceb39 --- /dev/null +++ b/llama.cpp-alibi @@ -0,0 +1 @@ +Subproject commit 03ceb39c1e729bed4ad1dfa16638a72f1843bf0c diff --git a/mpt/mpt.cpp b/mpt/mpt.cpp index d31eaba..ef972b3 100644 --- a/mpt/mpt.cpp +++ b/mpt/mpt.cpp @@ -17,7 +17,10 @@ #include #include -static const size_t MB = 1024*1024; +inline +unsigned long long operator ""_MB(unsigned long long bytes) { + return bytes*1024*1024; +} static bool kv_cache_init( const struct mpt_hparams & hparams, @@ -30,7 +33,7 @@ static bool kv_cache_init( const int64_t n_mem = (int64_t)n_layer*n_ctx; const int64_t n_elements = n_embd*n_mem; - cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB); + cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MB); struct ggml_init_params params; params.mem_size = cache.buf.size; @@ -202,7 +205,6 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod const int n_embd = hparams.n_embd; const int n_layer = hparams.n_layer; - const int n_ctx = hparams.n_ctx; const int n_vocab = hparams.n_vocab; const int expand = hparams.expand; @@ -241,13 +243,6 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod { const auto & hparams = model.hparams; - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - const int n_ctx = hparams.n_ctx; - - const int n_mem = n_layer*n_ctx; - const int n_elements = n_embd*n_mem; - if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F16, model.hparams.n_ctx)) { fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); ggml_free(ctx); @@ -350,37 +345,6 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, mpt_vocab & vo return loaded; } -struct ggml_tensor * ggml_alibi( - struct ggml_context * ctx, - struct ggml_tensor * a, - int n_past, - int n_head) { - if (n_past < 0) { - return NULL; - } - bool is_node = false; - - if (a->grad) { - return NULL; // TODO: implement backward - is_node = true; - } - - // TODO: when implement backward, fix this: - //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); - struct ggml_tensor * result = ggml_view_tensor(ctx, a); - - struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3); - ((int32_t *) b->data)[0] = n_past; - ((int32_t *) b->data)[1] = n_head; - - result->op = GGML_OP_COUNT; // Dirty hack - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; - result->src0 = a; - result->src1 = b; - - return result; -} - bool mpt_eval( mpt_model & model, const int n_threads, @@ -397,9 +361,6 @@ bool mpt_eval( const int n_ctx = hparams.n_ctx; const int n_head = hparams.n_head; const int n_vocab = hparams.n_vocab; - const int expand = hparams.expand; - - const int d_key = n_embd/n_head; static size_t buf_size = 256u*1024*1024; static void * buf = malloc(buf_size); @@ -560,12 +521,10 @@ bool mpt_eval( out = ggml_mul_mat(ctx0, model.wte, out); } - // run the computation ggml_build_forward_expand(&gf, out); ggml_graph_compute (ctx0, &gf); - // return result for just the last token embd_w.resize(n_vocab); memcpy(embd_w.data(), (float *) ggml_get_data(out) + (n_vocab*(N-1)), sizeof(float)*n_vocab); @@ -728,8 +687,6 @@ size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint } const size_t written = out - dest; - const size_t expected = mpt_get_state_size(model); - assert(written == expected); fflush(stdout); return written; } @@ -876,8 +833,6 @@ size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *sr } const size_t nread = in - src; - const size_t expected = mpt_get_state_size(*model); - assert(nread == expected); fflush(stdout); return nread; }