mirror of
https://gitlab.com/niansa/libjustlm.git
synced 2025-03-06 20:49:17 +01:00
Added pre_tick
This commit is contained in:
parent
bcacfc3d54
commit
01b0d059ed
5 changed files with 42 additions and 28 deletions
|
@ -62,9 +62,12 @@ namespace LM {
|
|||
using ssize_t = SSIZE_T;
|
||||
#endif
|
||||
|
||||
using GenerateCallback = std::function<bool (const char *generated)>;
|
||||
using AppendCallback = std::function<bool (float progress)>;
|
||||
|
||||
class Inference {
|
||||
protected:
|
||||
std::function<bool (float)> on_scroll = nullptr;
|
||||
AppendCallback on_scroll = nullptr;
|
||||
|
||||
void *generic_state = nullptr;
|
||||
|
||||
|
@ -126,15 +129,15 @@ public:
|
|||
static
|
||||
Inference *construct(const std::string& weights_path, const Params& p);
|
||||
|
||||
void set_scroll_callback(const std::function<bool (float)>& scroll_cb) noexcept {
|
||||
void set_scroll_callback(const AppendCallback& scroll_cb) noexcept {
|
||||
on_scroll = scroll_cb;
|
||||
}
|
||||
|
||||
// This must be called with a non-empty prompt!
|
||||
virtual LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float progress)>& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
|
||||
virtual LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const AppendCallback& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
|
||||
|
||||
// append() must have been called at least once before calling this!
|
||||
virtual LM_SCHEDULABLE(std::string) run(std::string_view end = "", const std::function<bool (const char *generated)>& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
|
||||
virtual LM_SCHEDULABLE(std::string) run(std::string_view end = "", const GenerateCallback& on_tick = nullptr, const GenerateCallback& pre_tick = nullptr) LM_NOEXCEPTDECL = 0;
|
||||
|
||||
virtual unsigned get_context_size() const noexcept = 0;
|
||||
|
||||
|
|
|
@ -85,7 +85,7 @@ class GPTJInference final : public Inference {
|
|||
LM_CORETURN true;
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL {
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick = nullptr) LM_NOEXCEPTDECL {
|
||||
auto& state = get_state();
|
||||
|
||||
// Evaluate tokens in batches
|
||||
|
@ -134,7 +134,7 @@ public:
|
|||
deinit();
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
|
||||
// Append to current prompt
|
||||
|
@ -161,7 +161,7 @@ public:
|
|||
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
|
||||
LM_SCHEDULABLE(std::string) run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
std::string fres;
|
||||
|
||||
|
@ -194,11 +194,14 @@ public:
|
|||
state->prompt.append(str);
|
||||
fres.append(str);
|
||||
|
||||
// Evaluate token
|
||||
// TODO: Respect batch size
|
||||
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
|
||||
if (!gptj_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
|
||||
LM_COTHROW("Failed to evaluate new tokens", "");
|
||||
if (on_tick && !pre_tick(str.data())) abort = true;
|
||||
else {
|
||||
// Evaluate token
|
||||
// TODO: Respect batch size
|
||||
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
|
||||
if (!gptj_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
|
||||
LM_COTHROW("Failed to evaluate new tokens", "");
|
||||
}
|
||||
}
|
||||
|
||||
// Tick
|
||||
|
|
|
@ -76,7 +76,7 @@ class LLaMAInference final : public Inference {
|
|||
LM_CORETURN true;
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL {
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick = nullptr) LM_NOEXCEPTDECL {
|
||||
auto& state = get_state();
|
||||
|
||||
// Evaluate tokens in batches
|
||||
|
@ -174,7 +174,7 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
|
||||
// Check if prompt was empty
|
||||
|
@ -201,7 +201,7 @@ public:
|
|||
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
|
||||
LM_SCHEDULABLE(std::string) run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
std::string fres;
|
||||
|
||||
|
@ -240,10 +240,14 @@ public:
|
|||
state->prompt.append(str);
|
||||
fres.append(str);
|
||||
|
||||
// Evaluate token
|
||||
// TODO: Respect batch size
|
||||
if (llama_eval(state->ctx, state->tokens.data()+state->tokens.size()-1, 1, state->tokens.size()-1, params.n_threads)) {
|
||||
LM_COTHROW("Failed to evaluate new tokens", "");
|
||||
// Tick
|
||||
if (on_tick && !pre_tick(str.data())) abort = true;
|
||||
else {
|
||||
// Evaluate token
|
||||
// TODO: Respect batch size
|
||||
if (llama_eval(state->ctx, state->tokens.data()+state->tokens.size()-1, 1, state->tokens.size()-1, params.n_threads)) {
|
||||
LM_COTHROW("Failed to evaluate new tokens", "");
|
||||
}
|
||||
}
|
||||
|
||||
// Tick and yield
|
||||
|
|
|
@ -94,7 +94,7 @@ class MPTInference final : public Inference {
|
|||
LM_CORETURN true;
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL {
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick) LM_NOEXCEPTDECL {
|
||||
auto& state = get_state();
|
||||
|
||||
// Evaluate tokens in batches
|
||||
|
@ -143,7 +143,7 @@ public:
|
|||
deinit();
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
|
||||
// Append to current prompt
|
||||
|
@ -170,7 +170,7 @@ public:
|
|||
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
|
||||
LM_SCHEDULABLE(std::string) run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
std::string fres;
|
||||
|
||||
|
@ -209,11 +209,15 @@ public:
|
|||
fres.append(str);
|
||||
state->prompt.append(str);
|
||||
|
||||
// Evaluate token
|
||||
// TODO: Respect batch size
|
||||
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
|
||||
if (!mpt_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
|
||||
LM_COTHROW("Failed to evaluate new tokens", "");
|
||||
// Tick
|
||||
if (on_tick && !pre_tick(str.data())) abort = true;
|
||||
else {
|
||||
// Evaluate token
|
||||
// TODO: Respect batch size
|
||||
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
|
||||
if (!mpt_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
|
||||
LM_COTHROW("Failed to evaluate new tokens", "");
|
||||
}
|
||||
}
|
||||
|
||||
// Tick
|
||||
|
|
|
@ -32,7 +32,7 @@ PYBIND11_MODULE(justlm_py, m) {
|
|||
py::class_<Inference>(m, "Inference")
|
||||
.def_static("construct", &Inference::construct, py::arg("weights_path"), py::arg("params") = Inference::Params())
|
||||
.def("append", &Inference::append, py::arg("prompt"), py::arg("on_tick") = nullptr)
|
||||
.def("run", &Inference::run, py::arg("end") = "", py::arg("on_tick") = nullptr)
|
||||
.def("run", &Inference::run, py::arg("end") = "", py::arg("on_tick") = nullptr, py::arg("pre_tick") = nullptr)
|
||||
.def("create_savestate", &Inference::create_savestate)
|
||||
.def("restore_savestate", &Inference::restore_savestate)
|
||||
.def("get_prompt", &Inference::get_prompt)
|
||||
|
|
Loading…
Add table
Reference in a new issue