mirror of
https://gitlab.com/niansa/libjustlm.git
synced 2025-03-06 20:49:17 +01:00
Minor improvements on EOS handling
This commit is contained in:
parent
8e7e310757
commit
abbb35c6a9
3 changed files with 19 additions and 18 deletions
|
@ -179,17 +179,16 @@ public:
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
id = gpt_tokenize(state->vocab, "\n")[0];
|
id = gpt_tokenize(state->vocab, "\n")[0];
|
||||||
state->tokens.push_back(id);
|
|
||||||
} else {
|
|
||||||
// Add token
|
|
||||||
state->tokens.push_back(id);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add token
|
||||||
|
state->tokens.push_back(id);
|
||||||
|
|
||||||
// Make sure token limit isn't being hit
|
// Make sure token limit isn't being hit
|
||||||
LM_COAWAIT window_scroll();
|
LM_COAWAIT window_scroll();
|
||||||
|
|
||||||
// Get token as string
|
// Get token as string
|
||||||
const auto str = state->vocab.id_to_token[id];
|
const std::string_view str = state->vocab.id_to_token[id];
|
||||||
|
|
||||||
// Append string to function result
|
// Append string to function result
|
||||||
state->prompt.append(str);
|
state->prompt.append(str);
|
||||||
|
@ -203,7 +202,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tick
|
// Tick
|
||||||
if (on_tick && !on_tick(str.c_str())) abort = true;
|
if (on_tick && !on_tick(str.data())) abort = true;
|
||||||
else if (!LM_TASKYIELD) abort = true;
|
else if (!LM_TASKYIELD) abort = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -180,7 +180,7 @@ public:
|
||||||
LM_COAWAIT window_scroll();
|
LM_COAWAIT window_scroll();
|
||||||
|
|
||||||
// Get token as string
|
// Get token as string
|
||||||
const auto str = llama_token_to_str(state->ctx, id);
|
const std::string_view str = llama_token_to_str(state->ctx, id);
|
||||||
|
|
||||||
// Append string to function result
|
// Append string to function result
|
||||||
state->prompt.append(str);
|
state->prompt.append(str);
|
||||||
|
@ -193,7 +193,7 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tick and yield
|
// Tick and yield
|
||||||
if (on_tick && !on_tick(str)) abort = true;
|
if (on_tick && !on_tick(str.data())) abort = true;
|
||||||
else if (!LM_TASKYIELD) abort = true;
|
else if (!LM_TASKYIELD) abort = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -182,25 +182,28 @@ public:
|
||||||
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
|
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
|
||||||
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
|
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
|
||||||
|
|
||||||
if (id == 0 || id == state->im_end) {
|
if (state->im_end && id == state->im_end) {
|
||||||
|
if (eos_count++ == params.eos_ignores) {
|
||||||
|
abort = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
id = gpt_tokenize(state->vocab, "\n")[0];
|
||||||
|
} else if (id == 0) {
|
||||||
if (eos_count++ == params.eos_ignores) {
|
if (eos_count++ == params.eos_ignores) {
|
||||||
abort = true;
|
abort = true;
|
||||||
printf("Stopping due to EOS (%d)\n", id);
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
printf("Retrying after EOS (%d)... %d\n", id, eos_count);
|
|
||||||
id = gpt_tokenize(state->vocab, "\n")[0];
|
id = gpt_tokenize(state->vocab, "\n")[0];
|
||||||
state->tokens.push_back(id);
|
|
||||||
} else {
|
|
||||||
// Add token
|
|
||||||
state->tokens.push_back(id);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Add token
|
||||||
|
state->tokens.push_back(id);
|
||||||
|
|
||||||
// Make sure token limit isn't being hit
|
// Make sure token limit isn't being hit
|
||||||
LM_COAWAIT window_scroll();
|
LM_COAWAIT window_scroll();
|
||||||
|
|
||||||
// Get token as string
|
// Get token as string
|
||||||
const auto str = state->vocab.id_to_token[id];
|
const std::string_view str = state->vocab.id_to_token[id];
|
||||||
|
|
||||||
// Append string to function result
|
// Append string to function result
|
||||||
fres.append(str);
|
fres.append(str);
|
||||||
|
@ -210,12 +213,11 @@ public:
|
||||||
// TODO: Respect batch size
|
// TODO: Respect batch size
|
||||||
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
|
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
|
||||||
if (!mpt_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
|
if (!mpt_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
|
||||||
printf("Stopping due to eval error (%d)\n", id);
|
|
||||||
LM_COTHROW("Failed to evaluate new tokens", "");
|
LM_COTHROW("Failed to evaluate new tokens", "");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tick
|
// Tick
|
||||||
if (on_tick && !on_tick(str.c_str())) abort = true;
|
if (on_tick && !on_tick(str.data())) abort = true;
|
||||||
else if (!LM_TASKYIELD) abort = true;
|
else if (!LM_TASKYIELD) abort = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue