1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00

Added pre_tick

This commit is contained in:
niansa 2023-06-15 18:14:09 +02:00
parent bcacfc3d54
commit 01b0d059ed
5 changed files with 42 additions and 28 deletions

View file

@ -62,9 +62,12 @@ namespace LM {
using ssize_t = SSIZE_T;
#endif
using GenerateCallback = std::function<bool (const char *generated)>;
using AppendCallback = std::function<bool (float progress)>;
class Inference {
protected:
std::function<bool (float)> on_scroll = nullptr;
AppendCallback on_scroll = nullptr;
void *generic_state = nullptr;
@ -126,15 +129,15 @@ public:
static
Inference *construct(const std::string& weights_path, const Params& p);
void set_scroll_callback(const std::function<bool (float)>& scroll_cb) noexcept {
void set_scroll_callback(const AppendCallback& scroll_cb) noexcept {
on_scroll = scroll_cb;
}
// This must be called with a non-empty prompt!
virtual LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float progress)>& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
virtual LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const AppendCallback& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
// append() must have been called at least once before calling this!
virtual LM_SCHEDULABLE(std::string) run(std::string_view end = "", const std::function<bool (const char *generated)>& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
virtual LM_SCHEDULABLE(std::string) run(std::string_view end = "", const GenerateCallback& on_tick = nullptr, const GenerateCallback& pre_tick = nullptr) LM_NOEXCEPTDECL = 0;
virtual unsigned get_context_size() const noexcept = 0;

View file

@ -85,7 +85,7 @@ class GPTJInference final : public Inference {
LM_CORETURN true;
}
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL {
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick = nullptr) LM_NOEXCEPTDECL {
auto& state = get_state();
// Evaluate tokens in batches
@ -134,7 +134,7 @@ public:
deinit();
}
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
auto& state = get_state();
// Append to current prompt
@ -161,7 +161,7 @@ public:
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
}
LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(std::string) run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
auto& state = get_state();
std::string fres;
@ -194,11 +194,14 @@ public:
state->prompt.append(str);
fres.append(str);
// Evaluate token
// TODO: Respect batch size
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
if (!gptj_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
LM_COTHROW("Failed to evaluate new tokens", "");
if (on_tick && !pre_tick(str.data())) abort = true;
else {
// Evaluate token
// TODO: Respect batch size
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
if (!gptj_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
LM_COTHROW("Failed to evaluate new tokens", "");
}
}
// Tick

View file

@ -76,7 +76,7 @@ class LLaMAInference final : public Inference {
LM_CORETURN true;
}
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL {
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick = nullptr) LM_NOEXCEPTDECL {
auto& state = get_state();
// Evaluate tokens in batches
@ -174,7 +174,7 @@ public:
}
}
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
auto& state = get_state();
// Check if prompt was empty
@ -201,7 +201,7 @@ public:
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
}
LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(std::string) run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
auto& state = get_state();
std::string fres;
@ -240,10 +240,14 @@ public:
state->prompt.append(str);
fres.append(str);
// Evaluate token
// TODO: Respect batch size
if (llama_eval(state->ctx, state->tokens.data()+state->tokens.size()-1, 1, state->tokens.size()-1, params.n_threads)) {
LM_COTHROW("Failed to evaluate new tokens", "");
// Tick
if (on_tick && !pre_tick(str.data())) abort = true;
else {
// Evaluate token
// TODO: Respect batch size
if (llama_eval(state->ctx, state->tokens.data()+state->tokens.size()-1, 1, state->tokens.size()-1, params.n_threads)) {
LM_COTHROW("Failed to evaluate new tokens", "");
}
}
// Tick and yield

View file

@ -94,7 +94,7 @@ class MPTInference final : public Inference {
LM_CORETURN true;
}
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL {
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick) LM_NOEXCEPTDECL {
auto& state = get_state();
// Evaluate tokens in batches
@ -143,7 +143,7 @@ public:
deinit();
}
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
auto& state = get_state();
// Append to current prompt
@ -170,7 +170,7 @@ public:
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
}
LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(std::string) run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
auto& state = get_state();
std::string fres;
@ -209,11 +209,15 @@ public:
fres.append(str);
state->prompt.append(str);
// Evaluate token
// TODO: Respect batch size
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
if (!mpt_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
LM_COTHROW("Failed to evaluate new tokens", "");
// Tick
if (on_tick && !pre_tick(str.data())) abort = true;
else {
// Evaluate token
// TODO: Respect batch size
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
if (!mpt_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
LM_COTHROW("Failed to evaluate new tokens", "");
}
}
// Tick

View file

@ -32,7 +32,7 @@ PYBIND11_MODULE(justlm_py, m) {
py::class_<Inference>(m, "Inference")
.def_static("construct", &Inference::construct, py::arg("weights_path"), py::arg("params") = Inference::Params())
.def("append", &Inference::append, py::arg("prompt"), py::arg("on_tick") = nullptr)
.def("run", &Inference::run, py::arg("end") = "", py::arg("on_tick") = nullptr)
.def("run", &Inference::run, py::arg("end") = "", py::arg("on_tick") = nullptr, py::arg("pre_tick") = nullptr)
.def("create_savestate", &Inference::create_savestate)
.def("restore_savestate", &Inference::restore_savestate)
.def("get_prompt", &Inference::get_prompt)