From f5cf0ecff223636b05dfe3f744bd064927ade38b Mon Sep 17 00:00:00 2001 From: niansa Date: Wed, 17 May 2023 09:33:16 +0200 Subject: [PATCH] MPT works now! --- README.md | 2 +- justlm_mpt.hpp | 2 +- mpt/mpt.cpp | 2 +- mpt/mpt.hpp | 1 - 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 78124a2..6155872 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Super easy to use library for doing LLaMA/GPT-J/MPT stuff! ## Overview -This library implements an easy to use interface to LLaMa, GPT-J and MPT (not yet functional), with optional Python bindings. +This library implements an easy to use interface to LLaMa, GPT-J and MPT, with optional Python bindings. Context scrolling is automatic and supports a top window bar. diff --git a/justlm_mpt.hpp b/justlm_mpt.hpp index 310e7ca..854b8e6 100644 --- a/justlm_mpt.hpp +++ b/justlm_mpt.hpp @@ -175,7 +175,7 @@ public: while (!abort && !ends_with(fres, end)) { // Sample top p and top k const auto n_repeat_last = std::min(state->tokens.size(), params.n_repeat_last); - auto id = mpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng); + auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng); if (id == 0 || (state->has_im_end && id == state->vocab.token_to_id["<|im_end|>"])) { if (eos_count++ == params.eos_ignores) { diff --git a/mpt/mpt.cpp b/mpt/mpt.cpp index 0371922..0c35579 100644 --- a/mpt/mpt.cpp +++ b/mpt/mpt.cpp @@ -379,7 +379,7 @@ bool mpt_eval( }; struct ggml_context * ctx0 = ggml_init(params); - struct ggml_cgraph gf; + struct ggml_cgraph gf{}; gf.n_threads = n_threads; struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); diff --git a/mpt/mpt.hpp b/mpt/mpt.hpp index c106684..60d2f72 100644 --- a/mpt/mpt.hpp +++ b/mpt/mpt.hpp @@ -104,7 +104,6 @@ struct mpt_model { bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, gpt_vocab& vocab); bool mpt_eval(mpt_model& model, const int n_threads, const int n_past, const std::vector& embd_inp, std::vector& embd_w, size_t& mem_per_token); -gpt_vocab::id mpt_sample_top_k_top_p(const size_t actualVocabSize, const int32_t *last_n_tokens_data, int last_n_tokens_size, const std::vector logits, int top_k, double top_p, double temp, float repeat_penalty, std::mt19937& rng); size_t mpt_get_state_size(const mpt_model &model); size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937& rng, uint8_t *dest); size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *src);