1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00

Added GPT-J serialization/deserialization

This commit is contained in:
niansa 2023-05-07 12:02:04 +02:00
parent ca33a27e05
commit e4832f1077
6 changed files with 278 additions and 54 deletions

View file

@ -1,5 +1,7 @@
#include "gptj.hpp"
#include "utils.hpp"
#include <cassert>
#include <cmath>
#include <cstdio>
@ -10,6 +12,43 @@
#include <vector>
#include <iostream>
#include <unistd.h>
#include <sstream>
#include <unordered_set>
#include <ggml.h>
// default hparams (GPT-J 6B)
static const size_t MB = 1024*1024;
static bool kv_cache_init(
const struct gptj_hparams & hparams,
struct gptj_kv_cache & cache,
ggml_type wtype,
int n_ctx) {
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int64_t n_mem = (int64_t)n_layer*n_ctx;
const int64_t n_elements = n_embd*n_mem;
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
struct ggml_init_params params;
params.mem_size = cache.buf.size;
params.mem_buffer = cache.buf.addr;
params.no_alloc = false;
cache.ctx = ggml_init(params);
if (!cache.ctx) {
fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
return false;
}
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
return true;
}
// load the model's weights from a stream
bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & model, gpt_vocab & vocab) {
@ -151,7 +190,6 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
model.layers.resize(n_layer);
@ -213,19 +251,14 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F32, model.hparams.n_ctx)) {
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
ggml_free(ctx);
return false;
}
const int n_mem = n_layer*n_ctx;
const int n_elements = n_embd*n_mem;
model.memory_k = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
model.memory_v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_elements);
const size_t memory_size = ggml_nbytes(model.memory_k) + ggml_nbytes(model.memory_v);
printf("%s: memory_size = %8.2f MB, n_mem = %d\n", __func__, memory_size/1024.0/1024.0, n_mem);
const size_t memory_size = ggml_nbytes(model.kv_self.k) + ggml_nbytes(model.kv_self.v);
printf("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
}
// load weights
@ -320,7 +353,6 @@ bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & m
// load the model's weights from a file path
bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & vocab) {
auto fin = std::ifstream(fname, std::ios::binary);
if (!fin) {
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
@ -343,7 +375,7 @@ bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab &
// The GPT-J model requires about 16MB of memory per input token.
//
bool gptj_eval(
const gptj_model & model,
gptj_model & model,
const int n_threads,
const int n_past,
const std::vector<gpt_vocab::id> & embd_inp,
@ -360,25 +392,25 @@ bool gptj_eval(
const int n_vocab = hparams.n_vocab;
const int n_rot = hparams.n_rot;
static size_t buf_size = 1024u*1024*1024;
static void * buf = malloc(buf_size);
static size_t buf_size = 1024u*MB;
if (!model.buf.addr || model.buf.size < buf_size)
model.buf.resize(buf_size);
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
if (mem_per_token > 0 && mem_per_token*N > model.buf.size) {
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, model.buf.size, buf_size_new);
// reallocate
buf_size = buf_size_new;
buf = realloc(buf, buf_size);
if (buf == nullptr) {
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
model.buf.resize(buf_size_new);
if (model.buf.addr == nullptr) {
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, model.buf.size);
return false;
}
}
struct ggml_init_params params = {
.mem_size = buf_size,
.mem_buffer = buf,
.mem_size = model.buf.size,
.mem_buffer = model.buf.addr,
};
struct ggml_context * ctx0 = ggml_init(params);
@ -415,8 +447,8 @@ bool gptj_eval(
// store key and value to memory
if (N >= 1) {
struct ggml_tensor * k = ggml_view_1d(ctx0, model.memory_k, N*n_embd, (ggml_element_size(model.memory_k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_1d(ctx0, model.memory_v, N*n_embd, (ggml_element_size(model.memory_v)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * k = ggml_view_1d(ctx0, model.kv_self.k, N*n_embd, (ggml_element_size(model.kv_self.k)*n_embd)*(il*n_ctx + n_past));
struct ggml_tensor * v = ggml_view_1d(ctx0, model.kv_self.v, N*n_embd, (ggml_element_size(model.kv_self.v)*n_embd)*(il*n_ctx + n_past));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
@ -437,7 +469,7 @@ bool gptj_eval(
ggml_permute(ctx0,
ggml_rope(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_k)*n_embd),
ggml_view_1d(ctx0, model.kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.k)*n_embd),
n_embd/n_head, n_head, n_past + N),
n_past, n_rot, 1),
0, 2, 1, 3);
@ -463,10 +495,10 @@ bool gptj_eval(
ggml_cpy(ctx0,
ggml_permute(ctx0,
ggml_reshape_3d(ctx0,
ggml_view_1d(ctx0, model.memory_v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.memory_v)*n_embd),
ggml_view_1d(ctx0, model.kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.v)*n_embd),
n_embd/n_head, n_head, n_past + N),
1, 2, 0, 3),
ggml_new_tensor_3d(ctx0, model.memory_v->type, n_past + N, n_embd/n_head, n_head));
ggml_new_tensor_3d(ctx0, model.kv_self.v->type, n_past + N, n_embd/n_head, n_head));
// KQV = transpose(V) * KQ_soft_max
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);
@ -560,7 +592,7 @@ bool gptj_eval(
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
if (mem_per_token == 0) {
if (mem_per_token == 0 && N != 0) {
mem_per_token = ggml_used_mem(ctx0)/N;
}
//printf("used_mem = %zu\n", ggml_used_mem(ctx0));
@ -569,3 +601,112 @@ bool gptj_eval(
return true;
}
#define GPTJ_MAX_RNG_STATE 64*1024
size_t gptj_get_state_size(const gptj_model &model)
{
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
// for reference, std::mt19937(1337) serializes to 6701 bytes.
const size_t s_rng_size = sizeof(size_t);
const size_t s_rng = GPTJ_MAX_RNG_STATE;
const size_t s_kv_size = sizeof(size_t);
const size_t s_kv_ntok = sizeof(int);
const size_t s_kv = model.kv_self.buf.size;
const size_t s_total = (
+ s_rng_size
+ s_rng
+ s_kv_size
+ s_kv_ntok
+ s_kv
);
fflush(stdout);
return s_total;
}
size_t gptj_copy_state_data(const gptj_model &model, const std::mt19937 &rng, uint8_t *dest)
{
uint8_t * out = dest;
fflush(stdout);
// copy rng
{
std::stringstream rng_ss;
rng_ss << rng;
const size_t rng_size = rng_ss.str().size();
char rng_buf[GPTJ_MAX_RNG_STATE];
memset(&rng_buf[0], 0, GPTJ_MAX_RNG_STATE);
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
memcpy(out, &rng_size, sizeof(rng_size)); out += sizeof(rng_size);
memcpy(out, &rng_buf[0], GPTJ_MAX_RNG_STATE); out += GPTJ_MAX_RNG_STATE;
}
// copy kv cache
{
const size_t kv_size = model.kv_self.buf.size;
const int kv_ntok = model.kv_self.n;
memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
if (kv_size) {
memcpy(out, model.kv_self.buf.addr, kv_size); out += kv_size;
}
}
const size_t written = out - dest;
assert(written == expected);
fflush(stdout);
return written;
}
size_t gptj_set_state_data(gptj_model *model, std::mt19937 *rng, const uint8_t *src)
{
const uint8_t * in = src;
// set rng
{
size_t rng_size;
char rng_buf[GPTJ_MAX_RNG_STATE];
memcpy(&rng_size, in, sizeof(rng_size)); in += sizeof(rng_size);
memcpy(&rng_buf[0], in, GPTJ_MAX_RNG_STATE); in += GPTJ_MAX_RNG_STATE;
std::stringstream rng_ss;
rng_ss.str(std::string(&rng_buf[0], rng_size));
rng_ss >> *rng;
assert(rng_ss.fail() == false);
}
// set kv cache
{
size_t kv_size;
int kv_ntok;
memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size);
memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
if (kv_size) {
assert(model->kv_self.buf.size == kv_size);
void * k_data = model->kv_self.k->data; // remember data pointers
void * v_data = model->kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
memcpy(model->kv_self.buf.addr, in, kv_size); in += kv_size;
model->kv_self.k->data = k_data; // restore correct data pointers
model->kv_self.v->data = v_data;
}
model->kv_self.n = kv_ntok;
}
const size_t nread = in - src;
assert(nread == expected);
fflush(stdout);
return nread;
}

View file

@ -39,6 +39,38 @@ struct gptj_layer {
struct ggml_tensor * c_mlp_proj_b;
};
struct gptj_buffer {
uint8_t * addr = NULL;
size_t size = 0;
void resize(size_t size) {
delete[] addr;
addr = new uint8_t[size];
this->size = size;
}
~gptj_buffer() {
delete[] addr;
}
};
struct gptj_kv_cache {
struct ggml_tensor * k;
struct ggml_tensor * v;
struct ggml_context * ctx = NULL;
gptj_buffer buf;
int n; // number of tokens currently in the cache
~gptj_kv_cache() {
if (ctx) {
ggml_free(ctx);
}
}
};
struct gptj_model {
gptj_hparams hparams;
@ -54,16 +86,26 @@ struct gptj_model {
std::vector<gptj_layer> layers;
// key + value memory
struct ggml_tensor * memory_k;
struct ggml_tensor * memory_v;
struct gptj_kv_cache kv_self;
//
struct ggml_context * ctx;
std::map<std::string, struct ggml_tensor *> tensors;
gptj_buffer buf;
~gptj_model() {
if (ctx) {
ggml_free(ctx);
}
}
};
bool gptj_model_load(const std::string &fname, std::istream &fin, gptj_model & model, gpt_vocab & vocab);
bool gptj_model_load(const std::string &fname, gptj_model &model, gpt_vocab & vocab);
bool gptj_eval(const gptj_model& model, const int n_threads, const int n_past,
const std::vector<gpt_vocab::id>& embd_inp, std::vector<float>& embd_w, size_t& mem_per_token);
bool gptj_model_load(const std::string & fname, gptj_model & model, gpt_vocab & vocab);
bool gptj_eval(gptj_model& model, const int n_threads, const int n_past, const std::vector<gpt_vocab::id>& embd_inp, std::vector<float>& embd_w, size_t& mem_per_token);
size_t gptj_get_state_size(const gptj_model &model);
size_t gptj_copy_state_data(const gptj_model &model, const std::mt19937 &rng, uint8_t *dest);
size_t gptj_set_state_data(gptj_model *model, std::mt19937 *rng, const uint8_t *src);
#endif // GPTJ_HPP

View file

@ -27,6 +27,13 @@ struct gpt_params {
std::string model = "models/gpt-2-117M/ggml-model.bin"; // model path
std::string prompt;
};
bool gpt_params_parse(int argc, char ** argv, gpt_params & params);
void gpt_print_usage(int argc, char ** argv, const gpt_params & params);
std::string gpt_random_prompt(std::mt19937 & rng);
//
// Vocab utils
//
@ -41,6 +48,9 @@ struct gpt_vocab {
void replace(std::string & str, const std::string & needle, const std::string & replacement);
// poor-man's JSON parsing
std::map<std::string, int32_t> json_parse(const std::string & fname);
// split text into tokens
//
// ref: https://github.com/openai/gpt-2/blob/a74da5d99abaaba920de8131d64da2862a8f213b/src/encoder.py#L53

View file

@ -222,9 +222,11 @@ public:
return get_state()->tokens.size();
}
//TODO: The following functions are just a stub implementations and should be implemented properly asap
LM_SCHEDULABLE(void) create_savestate(Savestate &sv) const override {
auto& state = get_state();
sv.buf.resize(gptj_get_state_size(state->model));
gptj_copy_state_data(state->model, state->rng, sv.buf.data());
sv.tokens = state->tokens;
sv.prompt = state->prompt;
sv.ctx = generic_state;
LM_CORETURN;
@ -233,36 +235,67 @@ public:
auto& state = get_state();
if (sv.ctx != generic_state)
throw Exception("Savestate does not match context");
reinit();
LM_COAWAIT append(sv.prompt);
gptj_set_state_data(&state->model, &state->rng, sv.buf.data());
state->tokens = sv.tokens;
state->prompt = sv.prompt;
LM_CORETURN;
}
LM_SCHEDULABLE(void) serialize(std::ostream &o) const override {
auto& state = get_state();
size_t size = state->prompt.size();
o.write(reinterpret_cast<const char*>(&size), sizeof(size));
if (!o.write(state->prompt.data(), size)) {
// Get state size
auto state_size = gptj_get_state_size(state->model);
// Write sizes
for (const uint32_t s : {state->tokens.size(), state->prompt.size(), state_size}) {
if (!o.write(reinterpret_cast<const char*>(&s), sizeof(s))) {
throw Exception("Failed to serialize data sizes");
}
}
// Write tokens
if (!o.write(reinterpret_cast<const char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
throw Exception("Failed to serialize tokens");
}
// Write prompt
if (!o.write(state->prompt.data(), state->prompt.size())) {
throw Exception("Failed to serialize prompt");
}
// Write state
std::vector<uint8_t> state_buf(state_size);
gptj_copy_state_data(state->model, state->rng, state_buf.data());
if (!o.write(reinterpret_cast<const char*>(state_buf.data()), state_size)) {
throw Exception("Failed to serialize state");
}
LM_CORETURN;
}
LM_SCHEDULABLE(void) deserialize(std::istream &i) override {
auto& state = get_state();
std::string prompt;
size_t size;
if (!i.read(reinterpret_cast<char*>(&size), sizeof(size))) {
throw Exception("Failed to deserialize prompt size");
uint32_t embd_size, prompt_size, state_size;
// Initialization to prevent compiler complaints
embd_size = prompt_size = state_size = 0;
// Read sizes
for (uint32_t *s : {&embd_size, &prompt_size, &state_size}) {
if (!i.read(reinterpret_cast<char*>(s), sizeof(*s))) {
throw Exception("Failed to deserialize data sizes");
}
}
prompt.resize(size);
if (!i.read(prompt.data(), size)) {
// Read tokens
state->tokens.resize(embd_size);
if (!i.read(reinterpret_cast<char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
throw Exception("Failed to deserialize tokens");
}
// Read prompt
state->prompt.resize(prompt_size);
if (!i.read(state->prompt.data(), state->prompt.size())) {
throw Exception("Failed to deserialize prompt");
}
reinit();
LM_COAWAIT append(prompt);
// Read state
std::vector<uint8_t> state_buf(state_size);
if (!i.read(reinterpret_cast<char*>(state_buf.data()), state_buf.size())) {
throw Exception("Failed to deserialize state");
}
gptj_set_state_data(&state->model, &state->rng, state_buf.data());
LM_CORETURN;
}
const std::string &get_prompt() const override {
return get_state()->prompt;
}

@ -1 +1 @@
Subproject commit 0b2da20538d01926b77ea237dd1c930c4d20b686
Subproject commit 0e018fe008eacebdbcfa2d61b6c988c245c961cd

View file

@ -45,7 +45,5 @@ PYBIND11_MODULE(libjustlm_py, m) {
.def("get_or_create_inference", &InferencePool::create_inference, py::arg("id"), py::arg("weights_path"), py::arg("parameters"), py::return_value_policy::reference_internal)
.def("delete_inference", &InferencePool::delete_inference, py::arg("id"))
.def("store_all", &InferencePool::store_all)
.def("set_store_on_destruct", &InferencePool::set_store_on_destruct)
.def("is_stored_on_destruction", &InferencePool::is_stored_on_destruction)
.def("get_active_slot_ids", &InferencePool::get_active_slot_ids);
}