1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00

MPT works now!

This commit is contained in:
niansa/tuxifan 2023-05-17 09:33:16 +02:00
parent 4ec47699f0
commit f5cf0ecff2
4 changed files with 3 additions and 4 deletions

View file

@ -2,7 +2,7 @@
Super easy to use library for doing LLaMA/GPT-J/MPT stuff!
## Overview
This library implements an easy to use interface to LLaMa, GPT-J and MPT (not yet functional), with optional Python bindings.
This library implements an easy to use interface to LLaMa, GPT-J and MPT, with optional Python bindings.
Context scrolling is automatic and supports a top window bar.

View file

@ -175,7 +175,7 @@ public:
while (!abort && !ends_with(fres, end)) {
// Sample top p and top k
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
auto id = mpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
if (id == 0 || (state->has_im_end && id == state->vocab.token_to_id["<|im_end|>"])) {
if (eos_count++ == params.eos_ignores) {

View file

@ -379,7 +379,7 @@ bool mpt_eval(
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf;
struct ggml_cgraph gf{};
gf.n_threads = n_threads;
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);

View file

@ -104,7 +104,6 @@ struct mpt_model {
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, gpt_vocab& vocab);
bool mpt_eval(mpt_model& model, const int n_threads, const int n_past, const std::vector<int>& embd_inp, std::vector<float>& embd_w, size_t& mem_per_token);
gpt_vocab::id mpt_sample_top_k_top_p(const size_t actualVocabSize, const int32_t *last_n_tokens_data, int last_n_tokens_size, const std::vector<float> logits, int top_k, double top_p, double temp, float repeat_penalty, std::mt19937& rng);
size_t mpt_get_state_size(const mpt_model &model);
size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937& rng, uint8_t *dest);
size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *src);