1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00

Added mirostat support

This commit is contained in:
niansa 2023-05-26 00:43:07 +02:00
parent ad0b7e3c71
commit 53a4623aef
6 changed files with 56 additions and 20 deletions

View file

@ -77,17 +77,20 @@ public:
unsigned n_ctx = 2024; // Context size
unsigned n_ctx_window_top_bar = 0; // Top bar of context window. Must be smaller than context size
unsigned n_batch = 8; // Batch size
unsigned n_repeat_last = 0; // llama.cpp specific
unsigned n_repeat_last = 0;
unsigned n_eos_ignores = 0;
float scroll_keep = 0.0f; // 0.4f to keep 40% of context below top bar when scrolling; 0.0f to remove everything after top bar
unsigned top_k = 40;
float top_p = 0.9f;
float temp = 0.72f;
float repeat_penalty = 1.0f; // llama.cpp specific
unsigned eos_ignores = 0; // llama.cpp specific
float top_p = 0.9f;
float temp = 0.72f;
float mirostat_learning_rate = 0.1f; // mirostat specific
float mirostat_target_entropy = 5.0f; // mirostat specific
float repeat_penalty = 1.0f;
bool use_mlock = true; // llama.cpp specific
bool use_mlock = true; // llama specific
int prefer_mirostat = 0; // Use given mirostat version if available (see is_mirostat_available()); llama specific
} params;
struct Savestate {
@ -138,6 +141,8 @@ public:
virtual const std::string& get_prompt() const LM_NOEXCEPTDECL = 0;
virtual bool is_mirostat_available() const noexcept {return false;}
LM_LAST_ERROR_GETTER
};

View file

@ -174,7 +174,7 @@ public:
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
if (id == 50256) {
if (eos_count++ == params.eos_ignores) {
if (eos_count++ == params.n_eos_ignores) {
abort = true;
continue;
}

View file

@ -127,12 +127,28 @@ class LLaMAInference final : public Inference {
auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
llama_sample_repetition_penalty(state->ctx, &candidates_p, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-n_repeat_last):nullptr, n_repeat_last, params.repeat_penalty);
// Temperature sampling
llama_sample_top_k(state->ctx, &candidates_p, params.top_k, 1);
llama_sample_tail_free(state->ctx, &candidates_p, 1.0f, 1);
llama_sample_typical(state->ctx, &candidates_p, 1.0f, 1);
llama_sample_top_p(state->ctx, &candidates_p, params.top_p, 1);
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
return llama_sample_token(state->ctx, &candidates_p);
switch (params.prefer_mirostat) {
case 0: {
llama_sample_top_k(state->ctx, &candidates_p, params.top_k, 1);
llama_sample_tail_free(state->ctx, &candidates_p, 1.0f, 1);
llama_sample_typical(state->ctx, &candidates_p, 1.0f, 1);
llama_sample_top_p(state->ctx, &candidates_p, params.top_p, 1);
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
return llama_sample_token(state->ctx, &candidates_p);
}
case 1: {
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
const int mirostat_m = 100;
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
return llama_sample_token_mirostat(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, mirostat_m, &mirostat_mu);
}
case 2: {
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
return llama_sample_token_mirostat_v2(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, &mirostat_mu);
}
default: LM_THROW("Invalid mirostat version "+std::to_string(params.prefer_mirostat), LM_BOOL_ERROR);
}
}
#else
int llama_sample_top_p_top_k() {
@ -191,10 +207,15 @@ public:
unsigned eos_count = 0;
while (!abort && (end.empty() || fres.find(end) == fres.npos)) {
// Sample top p and top k
auto id = llama_sample_top_p_top_k();
int id;
try {
id = llama_sample_top_p_top_k();
} catch (const std::exception& e) {
LM_COTHROW(e.what(), "");
}
if (id == llama_token_eos()) {
if (eos_count++ == params.eos_ignores) {
if (eos_count++ == params.n_eos_ignores) {
abort = true;
continue;
}
@ -321,5 +342,11 @@ public:
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
return get_state()->prompt;
}
#if LLAMA_DATE >= 230519
bool is_mirostat_available() const noexcept override {
return true;
}
#endif
};
}

View file

@ -183,13 +183,13 @@ public:
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
if (state->im_end && id == state->im_end) {
if (eos_count++ == params.eos_ignores) {
if (eos_count++ == params.n_eos_ignores) {
abort = true;
continue;
}
id = gpt_tokenize(state->vocab, "\n")[0];
} else if (id == 0) {
if (eos_count++ == params.eos_ignores) {
if (eos_count++ == params.n_eos_ignores) {
abort = true;
continue;
}

@ -1 +1 @@
Subproject commit 2e6cd4b02549e343bef3768e6b946f999c82e823
Subproject commit 29cf5596fe0c37213f9b74e80d8f631193a93f0f

View file

@ -24,8 +24,11 @@ PYBIND11_MODULE(justlm_py, m) {
.def_readwrite("top_p", &Inference::Params::top_p)
.def_readwrite("temp", &Inference::Params::temp)
.def_readwrite("repeat_penalty", &Inference::Params::repeat_penalty)
.def_readwrite("eos_ignores", &Inference::Params::eos_ignores)
.def_readwrite("use_mlock", &Inference::Params::use_mlock);
.def_readwrite("eos_ignores", &Inference::Params::n_eos_ignores)
.def_readwrite("use_mlock", &Inference::Params::use_mlock)
.def_readwrite("prefer_mirostat", &Inference::Params::prefer_mirostat)
.def_readwrite("mirostat_learning_rate", &Inference::Params::mirostat_learning_rate)
.def_readwrite("mirostat_target_entropy", &Inference::Params::mirostat_target_entropy);
py::class_<Inference>(m, "Inference")
.def_static("construct", &Inference::construct, py::arg("weights_path"), py::arg("params") = Inference::Params())
.def("append", &Inference::append, py::arg("prompt"), py::arg("on_tick") = nullptr)
@ -34,6 +37,7 @@ PYBIND11_MODULE(justlm_py, m) {
.def("restore_savestate", &Inference::restore_savestate)
.def("get_prompt", &Inference::get_prompt)
.def("get_context_size", &Inference::get_context_size)
.def("is_mirostat_available", &Inference::is_mirostat_available)
.def_readwrite("params", &Inference::params);
py::class_<Inference::Savestate>(m, "Savestate")
.def(py::init<>());