1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00

Minor MPT improvements

This commit is contained in:
niansa 2023-05-16 23:35:42 +02:00
parent 4c4ef9e441
commit a98784aa53
2 changed files with 6 additions and 2 deletions

View file

@ -80,7 +80,7 @@ public:
struct Params {
int seed = 0; // RNG seed
unsigned n_threads = 0;
unsigned n_ctx = 2012; // Context size
unsigned n_ctx = 2024; // Context size
unsigned n_ctx_window_top_bar = 0; // Top bar of context window. Must be smaller than context size
unsigned n_batch = 8; // Batch size
unsigned n_repeat_last = 0; // llama.cpp specific

View file

@ -19,6 +19,7 @@ class MPTInference final : public Inference {
std::vector<float> logits;
size_t mem_per_token = 0;
std::mt19937 rng;
bool has_im_end;
State(int32_t seed) : rng(seed) {}
};
@ -47,6 +48,9 @@ class MPTInference final : public Inference {
static std::vector<gpt_vocab::id> r_instruct;
mpt_eval(state->model, params.n_threads, 0, { 0, 1, 2, 3 }, state->logits, state->mem_per_token);
// Some other stuff
state->has_im_end = state->vocab.token_to_id.find("<|im_end|>") != state->vocab.token_to_id.end();
return LM_BOOL_SUCCESS;
}
void deinit() LM_NOEXCEPTDECL {
@ -172,7 +176,7 @@ public:
// Sample top p and top k
auto id = mpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data(), state->tokens.size(), state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
if (id == state->vocab.token_to_id["<|im_end|>"]) {
if (id == 0 || (state->has_im_end && id == state->vocab.token_to_id["<|im_end|>"])) {
if (eos_count++ == params.eos_ignores) {
abort = true;
continue;