mirror of
https://gitlab.com/niansa/libjustlm.git
synced 2025-03-06 20:49:17 +01:00
Cut off ending from run() result properly
This commit is contained in:
parent
08ff1e72e7
commit
d8f4efb0c9
3 changed files with 9 additions and 3 deletions
|
@ -168,7 +168,9 @@ public:
|
|||
// Loop until done
|
||||
bool abort = false;
|
||||
unsigned eos_count = 0;
|
||||
size_t last_size = 0;
|
||||
while (!abort && (end.empty() || fres.find(end) == fres.npos)) {
|
||||
last_size = fres.size();
|
||||
// Sample top p and top k
|
||||
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
|
||||
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
|
||||
|
@ -211,7 +213,7 @@ public:
|
|||
|
||||
// Create final string TODO: Could be optimized
|
||||
if (!abort) {
|
||||
fres = std::string(fres.data(), fres.size()-end.size());
|
||||
fres = std::string(fres.data(), last_size);
|
||||
}
|
||||
|
||||
// Return final string
|
||||
|
|
|
@ -208,7 +208,9 @@ public:
|
|||
// Loop until done
|
||||
bool abort = false;
|
||||
unsigned eos_count = 0;
|
||||
size_t last_size = 0;
|
||||
while (!abort && (end.empty() || fres.find(end) == fres.npos)) {
|
||||
last_size = fres.size();
|
||||
// Sample top p and top k
|
||||
int id;
|
||||
try {
|
||||
|
@ -257,7 +259,7 @@ public:
|
|||
|
||||
// Create final string TODO: Could be optimized
|
||||
if (!abort && fres.size() > end.size()) {
|
||||
fres = std::string(fres.data(), fres.size()-end.size());
|
||||
fres = std::string(fres.data(), last_size);
|
||||
}
|
||||
|
||||
// Return final string
|
||||
|
|
|
@ -177,7 +177,9 @@ public:
|
|||
// Loop until done
|
||||
bool abort = false;
|
||||
unsigned eos_count = 0;
|
||||
size_t last_size = 0;
|
||||
while (!abort && (end.empty() || fres.find(end) == fres.npos)) {
|
||||
last_size = fres.size();
|
||||
// Sample top p and top k
|
||||
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
|
||||
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
|
||||
|
@ -227,7 +229,7 @@ public:
|
|||
|
||||
// Create final string TODO: Could be optimized
|
||||
if (!abort) {
|
||||
fres = std::string(fres.data(), fres.size()-end.size());
|
||||
fres = std::string(fres.data(), last_size);
|
||||
}
|
||||
|
||||
// Return final string
|
||||
|
|
Loading…
Add table
Reference in a new issue