diff --git a/justlm_mpt.hpp b/justlm_mpt.hpp index 854b8e6..2c0d0e3 100644 --- a/justlm_mpt.hpp +++ b/justlm_mpt.hpp @@ -19,7 +19,7 @@ class MPTInference final : public Inference { std::vector logits; size_t mem_per_token = 0; std::mt19937 rng; - bool has_im_end; + int im_end = 0; State(int32_t seed) : rng(seed) {} }; @@ -48,8 +48,13 @@ class MPTInference final : public Inference { static std::vector r_instruct; mpt_eval(state->model, params.n_threads, 0, { 0, 1, 2, 3 }, state->logits, state->mem_per_token); - // Some other stuff - state->has_im_end = state->vocab.token_to_id.find("<|im_end|>") != state->vocab.token_to_id.end(); + // Find im_end token + { + auto res = state->vocab.token_to_id.find("<|im_end|>"); + if (res != state->vocab.token_to_id.end()) { + state->im_end = res->second; + } + } return LM_BOOL_SUCCESS; } @@ -177,11 +182,13 @@ public: const auto n_repeat_last = std::min(state->tokens.size(), params.n_repeat_last); auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng); - if (id == 0 || (state->has_im_end && id == state->vocab.token_to_id["<|im_end|>"])) { + if (id == 0 || id == state->im_end) { if (eos_count++ == params.eos_ignores) { abort = true; + printf("Stopping due to EOS (%d)\n", id); continue; } + printf("Retrying after EOS (%d)... %d\n", id, eos_count); id = gpt_tokenize(state->vocab, "\n")[0]; state->tokens.push_back(id); } else { @@ -203,6 +210,7 @@ public: // TODO: Respect batch size std::vector batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size()); if (!mpt_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) { + printf("Stopping due to eval error (%d)\n", id); LM_COTHROW("Failed to evaluate new tokens", ""); }