mirror of
https://gitlab.com/niansa/libjustlm.git
synced 2025-03-06 20:49:17 +01:00
Properly implemented MPT
This commit is contained in:
parent
6d2910b7b9
commit
5b01daa764
6 changed files with 35 additions and 58 deletions
3
.gitmodules
vendored
3
.gitmodules
vendored
|
@ -1,3 +1,6 @@
|
||||||
[submodule "llama.cpp"]
|
[submodule "llama.cpp"]
|
||||||
path = llama.cpp
|
path = llama.cpp
|
||||||
url = https://github.com/ggerganov/llama.cpp.git
|
url = https://github.com/ggerganov/llama.cpp.git
|
||||||
|
[submodule "llama.cpp-alibi"]
|
||||||
|
path = llama.cpp-alibi
|
||||||
|
url = https://github.com/manyoso/llama.cpp.git
|
||||||
|
|
|
@ -9,22 +9,34 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
set(LM_PYBIND No CACHE BOOL "If Libjustlm Python bindings should be build")
|
set(LM_PYBIND No CACHE BOOL "If Libjustlm Python bindings should be build")
|
||||||
set(LM_COSCHED No CACHE BOOL "If Libjustlm should make use of CoSched")
|
set(LM_COSCHED No CACHE BOOL "If Libjustlm should make use of CoSched")
|
||||||
set(LM_NOEXCEPT No CACHE BOOL "If exceptions should be disabled")
|
set(LM_NOEXCEPT No CACHE BOOL "If exceptions should be disabled")
|
||||||
|
set(LM_MPT No CACHE BOOL "If MPT model support should be built")
|
||||||
|
|
||||||
if (LM_COSCHED)
|
if (LM_COSCHED)
|
||||||
set(CMAKE_CXX_STANDARD 20)
|
set(CMAKE_CXX_STANDARD 20)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_subdirectory(llama.cpp)
|
if (LM_MPT)
|
||||||
|
set(LM_MPT_SOURCES justlm_mpt.hpp mpt/mpt.cpp mpt/mpt.hpp)
|
||||||
|
add_subdirectory(llama.cpp-alibi)
|
||||||
|
else()
|
||||||
|
set(LM_MPT_SOURCES )
|
||||||
|
add_subdirectory(llama.cpp)
|
||||||
|
endif()
|
||||||
|
|
||||||
add_library(libjustlm STATIC
|
add_library(libjustlm STATIC
|
||||||
include/justlm.hpp justlm.cpp
|
include/justlm.hpp justlm.cpp
|
||||||
justlm_llama.hpp
|
justlm_llama.hpp
|
||||||
g4a-common.cpp g4a-common.hpp
|
g4a-common.cpp g4a-common.hpp
|
||||||
justlm_mpt.hpp mpt/mpt.cpp mpt/mpt.hpp
|
|
||||||
justlm_gptj.hpp gptj/gptj.cpp gptj/gptj.hpp
|
justlm_gptj.hpp gptj/gptj.cpp gptj/gptj.hpp
|
||||||
|
${LM_MPT_SOURCES}
|
||||||
include/justlm_pool.hpp justlm_pool.cpp
|
include/justlm_pool.hpp justlm_pool.cpp
|
||||||
)
|
)
|
||||||
target_link_libraries(libjustlm PRIVATE llama)
|
target_link_libraries(libjustlm PRIVATE llama)
|
||||||
|
|
||||||
|
if (LM_MPT)
|
||||||
|
target_compile_definitions(libjustlm PUBLIC LM_MPT)
|
||||||
|
endif()
|
||||||
|
|
||||||
if (LM_COSCHED)
|
if (LM_COSCHED)
|
||||||
target_compile_definitions(libjustlm PUBLIC LM_COSCHED)
|
target_compile_definitions(libjustlm PUBLIC LM_COSCHED)
|
||||||
target_link_libraries(libjustlm PRIVATE cosched)
|
target_link_libraries(libjustlm PRIVATE cosched)
|
||||||
|
|
|
@ -16,8 +16,10 @@
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <ggml.h>
|
#include <ggml.h>
|
||||||
|
|
||||||
// default hparams (GPT-J 6B)
|
constexpr inline
|
||||||
static const size_t MB = 1024*1024;
|
unsigned long long operator ""_MB(unsigned long long bytes) {
|
||||||
|
return bytes*1024*1024;
|
||||||
|
}
|
||||||
|
|
||||||
static bool kv_cache_init(
|
static bool kv_cache_init(
|
||||||
const struct gptj_hparams & hparams,
|
const struct gptj_hparams & hparams,
|
||||||
|
@ -30,7 +32,7 @@ static bool kv_cache_init(
|
||||||
const int64_t n_mem = (int64_t)n_layer*n_ctx;
|
const int64_t n_mem = (int64_t)n_layer*n_ctx;
|
||||||
const int64_t n_elements = n_embd*n_mem;
|
const int64_t n_elements = n_embd*n_mem;
|
||||||
|
|
||||||
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
|
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MB);
|
||||||
|
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
params.mem_size = cache.buf.size;
|
params.mem_size = cache.buf.size;
|
||||||
|
@ -392,7 +394,7 @@ bool gptj_eval(
|
||||||
const int n_vocab = hparams.n_vocab;
|
const int n_vocab = hparams.n_vocab;
|
||||||
const int n_rot = hparams.n_rot;
|
const int n_rot = hparams.n_rot;
|
||||||
|
|
||||||
static size_t buf_size = 1024u*MB;
|
static size_t buf_size = 1024_MB;
|
||||||
if (!model.buf.addr || model.buf.size < buf_size)
|
if (!model.buf.addr || model.buf.size < buf_size)
|
||||||
model.buf.resize(buf_size);
|
model.buf.resize(buf_size);
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
#include "justlm.hpp"
|
#include "justlm.hpp"
|
||||||
#include "justlm_llama.hpp"
|
#include "justlm_llama.hpp"
|
||||||
#include "justlm_gptj.hpp"
|
#include "justlm_gptj.hpp"
|
||||||
#include "justlm_mpt.hpp"
|
#ifdef LM_MPT
|
||||||
|
# include "justlm_mpt.hpp"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
|
||||||
|
@ -12,13 +14,15 @@ LM::Inference *LM::Inference::construct(const std::string &weights_path, const P
|
||||||
std::ifstream f(weights_path, std::ios::binary);
|
std::ifstream f(weights_path, std::ios::binary);
|
||||||
uint32_t magic;
|
uint32_t magic;
|
||||||
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||||
// Create model
|
// Create inference instance
|
||||||
if (magic == 0x67676d6c) {
|
if (magic == 0x67676d6c) {
|
||||||
f.seekg(0);
|
f.seekg(0);
|
||||||
return new GPTJInference(weights_path, f, p);
|
return new GPTJInference(weights_path, f, p);
|
||||||
|
# ifdef LM_MPT
|
||||||
} else if (magic == 0x67676d6d) {
|
} else if (magic == 0x67676d6d) {
|
||||||
f.seekg(0);
|
f.seekg(0);
|
||||||
return new MPTInference(weights_path, f, p);
|
return new MPTInference(weights_path, f, p);
|
||||||
|
# endif
|
||||||
} else {
|
} else {
|
||||||
f.close();
|
f.close();
|
||||||
return new LLaMaInference(weights_path, p);
|
return new LLaMaInference(weights_path, p);
|
||||||
|
|
1
llama.cpp-alibi
Submodule
1
llama.cpp-alibi
Submodule
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 03ceb39c1e729bed4ad1dfa16638a72f1843bf0c
|
55
mpt/mpt.cpp
55
mpt/mpt.cpp
|
@ -17,7 +17,10 @@
|
||||||
#include <regex>
|
#include <regex>
|
||||||
#include <ggml.h>
|
#include <ggml.h>
|
||||||
|
|
||||||
static const size_t MB = 1024*1024;
|
inline
|
||||||
|
unsigned long long operator ""_MB(unsigned long long bytes) {
|
||||||
|
return bytes*1024*1024;
|
||||||
|
}
|
||||||
|
|
||||||
static bool kv_cache_init(
|
static bool kv_cache_init(
|
||||||
const struct mpt_hparams & hparams,
|
const struct mpt_hparams & hparams,
|
||||||
|
@ -30,7 +33,7 @@ static bool kv_cache_init(
|
||||||
const int64_t n_mem = (int64_t)n_layer*n_ctx;
|
const int64_t n_mem = (int64_t)n_layer*n_ctx;
|
||||||
const int64_t n_elements = n_embd*n_mem;
|
const int64_t n_elements = n_embd*n_mem;
|
||||||
|
|
||||||
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
|
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MB);
|
||||||
|
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
params.mem_size = cache.buf.size;
|
params.mem_size = cache.buf.size;
|
||||||
|
@ -202,7 +205,6 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
|
||||||
|
|
||||||
const int n_embd = hparams.n_embd;
|
const int n_embd = hparams.n_embd;
|
||||||
const int n_layer = hparams.n_layer;
|
const int n_layer = hparams.n_layer;
|
||||||
const int n_ctx = hparams.n_ctx;
|
|
||||||
const int n_vocab = hparams.n_vocab;
|
const int n_vocab = hparams.n_vocab;
|
||||||
const int expand = hparams.expand;
|
const int expand = hparams.expand;
|
||||||
|
|
||||||
|
@ -241,13 +243,6 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
|
||||||
{
|
{
|
||||||
const auto & hparams = model.hparams;
|
const auto & hparams = model.hparams;
|
||||||
|
|
||||||
const int n_embd = hparams.n_embd;
|
|
||||||
const int n_layer = hparams.n_layer;
|
|
||||||
const int n_ctx = hparams.n_ctx;
|
|
||||||
|
|
||||||
const int n_mem = n_layer*n_ctx;
|
|
||||||
const int n_elements = n_embd*n_mem;
|
|
||||||
|
|
||||||
if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F16, model.hparams.n_ctx)) {
|
if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F16, model.hparams.n_ctx)) {
|
||||||
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
|
@ -350,37 +345,6 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, mpt_vocab & vo
|
||||||
return loaded;
|
return loaded;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * ggml_alibi(
|
|
||||||
struct ggml_context * ctx,
|
|
||||||
struct ggml_tensor * a,
|
|
||||||
int n_past,
|
|
||||||
int n_head) {
|
|
||||||
if (n_past < 0) {
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
bool is_node = false;
|
|
||||||
|
|
||||||
if (a->grad) {
|
|
||||||
return NULL; // TODO: implement backward
|
|
||||||
is_node = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: when implement backward, fix this:
|
|
||||||
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
|
||||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
|
||||||
|
|
||||||
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
|
|
||||||
((int32_t *) b->data)[0] = n_past;
|
|
||||||
((int32_t *) b->data)[1] = n_head;
|
|
||||||
|
|
||||||
result->op = GGML_OP_COUNT; // Dirty hack
|
|
||||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
|
||||||
result->src0 = a;
|
|
||||||
result->src1 = b;
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool mpt_eval(
|
bool mpt_eval(
|
||||||
mpt_model & model,
|
mpt_model & model,
|
||||||
const int n_threads,
|
const int n_threads,
|
||||||
|
@ -397,9 +361,6 @@ bool mpt_eval(
|
||||||
const int n_ctx = hparams.n_ctx;
|
const int n_ctx = hparams.n_ctx;
|
||||||
const int n_head = hparams.n_head;
|
const int n_head = hparams.n_head;
|
||||||
const int n_vocab = hparams.n_vocab;
|
const int n_vocab = hparams.n_vocab;
|
||||||
const int expand = hparams.expand;
|
|
||||||
|
|
||||||
const int d_key = n_embd/n_head;
|
|
||||||
|
|
||||||
static size_t buf_size = 256u*1024*1024;
|
static size_t buf_size = 256u*1024*1024;
|
||||||
static void * buf = malloc(buf_size);
|
static void * buf = malloc(buf_size);
|
||||||
|
@ -560,12 +521,10 @@ bool mpt_eval(
|
||||||
out = ggml_mul_mat(ctx0, model.wte, out);
|
out = ggml_mul_mat(ctx0, model.wte, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// run the computation
|
// run the computation
|
||||||
ggml_build_forward_expand(&gf, out);
|
ggml_build_forward_expand(&gf, out);
|
||||||
ggml_graph_compute (ctx0, &gf);
|
ggml_graph_compute (ctx0, &gf);
|
||||||
|
|
||||||
|
|
||||||
// return result for just the last token
|
// return result for just the last token
|
||||||
embd_w.resize(n_vocab);
|
embd_w.resize(n_vocab);
|
||||||
memcpy(embd_w.data(), (float *) ggml_get_data(out) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
memcpy(embd_w.data(), (float *) ggml_get_data(out) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||||
|
@ -728,8 +687,6 @@ size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t written = out - dest;
|
const size_t written = out - dest;
|
||||||
const size_t expected = mpt_get_state_size(model);
|
|
||||||
assert(written == expected);
|
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
return written;
|
return written;
|
||||||
}
|
}
|
||||||
|
@ -876,8 +833,6 @@ size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *sr
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t nread = in - src;
|
const size_t nread = in - src;
|
||||||
const size_t expected = mpt_get_state_size(*model);
|
|
||||||
assert(nread == expected);
|
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
return nread;
|
return nread;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue