mirror of
https://gitlab.com/niansa/libjustlm.git
synced 2025-03-06 20:49:17 +01:00
Added mirostat support
This commit is contained in:
parent
ad0b7e3c71
commit
53a4623aef
6 changed files with 56 additions and 20 deletions
|
@ -77,17 +77,20 @@ public:
|
|||
unsigned n_ctx = 2024; // Context size
|
||||
unsigned n_ctx_window_top_bar = 0; // Top bar of context window. Must be smaller than context size
|
||||
unsigned n_batch = 8; // Batch size
|
||||
unsigned n_repeat_last = 0; // llama.cpp specific
|
||||
unsigned n_repeat_last = 0;
|
||||
unsigned n_eos_ignores = 0;
|
||||
|
||||
float scroll_keep = 0.0f; // 0.4f to keep 40% of context below top bar when scrolling; 0.0f to remove everything after top bar
|
||||
|
||||
unsigned top_k = 40;
|
||||
float top_p = 0.9f;
|
||||
float temp = 0.72f;
|
||||
float repeat_penalty = 1.0f; // llama.cpp specific
|
||||
unsigned eos_ignores = 0; // llama.cpp specific
|
||||
float top_p = 0.9f;
|
||||
float temp = 0.72f;
|
||||
float mirostat_learning_rate = 0.1f; // mirostat specific
|
||||
float mirostat_target_entropy = 5.0f; // mirostat specific
|
||||
float repeat_penalty = 1.0f;
|
||||
|
||||
bool use_mlock = true; // llama.cpp specific
|
||||
bool use_mlock = true; // llama specific
|
||||
int prefer_mirostat = 0; // Use given mirostat version if available (see is_mirostat_available()); llama specific
|
||||
} params;
|
||||
|
||||
struct Savestate {
|
||||
|
@ -138,6 +141,8 @@ public:
|
|||
|
||||
virtual const std::string& get_prompt() const LM_NOEXCEPTDECL = 0;
|
||||
|
||||
virtual bool is_mirostat_available() const noexcept {return false;}
|
||||
|
||||
LM_LAST_ERROR_GETTER
|
||||
};
|
||||
|
||||
|
|
|
@ -174,7 +174,7 @@ public:
|
|||
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
|
||||
|
||||
if (id == 50256) {
|
||||
if (eos_count++ == params.eos_ignores) {
|
||||
if (eos_count++ == params.n_eos_ignores) {
|
||||
abort = true;
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -127,12 +127,28 @@ class LLaMAInference final : public Inference {
|
|||
auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
|
||||
llama_sample_repetition_penalty(state->ctx, &candidates_p, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-n_repeat_last):nullptr, n_repeat_last, params.repeat_penalty);
|
||||
// Temperature sampling
|
||||
llama_sample_top_k(state->ctx, &candidates_p, params.top_k, 1);
|
||||
llama_sample_tail_free(state->ctx, &candidates_p, 1.0f, 1);
|
||||
llama_sample_typical(state->ctx, &candidates_p, 1.0f, 1);
|
||||
llama_sample_top_p(state->ctx, &candidates_p, params.top_p, 1);
|
||||
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
|
||||
return llama_sample_token(state->ctx, &candidates_p);
|
||||
switch (params.prefer_mirostat) {
|
||||
case 0: {
|
||||
llama_sample_top_k(state->ctx, &candidates_p, params.top_k, 1);
|
||||
llama_sample_tail_free(state->ctx, &candidates_p, 1.0f, 1);
|
||||
llama_sample_typical(state->ctx, &candidates_p, 1.0f, 1);
|
||||
llama_sample_top_p(state->ctx, &candidates_p, params.top_p, 1);
|
||||
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
|
||||
return llama_sample_token(state->ctx, &candidates_p);
|
||||
}
|
||||
case 1: {
|
||||
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
|
||||
const int mirostat_m = 100;
|
||||
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
|
||||
return llama_sample_token_mirostat(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, mirostat_m, &mirostat_mu);
|
||||
}
|
||||
case 2: {
|
||||
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
|
||||
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
|
||||
return llama_sample_token_mirostat_v2(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, &mirostat_mu);
|
||||
}
|
||||
default: LM_THROW("Invalid mirostat version "+std::to_string(params.prefer_mirostat), LM_BOOL_ERROR);
|
||||
}
|
||||
}
|
||||
#else
|
||||
int llama_sample_top_p_top_k() {
|
||||
|
@ -191,10 +207,15 @@ public:
|
|||
unsigned eos_count = 0;
|
||||
while (!abort && (end.empty() || fres.find(end) == fres.npos)) {
|
||||
// Sample top p and top k
|
||||
auto id = llama_sample_top_p_top_k();
|
||||
int id;
|
||||
try {
|
||||
id = llama_sample_top_p_top_k();
|
||||
} catch (const std::exception& e) {
|
||||
LM_COTHROW(e.what(), "");
|
||||
}
|
||||
|
||||
if (id == llama_token_eos()) {
|
||||
if (eos_count++ == params.eos_ignores) {
|
||||
if (eos_count++ == params.n_eos_ignores) {
|
||||
abort = true;
|
||||
continue;
|
||||
}
|
||||
|
@ -321,5 +342,11 @@ public:
|
|||
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
|
||||
return get_state()->prompt;
|
||||
}
|
||||
|
||||
#if LLAMA_DATE >= 230519
|
||||
bool is_mirostat_available() const noexcept override {
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
};
|
||||
}
|
||||
|
|
|
@ -183,13 +183,13 @@ public:
|
|||
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
|
||||
|
||||
if (state->im_end && id == state->im_end) {
|
||||
if (eos_count++ == params.eos_ignores) {
|
||||
if (eos_count++ == params.n_eos_ignores) {
|
||||
abort = true;
|
||||
continue;
|
||||
}
|
||||
id = gpt_tokenize(state->vocab, "\n")[0];
|
||||
} else if (id == 0) {
|
||||
if (eos_count++ == params.eos_ignores) {
|
||||
if (eos_count++ == params.n_eos_ignores) {
|
||||
abort = true;
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 2e6cd4b02549e343bef3768e6b946f999c82e823
|
||||
Subproject commit 29cf5596fe0c37213f9b74e80d8f631193a93f0f
|
|
@ -24,8 +24,11 @@ PYBIND11_MODULE(justlm_py, m) {
|
|||
.def_readwrite("top_p", &Inference::Params::top_p)
|
||||
.def_readwrite("temp", &Inference::Params::temp)
|
||||
.def_readwrite("repeat_penalty", &Inference::Params::repeat_penalty)
|
||||
.def_readwrite("eos_ignores", &Inference::Params::eos_ignores)
|
||||
.def_readwrite("use_mlock", &Inference::Params::use_mlock);
|
||||
.def_readwrite("eos_ignores", &Inference::Params::n_eos_ignores)
|
||||
.def_readwrite("use_mlock", &Inference::Params::use_mlock)
|
||||
.def_readwrite("prefer_mirostat", &Inference::Params::prefer_mirostat)
|
||||
.def_readwrite("mirostat_learning_rate", &Inference::Params::mirostat_learning_rate)
|
||||
.def_readwrite("mirostat_target_entropy", &Inference::Params::mirostat_target_entropy);
|
||||
py::class_<Inference>(m, "Inference")
|
||||
.def_static("construct", &Inference::construct, py::arg("weights_path"), py::arg("params") = Inference::Params())
|
||||
.def("append", &Inference::append, py::arg("prompt"), py::arg("on_tick") = nullptr)
|
||||
|
@ -34,6 +37,7 @@ PYBIND11_MODULE(justlm_py, m) {
|
|||
.def("restore_savestate", &Inference::restore_savestate)
|
||||
.def("get_prompt", &Inference::get_prompt)
|
||||
.def("get_context_size", &Inference::get_context_size)
|
||||
.def("is_mirostat_available", &Inference::is_mirostat_available)
|
||||
.def_readwrite("params", &Inference::params);
|
||||
py::class_<Inference::Savestate>(m, "Savestate")
|
||||
.def(py::init<>());
|
||||
|
|
Loading…
Add table
Reference in a new issue