1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00

Properly implemented MPT

This commit is contained in:
niansa/tuxifan 2023-05-15 14:46:19 +02:00
parent 6d2910b7b9
commit 5b01daa764
6 changed files with 35 additions and 58 deletions

3
.gitmodules vendored
View file

@ -1,3 +1,6 @@
[submodule "llama.cpp"]
path = llama.cpp
url = https://github.com/ggerganov/llama.cpp.git
[submodule "llama.cpp-alibi"]
path = llama.cpp-alibi
url = https://github.com/manyoso/llama.cpp.git

View file

@ -9,22 +9,34 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(LM_PYBIND No CACHE BOOL "If Libjustlm Python bindings should be build")
set(LM_COSCHED No CACHE BOOL "If Libjustlm should make use of CoSched")
set(LM_NOEXCEPT No CACHE BOOL "If exceptions should be disabled")
set(LM_MPT No CACHE BOOL "If MPT model support should be built")
if (LM_COSCHED)
set(CMAKE_CXX_STANDARD 20)
endif()
add_subdirectory(llama.cpp)
if (LM_MPT)
set(LM_MPT_SOURCES justlm_mpt.hpp mpt/mpt.cpp mpt/mpt.hpp)
add_subdirectory(llama.cpp-alibi)
else()
set(LM_MPT_SOURCES )
add_subdirectory(llama.cpp)
endif()
add_library(libjustlm STATIC
include/justlm.hpp justlm.cpp
justlm_llama.hpp
g4a-common.cpp g4a-common.hpp
justlm_mpt.hpp mpt/mpt.cpp mpt/mpt.hpp
justlm_gptj.hpp gptj/gptj.cpp gptj/gptj.hpp
${LM_MPT_SOURCES}
include/justlm_pool.hpp justlm_pool.cpp
)
target_link_libraries(libjustlm PRIVATE llama)
if (LM_MPT)
target_compile_definitions(libjustlm PUBLIC LM_MPT)
endif()
if (LM_COSCHED)
target_compile_definitions(libjustlm PUBLIC LM_COSCHED)
target_link_libraries(libjustlm PRIVATE cosched)

View file

@ -16,8 +16,10 @@
#include <unordered_set>
#include <ggml.h>
// default hparams (GPT-J 6B)
static const size_t MB = 1024*1024;
constexpr inline
unsigned long long operator ""_MB(unsigned long long bytes) {
return bytes*1024*1024;
}
static bool kv_cache_init(
const struct gptj_hparams & hparams,
@ -30,7 +32,7 @@ static bool kv_cache_init(
const int64_t n_mem = (int64_t)n_layer*n_ctx;
const int64_t n_elements = n_embd*n_mem;
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MB);
struct ggml_init_params params;
params.mem_size = cache.buf.size;
@ -392,7 +394,7 @@ bool gptj_eval(
const int n_vocab = hparams.n_vocab;
const int n_rot = hparams.n_rot;
static size_t buf_size = 1024u*MB;
static size_t buf_size = 1024_MB;
if (!model.buf.addr || model.buf.size < buf_size)
model.buf.resize(buf_size);

View file

@ -1,7 +1,9 @@
#include "justlm.hpp"
#include "justlm_llama.hpp"
#include "justlm_gptj.hpp"
#include "justlm_mpt.hpp"
#ifdef LM_MPT
# include "justlm_mpt.hpp"
#endif
#include <fstream>
@ -12,13 +14,15 @@ LM::Inference *LM::Inference::construct(const std::string &weights_path, const P
std::ifstream f(weights_path, std::ios::binary);
uint32_t magic;
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
// Create model
// Create inference instance
if (magic == 0x67676d6c) {
f.seekg(0);
return new GPTJInference(weights_path, f, p);
# ifdef LM_MPT
} else if (magic == 0x67676d6d) {
f.seekg(0);
return new MPTInference(weights_path, f, p);
# endif
} else {
f.close();
return new LLaMaInference(weights_path, p);

1
llama.cpp-alibi Submodule

@ -0,0 +1 @@
Subproject commit 03ceb39c1e729bed4ad1dfa16638a72f1843bf0c

View file

@ -17,7 +17,10 @@
#include <regex>
#include <ggml.h>
static const size_t MB = 1024*1024;
inline
unsigned long long operator ""_MB(unsigned long long bytes) {
return bytes*1024*1024;
}
static bool kv_cache_init(
const struct mpt_hparams & hparams,
@ -30,7 +33,7 @@ static bool kv_cache_init(
const int64_t n_mem = (int64_t)n_layer*n_ctx;
const int64_t n_elements = n_embd*n_mem;
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MB);
struct ggml_init_params params;
params.mem_size = cache.buf.size;
@ -202,7 +205,6 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
const int expand = hparams.expand;
@ -241,13 +243,6 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_mem = n_layer*n_ctx;
const int n_elements = n_embd*n_mem;
if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F16, model.hparams.n_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
ggml_free(ctx);
@ -350,37 +345,6 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, mpt_vocab & vo
return loaded;
}
struct ggml_tensor * ggml_alibi(
struct ggml_context * ctx,
struct ggml_tensor * a,
int n_past,
int n_head) {
if (n_past < 0) {
return NULL;
}
bool is_node = false;
if (a->grad) {
return NULL; // TODO: implement backward
is_node = true;
}
// TODO: when implement backward, fix this:
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
((int32_t *) b->data)[0] = n_past;
((int32_t *) b->data)[1] = n_head;
result->op = GGML_OP_COUNT; // Dirty hack
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src0 = a;
result->src1 = b;
return result;
}
bool mpt_eval(
mpt_model & model,
const int n_threads,
@ -397,9 +361,6 @@ bool mpt_eval(
const int n_ctx = hparams.n_ctx;
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
const int expand = hparams.expand;
const int d_key = n_embd/n_head;
static size_t buf_size = 256u*1024*1024;
static void * buf = malloc(buf_size);
@ -560,12 +521,10 @@ bool mpt_eval(
out = ggml_mul_mat(ctx0, model.wte, out);
}
// run the computation
ggml_build_forward_expand(&gf, out);
ggml_graph_compute (ctx0, &gf);
// return result for just the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(out) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
@ -728,8 +687,6 @@ size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint
}
const size_t written = out - dest;
const size_t expected = mpt_get_state_size(model);
assert(written == expected);
fflush(stdout);
return written;
}
@ -876,8 +833,6 @@ size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *sr
}
const size_t nread = in - src;
const size_t expected = mpt_get_state_size(*model);
assert(nread == expected);
fflush(stdout);
return nread;
}