mirror of
https://gitlab.com/niansa/libjustlm.git
synced 2025-03-06 20:49:17 +01:00
Properly implemented MPT
This commit is contained in:
parent
6d2910b7b9
commit
5b01daa764
6 changed files with 35 additions and 58 deletions
3
.gitmodules
vendored
3
.gitmodules
vendored
|
@ -1,3 +1,6 @@
|
|||
[submodule "llama.cpp"]
|
||||
path = llama.cpp
|
||||
url = https://github.com/ggerganov/llama.cpp.git
|
||||
[submodule "llama.cpp-alibi"]
|
||||
path = llama.cpp-alibi
|
||||
url = https://github.com/manyoso/llama.cpp.git
|
||||
|
|
|
@ -9,22 +9,34 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
|||
set(LM_PYBIND No CACHE BOOL "If Libjustlm Python bindings should be build")
|
||||
set(LM_COSCHED No CACHE BOOL "If Libjustlm should make use of CoSched")
|
||||
set(LM_NOEXCEPT No CACHE BOOL "If exceptions should be disabled")
|
||||
set(LM_MPT No CACHE BOOL "If MPT model support should be built")
|
||||
|
||||
if (LM_COSCHED)
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
endif()
|
||||
|
||||
add_subdirectory(llama.cpp)
|
||||
if (LM_MPT)
|
||||
set(LM_MPT_SOURCES justlm_mpt.hpp mpt/mpt.cpp mpt/mpt.hpp)
|
||||
add_subdirectory(llama.cpp-alibi)
|
||||
else()
|
||||
set(LM_MPT_SOURCES )
|
||||
add_subdirectory(llama.cpp)
|
||||
endif()
|
||||
|
||||
add_library(libjustlm STATIC
|
||||
include/justlm.hpp justlm.cpp
|
||||
justlm_llama.hpp
|
||||
g4a-common.cpp g4a-common.hpp
|
||||
justlm_mpt.hpp mpt/mpt.cpp mpt/mpt.hpp
|
||||
justlm_gptj.hpp gptj/gptj.cpp gptj/gptj.hpp
|
||||
${LM_MPT_SOURCES}
|
||||
include/justlm_pool.hpp justlm_pool.cpp
|
||||
)
|
||||
target_link_libraries(libjustlm PRIVATE llama)
|
||||
|
||||
if (LM_MPT)
|
||||
target_compile_definitions(libjustlm PUBLIC LM_MPT)
|
||||
endif()
|
||||
|
||||
if (LM_COSCHED)
|
||||
target_compile_definitions(libjustlm PUBLIC LM_COSCHED)
|
||||
target_link_libraries(libjustlm PRIVATE cosched)
|
||||
|
|
|
@ -16,8 +16,10 @@
|
|||
#include <unordered_set>
|
||||
#include <ggml.h>
|
||||
|
||||
// default hparams (GPT-J 6B)
|
||||
static const size_t MB = 1024*1024;
|
||||
constexpr inline
|
||||
unsigned long long operator ""_MB(unsigned long long bytes) {
|
||||
return bytes*1024*1024;
|
||||
}
|
||||
|
||||
static bool kv_cache_init(
|
||||
const struct gptj_hparams & hparams,
|
||||
|
@ -30,7 +32,7 @@ static bool kv_cache_init(
|
|||
const int64_t n_mem = (int64_t)n_layer*n_ctx;
|
||||
const int64_t n_elements = n_embd*n_mem;
|
||||
|
||||
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
|
||||
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MB);
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = cache.buf.size;
|
||||
|
@ -392,7 +394,7 @@ bool gptj_eval(
|
|||
const int n_vocab = hparams.n_vocab;
|
||||
const int n_rot = hparams.n_rot;
|
||||
|
||||
static size_t buf_size = 1024u*MB;
|
||||
static size_t buf_size = 1024_MB;
|
||||
if (!model.buf.addr || model.buf.size < buf_size)
|
||||
model.buf.resize(buf_size);
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
#include "justlm.hpp"
|
||||
#include "justlm_llama.hpp"
|
||||
#include "justlm_gptj.hpp"
|
||||
#include "justlm_mpt.hpp"
|
||||
#ifdef LM_MPT
|
||||
# include "justlm_mpt.hpp"
|
||||
#endif
|
||||
|
||||
#include <fstream>
|
||||
|
||||
|
@ -12,13 +14,15 @@ LM::Inference *LM::Inference::construct(const std::string &weights_path, const P
|
|||
std::ifstream f(weights_path, std::ios::binary);
|
||||
uint32_t magic;
|
||||
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||
// Create model
|
||||
// Create inference instance
|
||||
if (magic == 0x67676d6c) {
|
||||
f.seekg(0);
|
||||
return new GPTJInference(weights_path, f, p);
|
||||
# ifdef LM_MPT
|
||||
} else if (magic == 0x67676d6d) {
|
||||
f.seekg(0);
|
||||
return new MPTInference(weights_path, f, p);
|
||||
# endif
|
||||
} else {
|
||||
f.close();
|
||||
return new LLaMaInference(weights_path, p);
|
||||
|
|
1
llama.cpp-alibi
Submodule
1
llama.cpp-alibi
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit 03ceb39c1e729bed4ad1dfa16638a72f1843bf0c
|
55
mpt/mpt.cpp
55
mpt/mpt.cpp
|
@ -17,7 +17,10 @@
|
|||
#include <regex>
|
||||
#include <ggml.h>
|
||||
|
||||
static const size_t MB = 1024*1024;
|
||||
inline
|
||||
unsigned long long operator ""_MB(unsigned long long bytes) {
|
||||
return bytes*1024*1024;
|
||||
}
|
||||
|
||||
static bool kv_cache_init(
|
||||
const struct mpt_hparams & hparams,
|
||||
|
@ -30,7 +33,7 @@ static bool kv_cache_init(
|
|||
const int64_t n_mem = (int64_t)n_layer*n_ctx;
|
||||
const int64_t n_elements = n_embd*n_mem;
|
||||
|
||||
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
|
||||
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MB);
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = cache.buf.size;
|
||||
|
@ -202,7 +205,6 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
|
|||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
const int expand = hparams.expand;
|
||||
|
||||
|
@ -241,13 +243,6 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
|
|||
{
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
|
||||
const int n_mem = n_layer*n_ctx;
|
||||
const int n_elements = n_embd*n_mem;
|
||||
|
||||
if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F16, model.hparams.n_ctx)) {
|
||||
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
ggml_free(ctx);
|
||||
|
@ -350,37 +345,6 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, mpt_vocab & vo
|
|||
return loaded;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_alibi(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past,
|
||||
int n_head) {
|
||||
if (n_past < 0) {
|
||||
return NULL;
|
||||
}
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
return NULL; // TODO: implement backward
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
// TODO: when implement backward, fix this:
|
||||
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
||||
|
||||
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
|
||||
((int32_t *) b->data)[0] = n_past;
|
||||
((int32_t *) b->data)[1] = n_head;
|
||||
|
||||
result->op = GGML_OP_COUNT; // Dirty hack
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src0 = a;
|
||||
result->src1 = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
bool mpt_eval(
|
||||
mpt_model & model,
|
||||
const int n_threads,
|
||||
|
@ -397,9 +361,6 @@ bool mpt_eval(
|
|||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_head = hparams.n_head;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
const int expand = hparams.expand;
|
||||
|
||||
const int d_key = n_embd/n_head;
|
||||
|
||||
static size_t buf_size = 256u*1024*1024;
|
||||
static void * buf = malloc(buf_size);
|
||||
|
@ -560,12 +521,10 @@ bool mpt_eval(
|
|||
out = ggml_mul_mat(ctx0, model.wte, out);
|
||||
}
|
||||
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(&gf, out);
|
||||
ggml_graph_compute (ctx0, &gf);
|
||||
|
||||
|
||||
// return result for just the last token
|
||||
embd_w.resize(n_vocab);
|
||||
memcpy(embd_w.data(), (float *) ggml_get_data(out) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||
|
@ -728,8 +687,6 @@ size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint
|
|||
}
|
||||
|
||||
const size_t written = out - dest;
|
||||
const size_t expected = mpt_get_state_size(model);
|
||||
assert(written == expected);
|
||||
fflush(stdout);
|
||||
return written;
|
||||
}
|
||||
|
@ -876,8 +833,6 @@ size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *sr
|
|||
}
|
||||
|
||||
const size_t nread = in - src;
|
||||
const size_t expected = mpt_get_state_size(*model);
|
||||
assert(nread == expected);
|
||||
fflush(stdout);
|
||||
return nread;
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue