mirror of
https://gitlab.com/niansa/libjustlm.git
synced 2025-03-06 20:49:17 +01:00
Initial MPT support (buggy)
This commit is contained in:
parent
f3a9092ca5
commit
6d2910b7b9
6 changed files with 1334 additions and 0 deletions
|
@ -19,6 +19,7 @@ add_library(libjustlm STATIC
|
|||
include/justlm.hpp justlm.cpp
|
||||
justlm_llama.hpp
|
||||
g4a-common.cpp g4a-common.hpp
|
||||
justlm_mpt.hpp mpt/mpt.cpp mpt/mpt.hpp
|
||||
justlm_gptj.hpp gptj/gptj.cpp gptj/gptj.hpp
|
||||
include/justlm_pool.hpp justlm_pool.cpp
|
||||
)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "justlm.hpp"
|
||||
#include "justlm_llama.hpp"
|
||||
#include "justlm_gptj.hpp"
|
||||
#include "justlm_mpt.hpp"
|
||||
|
||||
#include <fstream>
|
||||
|
||||
|
@ -15,6 +16,9 @@ LM::Inference *LM::Inference::construct(const std::string &weights_path, const P
|
|||
if (magic == 0x67676d6c) {
|
||||
f.seekg(0);
|
||||
return new GPTJInference(weights_path, f, p);
|
||||
} else if (magic == 0x67676d6d) {
|
||||
f.seekg(0);
|
||||
return new MPTInference(weights_path, f, p);
|
||||
} else {
|
||||
f.close();
|
||||
return new LLaMaInference(weights_path, p);
|
||||
|
|
314
justlm_mpt.hpp
Normal file
314
justlm_mpt.hpp
Normal file
|
@ -0,0 +1,314 @@
|
|||
#include "justlm.hpp"
|
||||
|
||||
#include <fstream>
|
||||
#include <random>
|
||||
#include <cstring>
|
||||
#include "mpt/mpt.hpp"
|
||||
#include "g4a-common.hpp"
|
||||
|
||||
|
||||
namespace LM {
|
||||
class MPTInference final : public Inference {
|
||||
std::string weights_path;
|
||||
|
||||
struct State {
|
||||
mpt_vocab vocab;
|
||||
mpt_model model;
|
||||
std::string prompt; // Mostly here for easy "debugging"
|
||||
std::vector<int> tokens;
|
||||
std::vector<float> logits;
|
||||
size_t mem_per_token = 0;
|
||||
std::mt19937 rng;
|
||||
|
||||
State(int32_t seed) : rng(seed) {}
|
||||
};
|
||||
|
||||
State*& get_state() LM_NOEXCEPTDECL {
|
||||
return *reinterpret_cast<State**>(&generic_state);
|
||||
}
|
||||
State* const& get_state() const LM_NOEXCEPTDECL {
|
||||
return *reinterpret_cast<State* const*>(&generic_state);
|
||||
}
|
||||
|
||||
LM_ERRBOOL init(const std::string& _weights_path, std::ifstream& f) LM_NOEXCEPTDECL {
|
||||
auto& state = get_state();
|
||||
weights_path = _weights_path;
|
||||
|
||||
// Allocate state
|
||||
state = new State(params.seed);
|
||||
|
||||
// Load model
|
||||
if (!mpt_model_load(weights_path, f, state->model, state->vocab)) {
|
||||
LM_THROW("Failed to initialize mpt_ from file", LM_BOOL_ERROR);
|
||||
}
|
||||
|
||||
// Calculate memory required per token
|
||||
static std::vector<gpt_vocab::id> p_instruct;
|
||||
static std::vector<gpt_vocab::id> r_instruct;
|
||||
mpt_eval(state->model, params.n_threads, 0, { 0, 1, 2, 3 }, state->logits, state->mem_per_token);
|
||||
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
void deinit() LM_NOEXCEPTDECL {
|
||||
auto& state = get_state();
|
||||
|
||||
if (state) {
|
||||
if (state->model.ctx) ggml_free(state->model.ctx); //TODO: Is that enough?
|
||||
delete state;
|
||||
}
|
||||
}
|
||||
|
||||
// This function reduces the size of our tokens vector according to some parameters
|
||||
// All tokens will be evaluated if scrolling was needed and true will be returned
|
||||
LM_SCHEDULABLE(bool) window_scroll() LM_NOEXCEPTDECL {
|
||||
auto &state = get_state();
|
||||
// Check that we actually need to scroll
|
||||
if (state->tokens.size() <= params.n_ctx) {
|
||||
// Nope
|
||||
LM_CORETURN false;
|
||||
}
|
||||
// Start scrolling
|
||||
if (params.scroll_keep > 0.0f) {
|
||||
// "Scroll" down the context window...
|
||||
unsigned keep_count = float(state->tokens.size() - params.n_ctx_window_top_bar) * 0.4f; // We keep about 40%
|
||||
// Get vector of tokens to keep
|
||||
std::vector<int> tokens_in_view(state->tokens.end()-keep_count, state->tokens.end());
|
||||
// Cut down tokens vector size
|
||||
state->tokens.resize(params.n_ctx_window_top_bar+keep_count);
|
||||
// Overwrite tokens after top bar with tokens in view
|
||||
std::memcpy(state->tokens.data()+params.n_ctx_window_top_bar, tokens_in_view.data(), tokens_in_view.size()*sizeof(int));
|
||||
} else {
|
||||
// Cut down tokens vector size to top bar
|
||||
state->tokens.resize(params.n_ctx_window_top_bar);
|
||||
}
|
||||
// Evaluate tokens
|
||||
LM_ERROR_FORWARD(LM_COAWAIT evaluate_tokens(0, on_scroll));
|
||||
LM_CORETURN true;
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL {
|
||||
auto& state = get_state();
|
||||
|
||||
// Evaluate tokens in batches
|
||||
unsigned it;
|
||||
for (it = starting_offset; ; it += params.n_batch) {
|
||||
if (it + params.n_batch >= ssize_t(state->tokens.size())) break;
|
||||
|
||||
// Evaluate
|
||||
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+params.n_batch);
|
||||
if (!mpt_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
|
||||
LM_COTHROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
|
||||
}
|
||||
|
||||
// Tick
|
||||
if (on_tick) {
|
||||
// Calculate progress
|
||||
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
|
||||
// Tick and yield
|
||||
if (!on_tick(progress)) LM_CORETURN LM_BOOL_SUCCESS;
|
||||
else if (!LM_TASKYIELD) LM_CORETURN LM_BOOL_SUCCESS;
|
||||
}
|
||||
}
|
||||
|
||||
// Evaluate remaining tokens
|
||||
if (it < state->tokens.size()) {
|
||||
for (; it != state->tokens.size(); it++) {
|
||||
//TODO: This is extremely inefficient! Don't do that...
|
||||
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+1);
|
||||
if (!mpt_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
|
||||
LM_COTHROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Notify about completion
|
||||
if (on_tick) on_tick(100.f);
|
||||
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
public:
|
||||
MPTInference(const std::string& weights_path, std::ifstream& f, const Params& p) : Inference(p) {
|
||||
init(weights_path, f);
|
||||
}
|
||||
~MPTInference() LM_NOEXCEPTDECL override {
|
||||
deinit();
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
|
||||
// Append to current prompt
|
||||
state->prompt.append(prompt);
|
||||
|
||||
// Resize buffer for tokens
|
||||
const auto old_token_count = state->tokens.size();
|
||||
|
||||
// Run tokenizer
|
||||
const auto tokens = mpt_tokenize(state->vocab, prompt);
|
||||
state->tokens.insert(
|
||||
state->tokens.end(),
|
||||
std::make_move_iterator(tokens.begin()),
|
||||
std::make_move_iterator(tokens.end())
|
||||
);
|
||||
|
||||
// Make sure token limit isn't being hit
|
||||
if (LM_COAWAIT window_scroll()) {
|
||||
// That function already has evaluated our tokens since scrolling was needed
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
// Evaluate new tokens
|
||||
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
|
||||
}
|
||||
|
||||
/*mpt_vocab::id mpt_sample_top_k_top_p(
|
||||
const mpt_vocab & vocab,
|
||||
const size_t actualVocabSize,
|
||||
const int32_t * last_n_tokens_data,
|
||||
int last_n_tokens_size,
|
||||
const std::vector<float> logits,
|
||||
int top_k,
|
||||
double top_p,
|
||||
double temp,
|
||||
float repeat_penalty,
|
||||
std::mt19937 & rng)
|
||||
*/
|
||||
|
||||
LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
std::string fres;
|
||||
|
||||
// Loop until done
|
||||
bool abort = false;
|
||||
unsigned eos_count = 0;
|
||||
while (!abort && !ends_with(fres, end)) {
|
||||
// Sample top p and top k
|
||||
auto id = mpt_sample_top_k_top_p(state->vocab, state->model.hparams.n_vocab, state->tokens.data(), state->tokens.size(), state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
|
||||
|
||||
if (id == state->vocab.token_to_id["<|im_end|>"]) {
|
||||
if (eos_count++ == params.eos_ignores) {
|
||||
abort = true;
|
||||
continue;
|
||||
}
|
||||
id = mpt_tokenize(state->vocab, "\n")[0];
|
||||
state->tokens.push_back(id);
|
||||
} else {
|
||||
// Add token
|
||||
state->tokens.push_back(id);
|
||||
}
|
||||
|
||||
// Make sure token limit isn't being hit
|
||||
LM_COAWAIT window_scroll();
|
||||
|
||||
// Get token as string
|
||||
const auto str = state->vocab.id_to_token[id];
|
||||
|
||||
// Append string to function result
|
||||
fres.append(str);
|
||||
|
||||
// Evaluate token
|
||||
// TODO: Respect batch size
|
||||
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
|
||||
if (!mpt_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
|
||||
LM_COTHROW("Failed to evaluate new tokens", "");
|
||||
}
|
||||
|
||||
// Tick
|
||||
if (on_tick && !on_tick(str.c_str())) abort = true;
|
||||
else if (!LM_TASKYIELD) abort = true;
|
||||
}
|
||||
|
||||
// Create final string TODO: Could be optimized
|
||||
state->prompt.append(fres);
|
||||
if (!abort) {
|
||||
fres = std::string(fres.data(), fres.size()-end.size());
|
||||
}
|
||||
|
||||
// Return final string
|
||||
LM_CORETURN fres;
|
||||
}
|
||||
|
||||
unsigned get_context_size() const noexcept override {
|
||||
return get_state()->tokens.size();
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
sv.buf.resize(mpt_get_state_size(state->model));
|
||||
mpt_copy_state_data(state->model, state->rng, sv.buf.data());
|
||||
sv.tokens = state->tokens;
|
||||
sv.prompt = state->prompt;
|
||||
sv.ctx = generic_state;
|
||||
LM_CORETURN LM_BOOL_SUCCESS ;
|
||||
}
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
if (sv.ctx != generic_state)
|
||||
LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR);
|
||||
mpt_set_state_data(&state->model, &state->rng, sv.buf.data());
|
||||
state->tokens = sv.tokens;
|
||||
state->prompt = sv.prompt;
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
// Get state size
|
||||
auto state_size = mpt_get_state_size(state->model);
|
||||
// Write sizes
|
||||
for (const uint32_t s : {state->tokens.size(), state->prompt.size(), state_size}) {
|
||||
if (!o.write(reinterpret_cast<const char*>(&s), sizeof(s))) {
|
||||
LM_COTHROW("Failed to serialize data sizes", LM_BOOL_ERROR);
|
||||
}
|
||||
}
|
||||
// Write tokens
|
||||
if (!o.write(reinterpret_cast<const char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
||||
LM_COTHROW("Failed to serialize tokens", LM_BOOL_ERROR);
|
||||
}
|
||||
// Write prompt
|
||||
if (!o.write(state->prompt.data(), state->prompt.size())) {
|
||||
LM_COTHROW("Failed to serialize prompt", LM_BOOL_ERROR);
|
||||
}
|
||||
// Write state
|
||||
std::vector<uint8_t> state_buf(state_size);
|
||||
mpt_copy_state_data(state->model, state->rng, state_buf.data());
|
||||
if (!o.write(reinterpret_cast<const char*>(state_buf.data()), state_size)) {
|
||||
LM_COTHROW("Failed to serialize state", LM_BOOL_ERROR);
|
||||
}
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
}
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream &i) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
uint32_t embd_size, promptsize, state_size;
|
||||
// Initialization to prevent compiler complaints
|
||||
embd_size = promptsize = state_size = 0;
|
||||
// Read sizes
|
||||
for (uint32_t *s : {&embd_size, &promptsize, &state_size}) {
|
||||
if (!i.read(reinterpret_cast<char*>(s), sizeof(*s))) {
|
||||
LM_COTHROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
|
||||
}
|
||||
}
|
||||
// Read tokens
|
||||
state->tokens.resize(embd_size);
|
||||
if (!i.read(reinterpret_cast<char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
||||
LM_COTHROW("Failed to deserialize tokens", LM_BOOL_ERROR);
|
||||
}
|
||||
// Read prompt
|
||||
state->prompt.resize(promptsize);
|
||||
if (!i.read(state->prompt.data(), state->prompt.size())) {
|
||||
LM_COTHROW("Failed to deserialize prompt", LM_BOOL_ERROR);
|
||||
}
|
||||
// Read state
|
||||
std::vector<uint8_t> state_buf(state_size);
|
||||
if (!i.read(reinterpret_cast<char*>(state_buf.data()), state_buf.size())) {
|
||||
LM_COTHROW("Failed to deserialize state", LM_BOOL_ERROR);
|
||||
}
|
||||
mpt_set_state_data(&state->model, &state->rng, state_buf.data());
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
}
|
||||
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
|
||||
return get_state()->prompt;
|
||||
}
|
||||
};
|
||||
}
|
15
mpt/LICENSE
Normal file
15
mpt/LICENSE
Normal file
|
@ -0,0 +1,15 @@
|
|||
Copyright 2023 Nomic, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
ADDENDUM:
|
||||
|
||||
Any LLM models that are loaded and used by the application are not themselves
|
||||
subject to this license if indeed they are even copyrightable. The terms of
|
||||
this license apply only to the application software and its accompanying
|
||||
documentation and do not extend to any LLM models, whether created by the
|
||||
author of the application or obtained from third-party sources.
|
883
mpt/mpt.cpp
Normal file
883
mpt/mpt.cpp
Normal file
|
@ -0,0 +1,883 @@
|
|||
#include "mpt.hpp"
|
||||
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <cstring>
|
||||
#include <fstream>
|
||||
#include <map>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <iostream>
|
||||
#include <unistd.h>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
#include <unordered_set>
|
||||
#include <regex>
|
||||
#include <ggml.h>
|
||||
|
||||
static const size_t MB = 1024*1024;
|
||||
|
||||
static bool kv_cache_init(
|
||||
const struct mpt_hparams & hparams,
|
||||
struct mpt_kv_cache & cache,
|
||||
ggml_type wtype,
|
||||
int n_ctx) {
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
|
||||
const int64_t n_mem = (int64_t)n_layer*n_ctx;
|
||||
const int64_t n_elements = n_embd*n_mem;
|
||||
|
||||
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
|
||||
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = cache.buf.size;
|
||||
params.mem_buffer = cache.buf.addr;
|
||||
params.no_alloc = false;
|
||||
|
||||
cache.ctx = ggml_init(params);
|
||||
|
||||
if (!cache.ctx) {
|
||||
fprintf(stderr, "%s: failed to allocate memory for kv cache\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
||||
cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::string regex_escape(const std::string &s) {
|
||||
static const std::regex metacharacters(R"([\.\^\$\-\+\(\)\[\]\{\}\|\?\*])");
|
||||
return std::regex_replace(s, metacharacters, "\\$&");
|
||||
}
|
||||
|
||||
// load the model's weights from a stream
|
||||
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, mpt_vocab & vocab) {
|
||||
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||
|
||||
// verify magic
|
||||
{
|
||||
uint32_t magic;
|
||||
fin.read((char *) &magic, sizeof(magic));
|
||||
if (magic != 0x67676d6d) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// load hparams
|
||||
{
|
||||
auto & hparams = model.hparams;
|
||||
|
||||
fin.read((char *) &hparams.n_vocab, sizeof(hparams.n_vocab));
|
||||
fin.read((char *) &hparams.n_ctx, sizeof(hparams.n_ctx));
|
||||
fin.read((char *) &hparams.n_layer, sizeof(hparams.n_layer));
|
||||
fin.read((char *) &hparams.n_head, sizeof(hparams.n_head));
|
||||
fin.read((char *) &hparams.n_embd, sizeof(hparams.n_embd));
|
||||
fin.read((char *) &hparams.alibi_bias_max, sizeof(hparams.alibi_bias_max));
|
||||
fin.read((char *) &hparams.clip_qkv, sizeof(hparams.clip_qkv));
|
||||
fin.read((char *) &hparams.f16, sizeof(hparams.f16));
|
||||
|
||||
printf("%s: n_vocab = %d\n", __func__, hparams.n_vocab);
|
||||
printf("%s: n_ctx = %d\n", __func__, hparams.n_ctx);
|
||||
printf("%s: n_embd = %d\n", __func__, hparams.n_embd);
|
||||
printf("%s: n_head = %d\n", __func__, hparams.n_head);
|
||||
printf("%s: n_layer = %d\n", __func__, hparams.n_layer);
|
||||
printf("%s: alibi_bias_max = %f\n", __func__, hparams.alibi_bias_max);
|
||||
printf("%s: clip_qkv = %f\n", __func__, hparams.clip_qkv);
|
||||
printf("%s: ftype = %d\n", __func__, hparams.f16);
|
||||
}
|
||||
|
||||
// load vocab
|
||||
{
|
||||
int32_t n_vocab = model.hparams.n_vocab;
|
||||
fin.read((char *) &n_vocab, sizeof(n_vocab));
|
||||
|
||||
if (n_vocab != model.hparams.n_vocab) {
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad vocab size %d != %d)\n",
|
||||
__func__, fname.c_str(), n_vocab, model.hparams.n_vocab);
|
||||
return false;
|
||||
}
|
||||
|
||||
std::string word;
|
||||
for (int i = 0; i < n_vocab; i++) {
|
||||
uint32_t len;
|
||||
fin.read((char *) &len, sizeof(len));
|
||||
bool special = false;
|
||||
if (len & (1<<31)) {
|
||||
len = len &~ (1<<31);
|
||||
special = true;
|
||||
}
|
||||
|
||||
if (len > 0) {
|
||||
word.resize(len);
|
||||
fin.read((char *) word.data(), len);
|
||||
vocab.token_to_id[word] = i;
|
||||
vocab.id_to_token[i] = word;
|
||||
}
|
||||
|
||||
// TODO: this only kind-of works, the gpt_tokenize can still incorrectly
|
||||
// tokenize special tokens
|
||||
if(special) {
|
||||
vocab.add_special_token(word);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for the big tensors, we have the option to store the data in 16-bit floats or quantized
|
||||
// in order to save memory and also to speed up the computation
|
||||
ggml_type wtype = GGML_TYPE_COUNT;
|
||||
switch (model.hparams.f16) {
|
||||
case 0: wtype = GGML_TYPE_F32; break;
|
||||
case 1: wtype = GGML_TYPE_F16; break;
|
||||
case 2: wtype = GGML_TYPE_Q4_0; break;
|
||||
case 3: wtype = GGML_TYPE_Q4_1; break;
|
||||
case 5: wtype = GGML_TYPE_Q4_2; break;
|
||||
default:
|
||||
{
|
||||
fprintf(stderr, "%s: invalid model file '%s' (bad f16 value %d)\n",
|
||||
__func__, fname.c_str(), model.hparams.f16);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto & ctx = model.ctx;
|
||||
|
||||
size_t ctx_size = 0;
|
||||
|
||||
{
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
const int expand = hparams.expand;
|
||||
|
||||
|
||||
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // ln_f_w
|
||||
|
||||
ctx_size += n_embd*n_vocab*ggml_type_sizef(GGML_TYPE_F32); // wte
|
||||
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // norm_1_w
|
||||
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // norm_2_w
|
||||
|
||||
ctx_size += n_layer*(3*n_embd*n_embd*ggml_type_sizef(wtype)); // attn_Wqkv_w
|
||||
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // attn_out_proj_w
|
||||
|
||||
ctx_size += n_layer*(expand*n_embd*n_embd*ggml_type_sizef(wtype)); // ffn_up_proj_w
|
||||
ctx_size += n_layer*(expand*n_embd*n_embd*ggml_type_sizef(wtype)); // ffn_down_proj_w
|
||||
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F16); // memory_k
|
||||
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(GGML_TYPE_F16); // memory_v
|
||||
|
||||
// TODO probably less now?
|
||||
ctx_size += (5 + 10*n_layer)*256; // object overhead
|
||||
|
||||
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size/(1024.0*1024.0));
|
||||
}
|
||||
|
||||
// create the ggml context
|
||||
{
|
||||
struct ggml_init_params params = {
|
||||
.mem_size = ctx_size,
|
||||
.mem_buffer = NULL,
|
||||
.no_alloc = false,
|
||||
};
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
if (!model.ctx) {
|
||||
fprintf(stderr, "%s: ggml_init() failed\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// prepare memory for the weights
|
||||
{
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
const int expand = hparams.expand;
|
||||
|
||||
model.layers.resize(n_layer);
|
||||
|
||||
model.wte = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_vocab);
|
||||
model.norm_f_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
// map by name
|
||||
model.tensors["transformer.wte.weight"] = model.wte;
|
||||
model.tensors["transformer.norm_f.weight"] = model.norm_f_w;
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = model.layers[i];
|
||||
|
||||
layer.norm_1_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
layer.norm_2_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_embd);
|
||||
|
||||
layer.attn_Wqkv_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd * 3);
|
||||
layer.attn_out_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, n_embd);
|
||||
layer.ffn_up_proj_w = ggml_new_tensor_2d(ctx, wtype, n_embd, expand*n_embd);
|
||||
layer.ffn_down_proj_w = ggml_new_tensor_2d(ctx, wtype, expand*n_embd, n_embd);
|
||||
|
||||
// map by name
|
||||
model.tensors["transformer.blocks." + std::to_string(i) + ".norm_1.weight"] = layer.norm_1_w;
|
||||
model.tensors["transformer.blocks." + std::to_string(i) + ".norm_2.weight"] = layer.norm_2_w;
|
||||
model.tensors["transformer.blocks." + std::to_string(i) + ".attn.Wqkv.weight"] = layer.attn_Wqkv_w;
|
||||
model.tensors["transformer.blocks." + std::to_string(i) + ".attn.out_proj.weight"] = layer.attn_out_proj_w;
|
||||
|
||||
model.tensors["transformer.blocks." + std::to_string(i) + ".ffn.up_proj.weight"] = layer.ffn_up_proj_w;
|
||||
model.tensors["transformer.blocks." + std::to_string(i) + ".ffn.down_proj.weight"] = layer.ffn_down_proj_w;
|
||||
}
|
||||
}
|
||||
|
||||
// key + value memory
|
||||
{
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
|
||||
const int n_mem = n_layer*n_ctx;
|
||||
const int n_elements = n_embd*n_mem;
|
||||
|
||||
if (!kv_cache_init(hparams, model.kv_self, GGML_TYPE_F16, model.hparams.n_ctx)) {
|
||||
fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__);
|
||||
ggml_free(ctx);
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t memory_size = ggml_nbytes(model.kv_self.k) + ggml_nbytes(model.kv_self.v);
|
||||
printf("%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0);
|
||||
}
|
||||
|
||||
// load weights
|
||||
{
|
||||
int n_tensors = 0;
|
||||
size_t total_size = 0;
|
||||
|
||||
printf("%s: ", __func__);
|
||||
|
||||
while (true) {
|
||||
int32_t n_dims;
|
||||
int32_t length;
|
||||
int32_t ttype;
|
||||
|
||||
fin.read(reinterpret_cast<char *>(&n_dims), sizeof(n_dims));
|
||||
fin.read(reinterpret_cast<char *>(&length), sizeof(length));
|
||||
fin.read(reinterpret_cast<char *>(&ttype), sizeof(ttype));
|
||||
|
||||
if (fin.eof()) {
|
||||
break;
|
||||
}
|
||||
|
||||
int32_t nelements = 1;
|
||||
int32_t ne[2] = { 1, 1 };
|
||||
for (int i = 0; i < n_dims; ++i) {
|
||||
fin.read(reinterpret_cast<char *>(&ne[i]), sizeof(ne[i]));
|
||||
nelements *= ne[i];
|
||||
}
|
||||
|
||||
std::string name(length, 0);
|
||||
fin.read(&name[0], length);
|
||||
|
||||
if (model.tensors.find(name.data()) == model.tensors.end()) {
|
||||
fprintf(stderr, "%s: unknown tensor '%s' in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
|
||||
auto tensor = model.tensors[name.data()];
|
||||
if (ggml_nelements(tensor) != nelements) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file\n", __func__, name.data());
|
||||
return false;
|
||||
}
|
||||
|
||||
if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1]) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong shape in model file: got [%d, %d], expected [%d, %d]\n",
|
||||
__func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], ne[0], ne[1]);
|
||||
return false;
|
||||
}
|
||||
|
||||
// for debugging
|
||||
if (0) {
|
||||
printf("%24s - [%5d, %5d], type = %6s, %6.2f MB, %9zu bytes\n", name.data(), ne[0], ne[1], ggml_type_name(ggml_type(ttype)), ggml_nbytes(tensor)/1024.0/1024.0, ggml_nbytes(tensor));
|
||||
}
|
||||
|
||||
const size_t bpe = ggml_type_size(ggml_type(ttype));
|
||||
|
||||
if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) {
|
||||
fprintf(stderr, "%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n",
|
||||
__func__, name.data(), ggml_nbytes(tensor), nelements*bpe);
|
||||
return false;
|
||||
}
|
||||
|
||||
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
|
||||
|
||||
//printf("%42s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ttype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
|
||||
total_size += ggml_nbytes(tensor);
|
||||
if (++n_tensors % 8 == 0) {
|
||||
printf(".");
|
||||
fflush(stdout);
|
||||
}
|
||||
}
|
||||
|
||||
printf(" done\n");
|
||||
|
||||
printf("%s: model size = %8.2f MB / num tensors = %d\n", __func__, total_size/1024.0/1024.0, n_tensors);
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
// load the model's weights from a file path
|
||||
bool mpt_model_load(const std::string & fname, mpt_model & model, mpt_vocab & vocab) {
|
||||
|
||||
auto fin = std::ifstream(fname, std::ios::binary);
|
||||
if (!fin) {
|
||||
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
|
||||
return false;
|
||||
}
|
||||
|
||||
bool loaded = mpt_model_load(fname, fin, model, vocab);
|
||||
fin.close();
|
||||
return loaded;
|
||||
}
|
||||
|
||||
struct ggml_tensor * ggml_alibi(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
int n_past,
|
||||
int n_head) {
|
||||
if (n_past < 0) {
|
||||
return NULL;
|
||||
}
|
||||
bool is_node = false;
|
||||
|
||||
if (a->grad) {
|
||||
return NULL; // TODO: implement backward
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
// TODO: when implement backward, fix this:
|
||||
//struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
||||
|
||||
struct ggml_tensor * b = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 3);
|
||||
((int32_t *) b->data)[0] = n_past;
|
||||
((int32_t *) b->data)[1] = n_head;
|
||||
|
||||
result->op = GGML_OP_COUNT; // Dirty hack
|
||||
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||
result->src0 = a;
|
||||
result->src1 = b;
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
bool mpt_eval(
|
||||
mpt_model & model,
|
||||
const int n_threads,
|
||||
const int n_past,
|
||||
const std::vector<int> & embd_inp,
|
||||
std::vector<float> & embd_w,
|
||||
size_t & mem_per_token) {
|
||||
const int N = embd_inp.size();
|
||||
|
||||
const auto & hparams = model.hparams;
|
||||
|
||||
const int n_embd = hparams.n_embd;
|
||||
const int n_layer = hparams.n_layer;
|
||||
const int n_ctx = hparams.n_ctx;
|
||||
const int n_head = hparams.n_head;
|
||||
const int n_vocab = hparams.n_vocab;
|
||||
const int expand = hparams.expand;
|
||||
|
||||
const int d_key = n_embd/n_head;
|
||||
|
||||
static size_t buf_size = 256u*1024*1024;
|
||||
static void * buf = malloc(buf_size);
|
||||
|
||||
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
|
||||
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
|
||||
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||
|
||||
// reallocate
|
||||
buf_size = buf_size_new;
|
||||
buf = realloc(buf, buf_size);
|
||||
if (buf == nullptr) {
|
||||
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
struct ggml_init_params params = {
|
||||
.mem_size = buf_size,
|
||||
.mem_buffer = buf,
|
||||
.no_alloc = false,
|
||||
};
|
||||
|
||||
struct ggml_context * ctx0 = ggml_init(params);
|
||||
struct ggml_cgraph gf = { .n_threads = n_threads };
|
||||
|
||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
||||
|
||||
// wte
|
||||
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.wte, embd);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
struct ggml_tensor * cur = inpSA;
|
||||
// self-attention
|
||||
{
|
||||
|
||||
// norm1
|
||||
cur = ggml_norm(ctx0, cur);
|
||||
cur = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].norm_1_w, cur),
|
||||
cur);
|
||||
// compute QKV
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].attn_Wqkv_w,
|
||||
cur);
|
||||
|
||||
// TODO: clip_qkv
|
||||
struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0*ggml_element_size(cur)*n_embd));
|
||||
struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1*ggml_element_size(cur)*n_embd));
|
||||
struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2*ggml_element_size(cur)*n_embd));
|
||||
|
||||
// TODO: qk_ln? (seems to be False in MPT-7B configs)
|
||||
{
|
||||
Vcur = ggml_transpose(ctx0, Vcur);
|
||||
|
||||
struct ggml_tensor * k = ggml_view_1d(ctx0, model.kv_self.k, N*n_embd, (ggml_element_size(model.kv_self.k)*n_embd)*(il*n_ctx + n_past));
|
||||
struct ggml_tensor * v = ggml_view_2d(ctx0, model.kv_self.v, N, n_embd,
|
||||
( n_ctx)*ggml_element_size(model.kv_self.v),
|
||||
(il*n_ctx)*ggml_element_size(model.kv_self.v)*n_embd + n_past*ggml_element_size(model.kv_self.v));
|
||||
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k));
|
||||
ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v));
|
||||
}
|
||||
// Q = Qcur.contiguous().view(n_embd/n_head, n_head, N).permute(0, 2, 1, 3)
|
||||
struct ggml_tensor * Q =
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0, Qcur, n_embd/n_head, n_head, N),
|
||||
0, 2, 1, 3);
|
||||
|
||||
struct ggml_tensor * K =
|
||||
ggml_permute(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_view_1d(ctx0, model.kv_self.k, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(model.kv_self.k)*n_embd),
|
||||
n_embd/n_head, n_head, n_past + N),
|
||||
0, 2, 1, 3);
|
||||
|
||||
// K * Q
|
||||
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);
|
||||
|
||||
// KQ_scaled = KQ / sqrt(n_embd/n_head)
|
||||
struct ggml_tensor * KQ_scaled =
|
||||
ggml_scale(ctx0,
|
||||
KQ,
|
||||
ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
|
||||
);
|
||||
|
||||
|
||||
// Alibi
|
||||
struct ggml_tensor * KQ_scaled_biased = ggml_alibi(ctx0, ggml_cont(ctx0, KQ_scaled), n_past, n_head);
|
||||
|
||||
// KQ_masked = mask_past(KQ_scaled)
|
||||
struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled_biased, n_past);
|
||||
|
||||
// KQ = soft_max(KQ_masked)
|
||||
struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked);
|
||||
|
||||
// V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous()
|
||||
struct ggml_tensor * V =
|
||||
ggml_view_3d(ctx0, model.kv_self.v,
|
||||
n_past + N, n_embd/n_head, n_head,
|
||||
n_ctx*ggml_element_size(model.kv_self.v),
|
||||
n_ctx*ggml_element_size(model.kv_self.v)*n_embd/n_head,
|
||||
il*n_ctx*ggml_element_size(model.kv_self.v)*n_embd);
|
||||
|
||||
// KQV = transpose(V) * KQ_soft_max
|
||||
struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max);
|
||||
|
||||
// KQV_merged = KQV.permute(0, 2, 1, 3)
|
||||
struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3);
|
||||
|
||||
// cur = KQV_merged.contiguous().view(n_embd, N)
|
||||
cur = ggml_cpy(ctx0,
|
||||
KQV_merged,
|
||||
ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N));
|
||||
|
||||
// projection (no bias)
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].attn_out_proj_w,
|
||||
cur);
|
||||
}
|
||||
|
||||
|
||||
// residual
|
||||
struct ggml_tensor * resSA = ggml_add(ctx0, cur, inpSA);
|
||||
// feed-forward network
|
||||
{
|
||||
cur = resSA;
|
||||
// norm2
|
||||
cur = ggml_norm(ctx0, cur);
|
||||
cur = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.layers[il].norm_2_w, cur),
|
||||
cur);
|
||||
// ffn
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].ffn_up_proj_w,
|
||||
cur);
|
||||
cur = ggml_gelu(ctx0, cur);
|
||||
cur = ggml_mul_mat(ctx0,
|
||||
model.layers[il].ffn_down_proj_w,
|
||||
cur);
|
||||
|
||||
}
|
||||
|
||||
// self-attention + FF
|
||||
inpL = ggml_add(ctx0, cur, resSA);
|
||||
}
|
||||
|
||||
struct ggml_tensor * out = inpL;
|
||||
// -> logits
|
||||
{
|
||||
out = ggml_norm(ctx0, out);
|
||||
out = ggml_mul(ctx0,
|
||||
ggml_repeat(ctx0, model.norm_f_w, out),
|
||||
out);
|
||||
out = ggml_mul_mat(ctx0, model.wte, out);
|
||||
}
|
||||
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(&gf, out);
|
||||
ggml_graph_compute (ctx0, &gf);
|
||||
|
||||
|
||||
// return result for just the last token
|
||||
embd_w.resize(n_vocab);
|
||||
memcpy(embd_w.data(), (float *) ggml_get_data(out) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||
|
||||
if (mem_per_token == 0) {
|
||||
mem_per_token = ggml_used_mem(ctx0)/N;
|
||||
}
|
||||
//printf("used_mem = %zu\n", ggml_used_mem(ctx0));
|
||||
|
||||
ggml_free(ctx0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<int> mpt_tokenize_inner(const mpt_vocab & vocab, const std::string & text) {
|
||||
// taken from stablelm example in ggml
|
||||
// they both use the gpt-neox tokenizer
|
||||
// not sure if this entirely right?
|
||||
std::vector<std::string> words;
|
||||
|
||||
|
||||
// first split the text into words
|
||||
{
|
||||
std::string str = text;
|
||||
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||
std::regex re(pat);
|
||||
std::smatch m;
|
||||
|
||||
while (std::regex_search(str, m, re)) {
|
||||
for (auto x : m) {
|
||||
words.push_back(x);
|
||||
}
|
||||
str = m.suffix();
|
||||
}
|
||||
}
|
||||
|
||||
// find the longest tokens that form the words:
|
||||
std::vector<mpt_vocab::id> tokens;
|
||||
for (const auto & word : words) {
|
||||
if (word.size() == 0) continue;
|
||||
|
||||
int i = 0;
|
||||
int n = word.size();
|
||||
while (i < n) {
|
||||
int j = n;
|
||||
while (j > i) {
|
||||
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||
if (it != vocab.token_to_id.end()) {
|
||||
tokens.push_back(it->second);
|
||||
i = j;
|
||||
break;
|
||||
}
|
||||
--j;
|
||||
}
|
||||
if (i == n) {
|
||||
break;
|
||||
}
|
||||
if (j == i) {
|
||||
auto sub = word.substr(i, 1);
|
||||
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||
tokens.push_back(vocab.token_to_id.at(sub));
|
||||
} else {
|
||||
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||
}
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens;
|
||||
}
|
||||
|
||||
std::vector<mpt_vocab::id> mpt_tokenize(const mpt_vocab & vocab, const std::string & text) {
|
||||
// Generate the subpattern from the special_tokens vector if it's not empty
|
||||
if (!vocab.special_tokens.empty()) {
|
||||
std::vector<mpt_vocab::id> out;
|
||||
std::vector<std::string> chunks;
|
||||
std::string str = text;
|
||||
std::string special_tokens_subpattern;
|
||||
for (const auto &token : vocab.special_tokens) {
|
||||
if (!special_tokens_subpattern.empty()) {
|
||||
special_tokens_subpattern += "|";
|
||||
}
|
||||
special_tokens_subpattern += regex_escape(token);
|
||||
}
|
||||
std::regex re(special_tokens_subpattern);
|
||||
std::smatch m;
|
||||
while (std::regex_search(str, m, re)) {
|
||||
auto tok = vocab.token_to_id.find(m.str());
|
||||
if (tok != vocab.token_to_id.end()) {
|
||||
auto tokid = tok->second;
|
||||
auto pfxtoks = mpt_tokenize_inner(vocab, m.prefix());
|
||||
out.insert(out.end(), pfxtoks.begin(), pfxtoks.end());
|
||||
out.push_back(tokid);
|
||||
str = m.suffix();
|
||||
}
|
||||
}
|
||||
if (!str.empty()) {
|
||||
auto tokrest = mpt_tokenize_inner(vocab, str);
|
||||
out.insert(out.end(), tokrest.begin(), tokrest.end());
|
||||
}
|
||||
return out;
|
||||
} else {
|
||||
return mpt_tokenize_inner(vocab, text);
|
||||
}
|
||||
}
|
||||
|
||||
#define MPT_MAX_RNG_STATE 64*1024
|
||||
|
||||
size_t mpt_get_state_size(const mpt_model &model)
|
||||
{
|
||||
// we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state.
|
||||
// for reference, std::mt19937(1337) serializes to 6701 bytes.
|
||||
const size_t s_rng_size = sizeof(size_t);
|
||||
const size_t s_rng = MPT_MAX_RNG_STATE;
|
||||
const size_t s_kv_size = sizeof(size_t);
|
||||
const size_t s_kv_ntok = sizeof(int);
|
||||
const size_t s_kv = model.kv_self.buf.size;
|
||||
const size_t s_total = (
|
||||
+ s_rng_size
|
||||
+ s_rng
|
||||
+ s_kv_size
|
||||
+ s_kv_ntok
|
||||
+ s_kv
|
||||
);
|
||||
fflush(stdout);
|
||||
return s_total;
|
||||
}
|
||||
|
||||
size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint8_t *dest)
|
||||
{
|
||||
uint8_t * out = dest;
|
||||
fflush(stdout);
|
||||
// copy rng
|
||||
{
|
||||
std::stringstream rng_ss;
|
||||
rng_ss << rng;
|
||||
|
||||
const size_t rng_size = rng_ss.str().size();
|
||||
char rng_buf[MPT_MAX_RNG_STATE];
|
||||
|
||||
memset(&rng_buf[0], 0, MPT_MAX_RNG_STATE);
|
||||
memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size());
|
||||
|
||||
memcpy(out, &rng_size, sizeof(rng_size)); out += sizeof(rng_size);
|
||||
memcpy(out, &rng_buf[0], MPT_MAX_RNG_STATE); out += MPT_MAX_RNG_STATE;
|
||||
}
|
||||
|
||||
// copy kv cache
|
||||
{
|
||||
const size_t kv_size = model.kv_self.buf.size;
|
||||
const int kv_ntok = model.kv_self.n;
|
||||
|
||||
memcpy(out, &kv_size, sizeof(kv_size)); out += sizeof(kv_size);
|
||||
memcpy(out, &kv_ntok, sizeof(kv_ntok)); out += sizeof(kv_ntok);
|
||||
|
||||
if (kv_size) {
|
||||
memcpy(out, model.kv_self.buf.addr, kv_size); out += kv_size;
|
||||
}
|
||||
}
|
||||
|
||||
const size_t written = out - dest;
|
||||
const size_t expected = mpt_get_state_size(model);
|
||||
assert(written == expected);
|
||||
fflush(stdout);
|
||||
return written;
|
||||
}
|
||||
|
||||
mpt_vocab::id mpt_sample_top_k_top_p(
|
||||
const mpt_vocab & vocab,
|
||||
const size_t actualVocabSize,
|
||||
const int32_t * last_n_tokens_data,
|
||||
int last_n_tokens_size,
|
||||
const std::vector<float> logits,
|
||||
int top_k,
|
||||
double top_p,
|
||||
double temp,
|
||||
float repeat_penalty,
|
||||
std::mt19937 & rng) {
|
||||
int n_logits = actualVocabSize;
|
||||
|
||||
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
|
||||
const auto * plogits = logits.data() + logits.size() - n_logits;
|
||||
|
||||
std::vector<std::pair<double, mpt_vocab::id>> logits_id;
|
||||
logits_id.reserve(n_logits);
|
||||
|
||||
{
|
||||
const float scale = 1.0f/temp;
|
||||
for (int i = 0; i < n_logits; ++i) {
|
||||
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
|
||||
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
|
||||
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
|
||||
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||
if (plogits[i] < 0.0f) {
|
||||
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
|
||||
} else {
|
||||
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
|
||||
}
|
||||
} else {
|
||||
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// find the top K tokens
|
||||
std::partial_sort(
|
||||
logits_id.begin(),
|
||||
logits_id.begin() + top_k, logits_id.end(),
|
||||
[](const std::pair<double, mpt_vocab::id> & a, const std::pair<double, mpt_vocab::id> & b) {
|
||||
return a.first > b.first;
|
||||
});
|
||||
|
||||
logits_id.resize(top_k);
|
||||
|
||||
double maxl = -INFINITY;
|
||||
for (const auto & kv : logits_id) {
|
||||
maxl = std::max(maxl, kv.first);
|
||||
}
|
||||
|
||||
// compute probs for the top K tokens
|
||||
std::vector<double> probs;
|
||||
probs.reserve(logits_id.size());
|
||||
|
||||
double sum = 0.0;
|
||||
for (const auto & kv : logits_id) {
|
||||
double p = exp(kv.first - maxl);
|
||||
probs.push_back(p);
|
||||
sum += p;
|
||||
}
|
||||
|
||||
// normalize the probs
|
||||
for (auto & p : probs) {
|
||||
p /= sum;
|
||||
}
|
||||
|
||||
if (top_p < 1.0f) {
|
||||
double cumsum = 0.0f;
|
||||
for (int i = 0; i < top_k; i++) {
|
||||
cumsum += probs[i];
|
||||
if (cumsum >= top_p) {
|
||||
top_k = i + 1;
|
||||
probs.resize(top_k);
|
||||
logits_id.resize(top_k);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
cumsum = 1.0/cumsum;
|
||||
for (int i = 0; i < (int) probs.size(); i++) {
|
||||
probs[i] *= cumsum;
|
||||
}
|
||||
}
|
||||
|
||||
//printf("\n");
|
||||
//for (int i = 0; i < (int) probs.size(); i++) {
|
||||
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
|
||||
//}
|
||||
//exit(0);
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
int idx = dist(rng);
|
||||
|
||||
return logits_id[idx].second;
|
||||
}
|
||||
|
||||
size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *src)
|
||||
{
|
||||
const uint8_t * in = src;
|
||||
|
||||
// set rng
|
||||
{
|
||||
size_t rng_size;
|
||||
char rng_buf[MPT_MAX_RNG_STATE];
|
||||
|
||||
memcpy(&rng_size, in, sizeof(rng_size)); in += sizeof(rng_size);
|
||||
memcpy(&rng_buf[0], in, MPT_MAX_RNG_STATE); in += MPT_MAX_RNG_STATE;
|
||||
|
||||
std::stringstream rng_ss;
|
||||
rng_ss.str(std::string(&rng_buf[0], rng_size));
|
||||
rng_ss >> *rng;
|
||||
|
||||
assert(rng_ss.fail() == false);
|
||||
}
|
||||
|
||||
// set kv cache
|
||||
{
|
||||
size_t kv_size;
|
||||
int kv_ntok;
|
||||
|
||||
memcpy(&kv_size, in, sizeof(kv_size)); in += sizeof(kv_size);
|
||||
memcpy(&kv_ntok, in, sizeof(kv_ntok)); in += sizeof(kv_ntok);
|
||||
|
||||
if (kv_size) {
|
||||
assert(model->kv_self.buf.size == kv_size);
|
||||
|
||||
void * k_data = model->kv_self.k->data; // remember data pointers
|
||||
void * v_data = model->kv_self.v->data; // because their value is stored in buf and overwritten by memcpy
|
||||
|
||||
memcpy(model->kv_self.buf.addr, in, kv_size); in += kv_size;
|
||||
|
||||
model->kv_self.k->data = k_data; // restore correct data pointers
|
||||
model->kv_self.v->data = v_data;
|
||||
|
||||
}
|
||||
|
||||
model->kv_self.n = kv_ntok;
|
||||
}
|
||||
|
||||
const size_t nread = in - src;
|
||||
const size_t expected = mpt_get_state_size(*model);
|
||||
assert(nread == expected);
|
||||
fflush(stdout);
|
||||
return nread;
|
||||
}
|
117
mpt/mpt.hpp
Normal file
117
mpt/mpt.hpp
Normal file
|
@ -0,0 +1,117 @@
|
|||
#ifndef MPT_H
|
||||
#define MPT_H
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <random>
|
||||
#include <ggml.h>
|
||||
|
||||
|
||||
// default hparams (MPT 7B)
|
||||
struct mpt_hparams {
|
||||
int32_t n_vocab = 50432;
|
||||
int32_t n_ctx = 2048;
|
||||
int32_t n_embd = 4096;
|
||||
int32_t n_head = 32;
|
||||
int32_t n_layer = 32;
|
||||
float alibi_bias_max = 8;
|
||||
float clip_qkv = 0;
|
||||
int32_t expand = 4;
|
||||
int32_t f16 = 1;
|
||||
};
|
||||
|
||||
struct mpt_layer {
|
||||
// normalization
|
||||
struct ggml_tensor * norm_1_w;
|
||||
struct ggml_tensor * norm_2_w;
|
||||
|
||||
// attention
|
||||
struct ggml_tensor * attn_Wqkv_w;
|
||||
struct ggml_tensor * attn_out_proj_w;
|
||||
|
||||
// ff
|
||||
struct ggml_tensor * ffn_up_proj_w;
|
||||
struct ggml_tensor * ffn_down_proj_w;
|
||||
};
|
||||
|
||||
struct mpt_buffer {
|
||||
uint8_t * addr = NULL;
|
||||
size_t size = 0;
|
||||
|
||||
void resize(size_t size) {
|
||||
delete[] addr;
|
||||
addr = new uint8_t[size];
|
||||
this->size = size;
|
||||
}
|
||||
|
||||
~mpt_buffer() {
|
||||
fflush(stdout);
|
||||
delete[] addr;
|
||||
}
|
||||
};
|
||||
|
||||
struct mpt_kv_cache {
|
||||
struct ggml_tensor * k;
|
||||
struct ggml_tensor * v;
|
||||
|
||||
struct ggml_context * ctx = NULL;
|
||||
|
||||
mpt_buffer buf;
|
||||
|
||||
int n; // number of tokens currently in the cache
|
||||
|
||||
~mpt_kv_cache() {
|
||||
if (ctx) {
|
||||
ggml_free(ctx);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct mpt_model {
|
||||
mpt_hparams hparams;
|
||||
|
||||
// normalization
|
||||
struct ggml_tensor * norm_f_w;
|
||||
|
||||
struct ggml_tensor * wte; // position embedding
|
||||
|
||||
// mpt does weight tying
|
||||
|
||||
std::vector<mpt_layer> layers;
|
||||
|
||||
struct mpt_kv_cache kv_self;
|
||||
struct ggml_context * ctx;
|
||||
std::map<std::string, struct ggml_tensor *> tensors;
|
||||
|
||||
|
||||
mpt_buffer buf;
|
||||
|
||||
~mpt_model() {
|
||||
if (ctx) {
|
||||
ggml_free(ctx);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct mpt_vocab {
|
||||
using id = int32_t;
|
||||
using token = std::string;
|
||||
|
||||
std::map<token, id> token_to_id;
|
||||
std::map<id, token> id_to_token;
|
||||
std::vector<std::string> special_tokens;
|
||||
|
||||
void add_special_token(const std::string &token) {
|
||||
special_tokens.push_back(token);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, mpt_vocab & vocab);
|
||||
bool mpt_eval(mpt_model& model, const int n_threads, const int n_past, const std::vector<int>& embd_inp, std::vector<float>& embd_w, size_t& mem_per_token);
|
||||
std::vector<mpt_vocab::id> mpt_tokenize(const mpt_vocab & vocab, const std::string & text);
|
||||
mpt_vocab::id mpt_sample_top_k_top_p(const mpt_vocab& vocab, const size_t actualVocabSize, const int32_t *last_n_tokens_data, int last_n_tokens_size, const std::vector<float> logits, int top_k, double top_p, double temp, float repeat_penalty, std::mt19937& rng);
|
||||
size_t mpt_get_state_size(const mpt_model &model);
|
||||
size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937& rng, uint8_t *dest);
|
||||
size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *src);
|
||||
#endif // MPT_H
|
Loading…
Add table
Reference in a new issue