1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00

Fixed an abort()

This commit is contained in:
niansa 2023-05-20 02:53:32 +02:00
parent 5feca59be7
commit 30a0a77cb2
4 changed files with 4 additions and 4 deletions

View file

@ -168,7 +168,7 @@ public:
// Loop until done
bool abort = false;
unsigned eos_count = 0;
while (!abort && fres.find(end) != fres.npos) {
while (!abort && fres.size() >= end.size() && fres.find(end) != fres.npos) {
// Sample top p and top k
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);

View file

@ -181,7 +181,7 @@ public:
// Loop until done
bool abort = false;
unsigned eos_count = 0;
while (!abort && fres.find(end) != fres.npos) {
while (!abort && fres.size() >= end.size() && fres.find(end) != fres.npos) {
// Sample top p and top k
auto id = llama_sample_top_p_top_k();

View file

@ -158,7 +158,7 @@ public:
// Loop until done
bool abort = false;
unsigned eos_count = 0;
while (!abort && fres.find(end) != fres.npos) {
while (!abort && fres.size() >= end.size() && fres.find(end) != fres.npos) {
// Sample top p and top k
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
auto id = llama_sample_top_p_top_k(state->ctx, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-n_repeat_last):nullptr, n_repeat_last, params.top_k, params.top_p, params.temp, params.repeat_penalty);

View file

@ -177,7 +177,7 @@ public:
// Loop until done
bool abort = false;
unsigned eos_count = 0;
while (!abort && fres.find(end) != fres.npos) {
while (!abort && fres.size() >= end.size() && fres.find(end) != fres.npos) {
// Sample top p and top k
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);