1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00

Repeat penalty fixes

This commit is contained in:
niansa/tuxifan 2023-05-17 08:44:25 +02:00
parent a77d25d01d
commit 4ec47699f0
3 changed files with 6 additions and 3 deletions

View file

@ -170,7 +170,8 @@ public:
unsigned eos_count = 0;
while (!abort && !ends_with(fres, end)) {
// Sample top p and top k
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-params.n_repeat_last):nullptr, params.n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
if (id == 50256) {
if (eos_count++ == params.eos_ignores) {

View file

@ -160,7 +160,8 @@ public:
unsigned eos_count = 0;
while (!abort && !ends_with(fres, end)) {
// Sample top p and top k
auto id = llama_sample_top_p_top_k(state->ctx, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-params.n_repeat_last):nullptr, params.n_repeat_last, params.top_k, params.top_p, params.temp, params.repeat_penalty);
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
auto id = llama_sample_top_p_top_k(state->ctx, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-n_repeat_last):nullptr, n_repeat_last, params.top_k, params.top_p, params.temp, params.repeat_penalty);
if (id == llama_token_eos()) {
if (eos_count++ == params.eos_ignores) {

View file

@ -174,7 +174,8 @@ public:
unsigned eos_count = 0;
while (!abort && !ends_with(fres, end)) {
// Sample top p and top k
auto id = mpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data(), state->tokens.size(), state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
auto id = mpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
if (id == 0 || (state->has_im_end && id == state->vocab.token_to_id["<|im_end|>"])) {
if (eos_count++ == params.eos_ignores) {