1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00

Updated MPT implementation

This commit is contained in:
niansa 2023-05-16 23:49:43 +02:00
parent a98784aa53
commit ddd130b2d9
6 changed files with 76 additions and 233 deletions

View file

@ -102,7 +102,7 @@ std::map<std::string, int32_t> json_parse(const std::string & fname) {
return result;
}
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
std::vector<gpt_vocab::id> gpt_tokenize_inner(const gpt_vocab & vocab, const std::string & text) {
std::vector<std::string> words;
// first split the text into words
@ -157,6 +157,47 @@ std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::stri
return tokens;
}
std::string regex_escape(const std::string &s) {
static const std::regex metacharacters(R"([\.\^\$\-\+\(\)\[\]\{\}\|\?\*])");
return std::regex_replace(s, metacharacters, "\\$&");
}
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
// Generate the subpattern from the special_tokens vector if it's not empty
if (!vocab.special_tokens.empty()) {
std::vector<gpt_vocab::id> out;
std::vector<std::string> chunks;
std::string str = text;
std::string special_tokens_subpattern;
for (const auto &token : vocab.special_tokens) {
if (!special_tokens_subpattern.empty()) {
special_tokens_subpattern += "|";
}
special_tokens_subpattern += regex_escape(token);
}
std::regex re(special_tokens_subpattern);
std::smatch m;
while (std::regex_search(str, m, re)) {
auto tok = vocab.token_to_id.find(m.str());
if (tok != vocab.token_to_id.end()) {
auto tokid = tok->second;
auto pfxtoks = gpt_tokenize_inner(vocab, m.prefix());
out.insert(out.end(), pfxtoks.begin(), pfxtoks.end());
out.push_back(tokid);
str = m.suffix();
}
}
if (!str.empty()) {
auto tokrest = gpt_tokenize_inner(vocab, str);
out.insert(out.end(), tokrest.begin(), tokrest.end());
}
return out;
} else {
return gpt_tokenize_inner(vocab, text);
}
}
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
@ -177,7 +218,7 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
}
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const size_t actualVocabSize,
const int32_t * last_n_tokens_data,
int last_n_tokens_size,
const std::vector<float> logits,
@ -186,7 +227,7 @@ gpt_vocab::id gpt_sample_top_k_top_p(
double temp,
float repeat_penalty,
std::mt19937 & rng) {
int n_logits = vocab.id_to_token.size();
int n_logits = actualVocabSize;
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
const auto * plogits = logits.data() + logits.size() - n_logits;

View file

@ -44,6 +44,11 @@ struct gpt_vocab {
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
std::vector<std::string> special_tokens;
void add_special_token(const std::string &token) {
special_tokens.push_back(token);
}
};
void replace(std::string & str, const std::string & needle, const std::string & replacement);
@ -74,7 +79,7 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
// TODO: not sure if this implementation is correct
//
gpt_vocab::id gpt_sample_top_k_top_p(
const gpt_vocab & vocab,
const size_t actualVocabSize,
const int32_t * last_n_tokens_data,
int last_n_tokens_size,
const std::vector<float> logits,

View file

@ -170,7 +170,7 @@ public:
unsigned eos_count = 0;
while (!abort && !ends_with(fres, end)) {
// Sample top p and top k
auto id = gpt_sample_top_k_top_p(state->vocab, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-params.n_repeat_last):nullptr, params.n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-params.n_repeat_last):nullptr, params.n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
if (id == 50256) {
if (eos_count++ == params.eos_ignores) {

View file

@ -12,7 +12,7 @@ class MPTInference final : public Inference {
std::string weights_path;
struct State {
mpt_vocab vocab;
gpt_vocab vocab;
mpt_model model;
std::string prompt; // Mostly here for easy "debugging"
std::vector<int> tokens;
@ -148,7 +148,7 @@ public:
const auto old_token_count = state->tokens.size();
// Run tokenizer
const auto tokens = mpt_tokenize(state->vocab, prompt);
const auto tokens = gpt_tokenize(state->vocab, prompt);
state->tokens.insert(
state->tokens.end(),
std::make_move_iterator(tokens.begin()),
@ -181,7 +181,7 @@ public:
abort = true;
continue;
}
id = mpt_tokenize(state->vocab, "\n")[0];
id = gpt_tokenize(state->vocab, "\n")[0];
state->tokens.push_back(id);
} else {
// Add token

View file

@ -1,4 +1,5 @@
#include "mpt.hpp"
#include "../g4a-common.hpp"
#include <cassert>
#include <cmath>
@ -53,13 +54,8 @@ static bool kv_cache_init(
return true;
}
std::string regex_escape(const std::string &s) {
static const std::regex metacharacters(R"([\.\^\$\-\+\(\)\[\]\{\}\|\?\*])");
return std::regex_replace(s, metacharacters, "\\$&");
}
// load the model's weights from a stream
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, mpt_vocab & vocab) {
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, gpt_vocab & vocab) {
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
// verify magic
@ -123,8 +119,6 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
vocab.id_to_token[i] = word;
}
// TODO: this only kind-of works, the gpt_tokenize can still incorrectly
// tokenize special tokens
if(special) {
vocab.add_special_token(word);
}
@ -187,9 +181,9 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
// create the ggml context
{
struct ggml_init_params params = {
ctx_size,
NULL,
false,
.mem_size = ctx_size,
.mem_buffer = NULL,
.no_alloc = false,
};
model.ctx = ggml_init(params);
@ -332,7 +326,7 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
}
// load the model's weights from a file path
bool mpt_model_load(const std::string & fname, mpt_model & model, mpt_vocab & vocab) {
bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vocab) {
auto fin = std::ifstream(fname, std::ios::binary);
if (!fin) {
@ -362,23 +356,25 @@ bool mpt_eval(
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
const size_t init_buf_size = 1024_MB;
if (!model.buf.addr || model.buf.size < init_buf_size)
model.buf.resize(init_buf_size);
if (mem_per_token > 0 && mem_per_token*N > model.eval_buf_size) {
if (mem_per_token > 0 && mem_per_token*N > model.buf.size) {
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
// printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, model.buf.size, buf_size_new);
// reallocate
model.eval_buf_size = buf_size_new;
model.eval_buf = realloc(model.eval_buf, model.eval_buf_size);
if (model.eval_buf == nullptr) {
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, model.eval_buf_size);
model.buf.resize(buf_size_new);
if (model.buf.addr == nullptr) {
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, model.buf.size);
return false;
}
}
struct ggml_init_params params = {
model.eval_buf_size,
model.eval_buf,
model.buf.size,
model.buf.addr,
false
};
@ -520,10 +516,12 @@ bool mpt_eval(
out = ggml_mul_mat(ctx0, model.wte, out);
}
// run the computation
ggml_build_forward_expand(&gf, out);
ggml_graph_compute (ctx0, &gf);
// return result for just the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(out) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
@ -538,98 +536,6 @@ bool mpt_eval(
return true;
}
std::vector<int> mpt_tokenize_inner(const mpt_vocab & vocab, const std::string & text) {
// taken from stablelm example in ggml
// they both use the gpt-neox tokenizer
// not sure if this entirely right?
std::vector<std::string> words;
// first split the text into words
{
std::string str = text;
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
std::regex re(pat);
std::smatch m;
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
}
// find the longest tokens that form the words:
std::vector<mpt_vocab::id> tokens;
for (const auto & word : words) {
if (word.size() == 0) continue;
int i = 0;
int n = word.size();
while (i < n) {
int j = n;
while (j > i) {
auto it = vocab.token_to_id.find(word.substr(i, j-i));
if (it != vocab.token_to_id.end()) {
tokens.push_back(it->second);
i = j;
break;
}
--j;
}
if (i == n) {
break;
}
if (j == i) {
auto sub = word.substr(i, 1);
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
tokens.push_back(vocab.token_to_id.at(sub));
} else {
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
}
++i;
}
}
}
return tokens;
}
std::vector<mpt_vocab::id> mpt_tokenize(const mpt_vocab & vocab, const std::string & text) {
// Generate the subpattern from the special_tokens vector if it's not empty
if (!vocab.special_tokens.empty()) {
std::vector<mpt_vocab::id> out;
std::vector<std::string> chunks;
std::string str = text;
std::string special_tokens_subpattern;
for (const auto &token : vocab.special_tokens) {
if (!special_tokens_subpattern.empty()) {
special_tokens_subpattern += "|";
}
special_tokens_subpattern += regex_escape(token);
}
std::regex re(special_tokens_subpattern);
std::smatch m;
while (std::regex_search(str, m, re)) {
auto tok = vocab.token_to_id.find(m.str());
if (tok != vocab.token_to_id.end()) {
auto tokid = tok->second;
auto pfxtoks = mpt_tokenize_inner(vocab, m.prefix());
out.insert(out.end(), pfxtoks.begin(), pfxtoks.end());
out.push_back(tokid);
str = m.suffix();
}
}
if (!str.empty()) {
auto tokrest = mpt_tokenize_inner(vocab, str);
out.insert(out.end(), tokrest.begin(), tokrest.end());
}
return out;
} else {
return mpt_tokenize_inner(vocab, text);
}
}
#define MPT_MAX_RNG_STATE 64*1024
@ -690,103 +596,6 @@ size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint
return written;
}
mpt_vocab::id mpt_sample_top_k_top_p(
const size_t actualVocabSize,
const int32_t * last_n_tokens_data,
int last_n_tokens_size,
const std::vector<float> logits,
int top_k,
double top_p,
double temp,
float repeat_penalty,
std::mt19937 & rng) {
int n_logits = actualVocabSize;
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
const auto * plogits = logits.data() + logits.size() - n_logits;
std::vector<std::pair<double, mpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);
{
const float scale = 1.0f/temp;
for (int i = 0; i < n_logits; ++i) {
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if (plogits[i] < 0.0f) {
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
} else {
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
}
} else {
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
}
}
}
// find the top K tokens
std::partial_sort(
logits_id.begin(),
logits_id.begin() + top_k, logits_id.end(),
[](const std::pair<double, mpt_vocab::id> & a, const std::pair<double, mpt_vocab::id> & b) {
return a.first > b.first;
});
logits_id.resize(top_k);
double maxl = -INFINITY;
for (const auto & kv : logits_id) {
maxl = std::max(maxl, kv.first);
}
// compute probs for the top K tokens
std::vector<double> probs;
probs.reserve(logits_id.size());
double sum = 0.0;
for (const auto & kv : logits_id) {
double p = exp(kv.first - maxl);
probs.push_back(p);
sum += p;
}
// normalize the probs
for (auto & p : probs) {
p /= sum;
}
if (top_p < 1.0f) {
double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) {
cumsum += probs[i];
if (cumsum >= top_p) {
top_k = i + 1;
probs.resize(top_k);
logits_id.resize(top_k);
break;
}
}
cumsum = 1.0/cumsum;
for (int i = 0; i < (int) probs.size(); i++) {
probs[i] *= cumsum;
}
}
//printf("\n");
//for (int i = 0; i < (int) probs.size(); i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
//}
//exit(0);
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
return logits_id[idx].second;
}
size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *src)
{
const uint8_t * in = src;

View file

@ -1,5 +1,7 @@
#ifndef MPT_H
#define MPT_H
#include "../g4a-common.hpp"
#include <string>
#include <vector>
#include <map>
@ -99,24 +101,10 @@ struct mpt_model {
}
};
struct mpt_vocab {
using id = int32_t;
using token = std::string;
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
std::vector<std::string> special_tokens;
void add_special_token(const std::string &token) {
special_tokens.push_back(token);
}
};
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, mpt_vocab & vocab);
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, gpt_vocab& vocab);
bool mpt_eval(mpt_model& model, const int n_threads, const int n_past, const std::vector<int>& embd_inp, std::vector<float>& embd_w, size_t& mem_per_token);
std::vector<mpt_vocab::id> mpt_tokenize(const mpt_vocab & vocab, const std::string & text);
mpt_vocab::id mpt_sample_top_k_top_p(const size_t actualVocabSize, const int32_t *last_n_tokens_data, int last_n_tokens_size, const std::vector<float> logits, int top_k, double top_p, double temp, float repeat_penalty, std::mt19937& rng);
gpt_vocab::id mpt_sample_top_k_top_p(const size_t actualVocabSize, const int32_t *last_n_tokens_data, int last_n_tokens_size, const std::vector<float> logits, int top_k, double top_p, double temp, float repeat_penalty, std::mt19937& rng);
size_t mpt_get_state_size(const mpt_model &model);
size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937& rng, uint8_t *dest);
size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *src);