1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00

Only look up im_end once

This commit is contained in:
niansa/tuxifan 2023-05-17 10:17:51 +02:00
parent f5cf0ecff2
commit 8e7e310757

View file

@ -19,7 +19,7 @@ class MPTInference final : public Inference {
std::vector<float> logits;
size_t mem_per_token = 0;
std::mt19937 rng;
bool has_im_end;
int im_end = 0;
State(int32_t seed) : rng(seed) {}
};
@ -48,8 +48,13 @@ class MPTInference final : public Inference {
static std::vector<gpt_vocab::id> r_instruct;
mpt_eval(state->model, params.n_threads, 0, { 0, 1, 2, 3 }, state->logits, state->mem_per_token);
// Some other stuff
state->has_im_end = state->vocab.token_to_id.find("<|im_end|>") != state->vocab.token_to_id.end();
// Find im_end token
{
auto res = state->vocab.token_to_id.find("<|im_end|>");
if (res != state->vocab.token_to_id.end()) {
state->im_end = res->second;
}
}
return LM_BOOL_SUCCESS;
}
@ -177,11 +182,13 @@ public:
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
if (id == 0 || (state->has_im_end && id == state->vocab.token_to_id["<|im_end|>"])) {
if (id == 0 || id == state->im_end) {
if (eos_count++ == params.eos_ignores) {
abort = true;
printf("Stopping due to EOS (%d)\n", id);
continue;
}
printf("Retrying after EOS (%d)... %d\n", id, eos_count);
id = gpt_tokenize(state->vocab, "\n")[0];
state->tokens.push_back(id);
} else {
@ -203,6 +210,7 @@ public:
// TODO: Respect batch size
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
if (!mpt_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
printf("Stopping due to eval error (%d)\n", id);
LM_COTHROW("Failed to evaluate new tokens", "");
}