mirror of
https://gitlab.com/niansa/libjustlm.git
synced 2025-03-06 20:49:17 +01:00
Removed cosched support
This commit is contained in:
parent
1a04a0e6d9
commit
90e54d66d0
7 changed files with 175 additions and 211 deletions
|
@ -9,24 +9,15 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
|||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||
|
||||
set(LM_PYBIND No CACHE BOOL "If justlm Python bindings should be build")
|
||||
set(LM_COSCHED No CACHE BOOL "If justlm should make use of CoSched")
|
||||
set(LM_NOEXCEPT No CACHE BOOL "If justlm exceptions should be disabled")
|
||||
set(LM_LLAMA Yes CACHE BOOL "If LLaMa model support should be built into justlm")
|
||||
set(LM_GPTJ Yes CACHE BOOL "If GPT-J model support should be built into justlm")
|
||||
set(LM_MPT Yes CACHE BOOL "If MPT model support should be built into justlm")
|
||||
|
||||
if (LM_COSCHED)
|
||||
set(CMAKE_CXX_STANDARD 20)
|
||||
endif()
|
||||
|
||||
|
||||
function(target_justlm_setup TARGET_NAME)
|
||||
message(STATUS "Configuring model implementation target ${TARGET_NAME}")
|
||||
target_include_directories(${TARGET_NAME} PUBLIC include/)
|
||||
if (LM_COSCHED)
|
||||
target_compile_definitions(${TARGET_NAME} PUBLIC LM_COSCHED)
|
||||
target_link_libraries(${TARGET_NAME} PUBLIC cosched)
|
||||
endif()
|
||||
if (LM_NOEXCEPT)
|
||||
target_compile_definitions(${TARGET_NAME} PUBLIC LM_NOEXCEPT)
|
||||
endif()
|
||||
|
|
|
@ -7,35 +7,20 @@
|
|||
#include <memory>
|
||||
#include <thread>
|
||||
|
||||
#ifdef LM_COSCHED
|
||||
# include <scheduler.hpp>
|
||||
# define LM_SCHEDULABLE(type) ::CoSched::AwaitableTask<type>
|
||||
# define LM_CORETURN co_return
|
||||
# define LM_COAWAIT co_await
|
||||
# define LM_TASKYIELD (co_await ::CoSched::Task::get_current().yield())
|
||||
#else
|
||||
# define LM_SCHEDULABLE(type) type
|
||||
# define LM_CORETURN return
|
||||
# define LM_COAWAIT
|
||||
# define LM_TASKYIELD (true)
|
||||
#endif
|
||||
|
||||
#ifdef LM_NOEXCEPT
|
||||
# define LM_NOEXCEPTDECL noexcept
|
||||
# define LM_THROW(t, r) this->last_error = (t); return r;
|
||||
# define LM_COTHROW(t, r) this->last_error = (t); LM_CORETURN r;
|
||||
# define LM_THROW(t, r) do {this->last_error = (t); return r;} while (0)
|
||||
# define LM_LAST_ERROR_STORAGE mutable std::string last_error;
|
||||
# define LM_LAST_ERROR_GETTER const std::string& get_last_error() const {return last_error;}
|
||||
# define LM_ERRBOOL bool
|
||||
# define LM_BOOL_ERROR false
|
||||
# define LM_BOOL_SUCCESS true
|
||||
# define LM_RETHROW(x) LM_CORETURN x;
|
||||
# define LM_RETHROW(x) return x
|
||||
# define LM_ERROR_CATCH(x, errval, ...) {auto v = x; if (v == (errval)) __VA_ARGS__}
|
||||
# define LM_ERROR_FORWARD(x, errval) {auto v = x; if (v == (errval)) LM_CORETURN x;} 0
|
||||
# define LM_ERROR_FORWARD(x, errval) do {auto v = x; if (v == (errval)) return x;} while (0)
|
||||
#else
|
||||
# define LM_NOEXCEPTDECL
|
||||
# define LM_THROW(t, r) throw Exception(t)
|
||||
# define LM_COTHROW(t, r) throw Exception(t)
|
||||
# define LM_LAST_ERROR_STORAGE
|
||||
# define LM_LAST_ERROR_GETTER
|
||||
# define LM_ERRBOOL void
|
||||
|
@ -46,12 +31,6 @@
|
|||
# define LM_ERROR_FORWARD(x, errval) {x;}
|
||||
#endif
|
||||
|
||||
#ifdef LM_COSCHED
|
||||
#ifndef LM_NOEXCEPT
|
||||
#warning Exceptions should not be enabled in combination with CoSched. Any exceptions thrown will lead to a std::terminate() call
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#if _MSC_VER
|
||||
#include <BaseTsd.h>
|
||||
#endif
|
||||
|
@ -134,24 +113,24 @@ public:
|
|||
}
|
||||
|
||||
// This must be called with a non-empty prompt!
|
||||
virtual LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const AppendCallback& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
|
||||
virtual LM_ERRBOOL append(const std::string& prompt, const AppendCallback& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
|
||||
|
||||
// append() must have been called at least once before calling this!
|
||||
virtual LM_SCHEDULABLE(std::string) run(std::string_view end = "", const GenerateCallback& on_tick = nullptr, const GenerateCallback& pre_tick = nullptr) LM_NOEXCEPTDECL = 0;
|
||||
virtual std::string run(std::string_view end = "", const GenerateCallback& on_tick = nullptr, const GenerateCallback& pre_tick = nullptr) LM_NOEXCEPTDECL = 0;
|
||||
|
||||
virtual unsigned get_context_size() const noexcept = 0;
|
||||
|
||||
virtual LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate&) const LM_NOEXCEPTDECL = 0;
|
||||
virtual LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate&) LM_NOEXCEPTDECL = 0;
|
||||
virtual LM_ERRBOOL create_savestate(Savestate&) const LM_NOEXCEPTDECL = 0;
|
||||
virtual LM_ERRBOOL restore_savestate(const Savestate&) LM_NOEXCEPTDECL = 0;
|
||||
|
||||
virtual LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream&) const LM_NOEXCEPTDECL = 0;
|
||||
virtual LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream&) LM_NOEXCEPTDECL = 0;
|
||||
virtual LM_ERRBOOL serialize(std::ostream&) const LM_NOEXCEPTDECL = 0;
|
||||
virtual LM_ERRBOOL deserialize(std::istream&) LM_NOEXCEPTDECL = 0;
|
||||
|
||||
virtual LM_SCHEDULABLE(LM_ERRBOOL) load_grammar(const std::string&, bool override_temperature [[maybe_unused]] = false) LM_NOEXCEPTDECL {
|
||||
LM_COTHROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
|
||||
virtual LM_ERRBOOL load_grammar(const std::string&, bool override_temperature [[maybe_unused]] = false) LM_NOEXCEPTDECL {
|
||||
LM_THROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
|
||||
}
|
||||
virtual LM_SCHEDULABLE(LM_ERRBOOL) unload_grammar() LM_NOEXCEPTDECL {
|
||||
LM_COTHROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
|
||||
virtual LM_ERRBOOL unload_grammar() LM_NOEXCEPTDECL {
|
||||
LM_THROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
|
||||
}
|
||||
|
||||
virtual const std::string& get_prompt() const LM_NOEXCEPTDECL = 0;
|
||||
|
|
|
@ -63,21 +63,21 @@ class InferencePool {
|
|||
}
|
||||
|
||||
// Returns false on error
|
||||
LM_SCHEDULABLE(bool) store_slot(Slot& slot);
|
||||
bool store_slot(Slot& slot);
|
||||
// Returns nullptr on error
|
||||
LM_SCHEDULABLE(Slot*) load_slot(size_t id, Slot *suggested_slot = nullptr);
|
||||
Slot *load_slot(size_t id, Slot *suggested_slot = nullptr);
|
||||
|
||||
LM_SCHEDULABLE(void) store_and_reset_slot(Slot& slot) {
|
||||
LM_COAWAIT store_slot(slot); //TODO: Should handle errors somehow
|
||||
void store_and_reset_slot(Slot& slot) {
|
||||
store_slot(slot); //TODO: Should handle errors somehow
|
||||
slot.reset();
|
||||
LM_CORETURN;
|
||||
return;
|
||||
}
|
||||
|
||||
// Doesn't fail
|
||||
LM_SCHEDULABLE(Slot*) get_free_slot();
|
||||
Slot *get_free_slot();
|
||||
|
||||
// Returns nullptr if not found
|
||||
LM_SCHEDULABLE(Slot*) find_slot_by_id(size_t id, bool deserialize = true);
|
||||
Slot *find_slot_by_id(size_t id, bool deserialize = true);
|
||||
|
||||
public:
|
||||
// The pool_name must be unique amonst all applications in cwd
|
||||
|
@ -93,14 +93,14 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(std::shared_ptr<Inference>) create_inference(size_t id, const std::string& weights_path, const Inference::Params& p) {
|
||||
auto slot = LM_COAWAIT get_free_slot();
|
||||
LM_CORETURN slot->create_inference(id, weights_path, p);
|
||||
std::shared_ptr<Inference> create_inference(size_t id, const std::string& weights_path, const Inference::Params& p) {
|
||||
auto slot = get_free_slot();
|
||||
return slot->create_inference(id, weights_path, p);
|
||||
}
|
||||
LM_SCHEDULABLE(std::shared_ptr<Inference>) get_inference(size_t id);
|
||||
LM_SCHEDULABLE(std::shared_ptr<Inference>) get_or_create_inference(size_t id, const std::string& weights_path, const Inference::Params& p);
|
||||
LM_SCHEDULABLE(void) delete_inference(size_t id);
|
||||
LM_SCHEDULABLE(void) store_all();
|
||||
std::shared_ptr<Inference> get_inference(size_t id);
|
||||
std::shared_ptr<Inference> get_or_create_inference(size_t id, const std::string& weights_path, const Inference::Params& p);
|
||||
void delete_inference(size_t id);
|
||||
void store_all();
|
||||
std::vector<size_t> get_active_slot_ids() const;
|
||||
|
||||
void cleanup();
|
||||
|
|
|
@ -59,12 +59,12 @@ class GPTJInference final : public Inference {
|
|||
|
||||
// This function reduces the size of our tokens vector according to some parameters
|
||||
// All tokens will be evaluated if scrolling was needed and true will be returned
|
||||
LM_SCHEDULABLE(bool) window_scroll() LM_NOEXCEPTDECL {
|
||||
bool window_scroll() LM_NOEXCEPTDECL {
|
||||
auto &state = get_state();
|
||||
// Check that we actually need to scroll
|
||||
if (state->tokens.size() <= params.n_ctx) {
|
||||
// Nope
|
||||
LM_CORETURN false;
|
||||
return false;
|
||||
}
|
||||
// Start scrolling
|
||||
if (params.scroll_keep > 0.0f) {
|
||||
|
@ -81,11 +81,11 @@ class GPTJInference final : public Inference {
|
|||
state->tokens.resize(params.n_ctx_window_top_bar);
|
||||
}
|
||||
// Evaluate tokens
|
||||
LM_ERROR_FORWARD(LM_COAWAIT evaluate_tokens(0, on_scroll), LM_BOOL_ERROR);
|
||||
LM_CORETURN true;
|
||||
LM_ERROR_FORWARD(evaluate_tokens(0, on_scroll), LM_BOOL_ERROR);
|
||||
return true;
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick = nullptr) LM_NOEXCEPTDECL {
|
||||
LM_ERRBOOL evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick = nullptr) LM_NOEXCEPTDECL {
|
||||
auto& state = get_state();
|
||||
|
||||
// Evaluate tokens in batches
|
||||
|
@ -96,7 +96,7 @@ class GPTJInference final : public Inference {
|
|||
// Evaluate
|
||||
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+params.n_batch);
|
||||
if (!gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
|
||||
LM_COTHROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
|
||||
}
|
||||
|
||||
// Tick
|
||||
|
@ -104,8 +104,7 @@ class GPTJInference final : public Inference {
|
|||
// Calculate progress
|
||||
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
|
||||
// Tick and yield
|
||||
if (!on_tick(progress)) LM_CORETURN LM_BOOL_SUCCESS;
|
||||
else if (!LM_TASKYIELD) LM_CORETURN LM_BOOL_SUCCESS;
|
||||
if (!on_tick(progress)) return LM_BOOL_SUCCESS;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -115,7 +114,7 @@ class GPTJInference final : public Inference {
|
|||
//TODO: This is extremely inefficient! Don't do that...
|
||||
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+1);
|
||||
if (!gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
|
||||
LM_COTHROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -123,7 +122,7 @@ class GPTJInference final : public Inference {
|
|||
// Notify about completion
|
||||
if (on_tick) on_tick(100.f);
|
||||
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -134,7 +133,7 @@ public:
|
|||
deinit();
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
|
||||
LM_ERRBOOL append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
|
||||
// Append to current prompt
|
||||
|
@ -152,16 +151,16 @@ public:
|
|||
);
|
||||
|
||||
// Make sure token limit isn't being hit
|
||||
if (LM_COAWAIT window_scroll()) {
|
||||
if (window_scroll()) {
|
||||
// That function already has evaluated our tokens since scrolling was needed
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
// Evaluate new tokens
|
||||
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
|
||||
return evaluate_tokens(old_token_count, on_tick);
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(std::string) run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
|
||||
std::string run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
std::string fres;
|
||||
|
||||
|
@ -187,7 +186,7 @@ public:
|
|||
state->tokens.push_back(id);
|
||||
|
||||
// Make sure token limit isn't being hit
|
||||
LM_COAWAIT window_scroll();
|
||||
window_scroll();
|
||||
|
||||
// Get token as string
|
||||
const std::string_view str = state->vocab.id_to_token[id];
|
||||
|
@ -202,13 +201,12 @@ public:
|
|||
// TODO: Respect batch size
|
||||
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
|
||||
if (!gptj_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
|
||||
LM_COTHROW("Failed to evaluate new tokens", "");
|
||||
LM_THROW("Failed to evaluate new tokens", "");
|
||||
}
|
||||
}
|
||||
|
||||
// Tick
|
||||
if (on_tick && !on_tick(str.data())) abort = true;
|
||||
else if (!LM_TASKYIELD) abort = true;
|
||||
}
|
||||
|
||||
// Create final string TODO: Could be optimized
|
||||
|
@ -217,59 +215,59 @@ public:
|
|||
}
|
||||
|
||||
// Return final string
|
||||
LM_CORETURN fres;
|
||||
return fres;
|
||||
}
|
||||
|
||||
unsigned get_context_size() const noexcept override {
|
||||
return get_state()->tokens.size();
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
|
||||
LM_ERRBOOL create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
sv.buf.resize(gptj_get_state_size(state->model));
|
||||
gptj_copy_state_data(state->model, state->rng, sv.buf.data());
|
||||
sv.tokens = state->tokens;
|
||||
sv.prompt = state->prompt;
|
||||
sv.ctx = generic_state;
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
|
||||
LM_ERRBOOL restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
if (sv.ctx != generic_state)
|
||||
LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR);
|
||||
LM_THROW("Savestate does not match context", LM_BOOL_ERROR);
|
||||
gptj_set_state_data(&state->model, &state->rng, sv.buf.data());
|
||||
state->tokens = sv.tokens;
|
||||
state->prompt = sv.prompt;
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
|
||||
LM_ERRBOOL serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
// Get state size
|
||||
auto state_size = gptj_get_state_size(state->model);
|
||||
// Write sizes
|
||||
for (const uint32_t s : {state->tokens.size(), state->prompt.size(), state_size}) {
|
||||
if (!o.write(reinterpret_cast<const char*>(&s), sizeof(s))) {
|
||||
LM_COTHROW("Failed to serialize data sizes", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to serialize data sizes", LM_BOOL_ERROR);
|
||||
}
|
||||
}
|
||||
// Write tokens
|
||||
if (!o.write(reinterpret_cast<const char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
||||
LM_COTHROW("Failed to serialize tokens", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to serialize tokens", LM_BOOL_ERROR);
|
||||
}
|
||||
// Write prompt
|
||||
if (!o.write(state->prompt.data(), state->prompt.size())) {
|
||||
LM_COTHROW("Failed to serialize prompt", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to serialize prompt", LM_BOOL_ERROR);
|
||||
}
|
||||
// Write state
|
||||
std::vector<uint8_t> state_buf(state_size);
|
||||
gptj_copy_state_data(state->model, state->rng, state_buf.data());
|
||||
if (!o.write(reinterpret_cast<const char*>(state_buf.data()), state_size)) {
|
||||
LM_COTHROW("Failed to serialize state", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to serialize state", LM_BOOL_ERROR);
|
||||
}
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream &i) LM_NOEXCEPTDECL override {
|
||||
LM_ERRBOOL deserialize(std::istream &i) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
uint32_t embd_size, prompt_size, state_size;
|
||||
// Initialization to prevent compiler complaints
|
||||
|
@ -277,26 +275,26 @@ public:
|
|||
// Read sizes
|
||||
for (uint32_t *s : {&embd_size, &prompt_size, &state_size}) {
|
||||
if (!i.read(reinterpret_cast<char*>(s), sizeof(*s))) {
|
||||
LM_COTHROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
|
||||
}
|
||||
}
|
||||
// Read tokens
|
||||
state->tokens.resize(embd_size);
|
||||
if (!i.read(reinterpret_cast<char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
||||
LM_COTHROW("Failed to deserialize tokens", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to deserialize tokens", LM_BOOL_ERROR);
|
||||
}
|
||||
// Read prompt
|
||||
state->prompt.resize(prompt_size);
|
||||
if (!i.read(state->prompt.data(), state->prompt.size())) {
|
||||
LM_COTHROW("Failed to deserialize prompt", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to deserialize prompt", LM_BOOL_ERROR);
|
||||
}
|
||||
// Read state
|
||||
std::vector<uint8_t> state_buf(state_size);
|
||||
if (!i.read(reinterpret_cast<char*>(state_buf.data()), state_buf.size())) {
|
||||
LM_COTHROW("Failed to deserialize state", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to deserialize state", LM_BOOL_ERROR);
|
||||
}
|
||||
gptj_set_state_data(&state->model, &state->rng, state_buf.data());
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
|
||||
return get_state()->prompt;
|
||||
|
|
|
@ -64,12 +64,12 @@ class LLaMAInference final : public Inference {
|
|||
|
||||
// This function reduces the size of our tokens vector according to some parameters
|
||||
// All tokens will be evaluated if scrolling was needed and true will be returned
|
||||
LM_SCHEDULABLE(bool) window_scroll() LM_NOEXCEPTDECL {
|
||||
bool window_scroll() LM_NOEXCEPTDECL {
|
||||
auto &state = get_state();
|
||||
// Check that we actually need to scroll
|
||||
if (state->tokens.size() <= state->n_ctx) {
|
||||
// Nope
|
||||
LM_CORETURN false;
|
||||
return false;
|
||||
}
|
||||
// Start scrolling
|
||||
if (params.scroll_keep > 0.0f) {
|
||||
|
@ -86,11 +86,11 @@ class LLaMAInference final : public Inference {
|
|||
state->tokens.resize(params.n_ctx_window_top_bar);
|
||||
}
|
||||
// Evaluate tokens
|
||||
LM_ERROR_FORWARD(LM_COAWAIT evaluate_tokens(0, on_scroll), LM_BOOL_ERROR);
|
||||
LM_CORETURN true;
|
||||
LM_ERROR_FORWARD(evaluate_tokens(0, on_scroll), LM_BOOL_ERROR);
|
||||
return true;
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick = nullptr) LM_NOEXCEPTDECL {
|
||||
LM_ERRBOOL evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick = nullptr) LM_NOEXCEPTDECL {
|
||||
auto& state = get_state();
|
||||
|
||||
// Evaluate tokens in batches
|
||||
|
@ -101,7 +101,7 @@ class LLaMAInference final : public Inference {
|
|||
// Evaluate
|
||||
const auto batch = llama_batch_get_one(state->tokens.data()+it, params.n_batch, it, 0);
|
||||
if (llama_decode(state->ctx, batch)) {
|
||||
LM_COTHROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
|
||||
}
|
||||
|
||||
// Tick
|
||||
|
@ -109,8 +109,7 @@ class LLaMAInference final : public Inference {
|
|||
// Calculate progress
|
||||
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
|
||||
// Tick and yield
|
||||
if (!on_tick(progress)) LM_CORETURN LM_BOOL_SUCCESS;
|
||||
else if (!LM_TASKYIELD) LM_CORETURN LM_BOOL_SUCCESS;
|
||||
if (!on_tick(progress)) return LM_BOOL_SUCCESS;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -119,7 +118,7 @@ class LLaMAInference final : public Inference {
|
|||
for (; it != state->tokens.size(); it++) {
|
||||
const auto batch = llama_batch_get_one(state->tokens.data()+it, 1, it, 0);
|
||||
if (llama_decode(state->ctx, batch)) {
|
||||
LM_COTHROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -127,7 +126,7 @@ class LLaMAInference final : public Inference {
|
|||
// Notify about completion
|
||||
if (on_tick) on_tick(100.f);
|
||||
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
int accept_token(int t) {
|
||||
|
@ -198,7 +197,7 @@ public:
|
|||
}
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
|
||||
LM_ERRBOOL append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
|
||||
// Check if prompt was empty
|
||||
|
@ -216,16 +215,16 @@ public:
|
|||
state->tokens.resize(old_token_count+token_count);
|
||||
|
||||
// Make sure token limit isn't being hit
|
||||
if (LM_COAWAIT window_scroll()) {
|
||||
if (window_scroll()) {
|
||||
// That function already has evaluated our tokens since scrolling was needed
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
// Evaluate new tokens
|
||||
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
|
||||
return evaluate_tokens(old_token_count, on_tick);
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(std::string) run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
|
||||
std::string run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
std::string fres;
|
||||
|
||||
|
@ -240,7 +239,7 @@ public:
|
|||
try {
|
||||
id = llama_sample_top_p_top_k();
|
||||
} catch (const std::exception& e) {
|
||||
LM_COTHROW(e.what(), "");
|
||||
LM_THROW(e.what(), "");
|
||||
}
|
||||
|
||||
if (id == llama_token_eos(state->model)) {
|
||||
|
@ -257,7 +256,7 @@ public:
|
|||
}
|
||||
|
||||
// Make sure token limit isn't hit
|
||||
LM_COAWAIT window_scroll();
|
||||
window_scroll();
|
||||
|
||||
// Get token as string
|
||||
std::string str(14, ' ');
|
||||
|
@ -274,13 +273,12 @@ public:
|
|||
// TODO: Respect batch size
|
||||
const auto batch = llama_batch_get_one(state->tokens.data()+state->tokens.size()-1, 1, state->tokens.size()-1, 0);
|
||||
if (llama_decode(state->ctx, batch)) {
|
||||
LM_COTHROW("Failed to evaluate new tokens", "");
|
||||
LM_THROW("Failed to evaluate new tokens", "");
|
||||
}
|
||||
}
|
||||
|
||||
// Tick and yield
|
||||
if (on_tick && !on_tick(str.data())) abort = true;
|
||||
else if (!LM_TASKYIELD) abort = true;
|
||||
}
|
||||
|
||||
// Create final string TODO: Could be optimized
|
||||
|
@ -289,59 +287,59 @@ public:
|
|||
}
|
||||
|
||||
// Return final string
|
||||
LM_CORETURN fres;
|
||||
return fres;
|
||||
}
|
||||
|
||||
unsigned get_context_size() const noexcept override {
|
||||
return get_state()->tokens.size();
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
|
||||
LM_ERRBOOL create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
sv.buf.resize(llama_get_state_size(state->ctx));
|
||||
llama_copy_state_data(state->ctx, sv.buf.data());
|
||||
sv.tokens = state->tokens;
|
||||
sv.prompt = state->prompt;
|
||||
sv.ctx = generic_state;
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
|
||||
LM_ERRBOOL restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
if (sv.ctx != generic_state)
|
||||
LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR);
|
||||
LM_THROW("Savestate does not match context", LM_BOOL_ERROR);
|
||||
llama_set_state_data(state->ctx, const_cast<uint8_t*>(sv.buf.data()));
|
||||
state->tokens = sv.tokens;
|
||||
state->prompt = sv.prompt;
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
|
||||
LM_ERRBOOL serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
// Get state size
|
||||
auto state_size = llama_get_state_size(state->ctx);
|
||||
// Write sizes
|
||||
for (const uint32_t s : {static_cast<size_t>(state->n_ctx), state->tokens.size(), state->prompt.size(), state_size}) {
|
||||
if (!o.write(reinterpret_cast<const char*>(&s), sizeof(s))) {
|
||||
LM_COTHROW("Failed to serialize data sizes", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to serialize data sizes", LM_BOOL_ERROR);
|
||||
}
|
||||
}
|
||||
// Write tokens
|
||||
if (!o.write(reinterpret_cast<const char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
||||
LM_COTHROW("Failed to serialize tokens", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to serialize tokens", LM_BOOL_ERROR);
|
||||
}
|
||||
// Write prompt
|
||||
if (!o.write(state->prompt.data(), state->prompt.size())) {
|
||||
LM_COTHROW("Failed to serialize prompt", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to serialize prompt", LM_BOOL_ERROR);
|
||||
}
|
||||
// Write state
|
||||
std::vector<uint8_t> state_buf(state_size);
|
||||
llama_copy_state_data(state->ctx, state_buf.data());
|
||||
if (!o.write(reinterpret_cast<const char*>(state_buf.data()), state_size)) {
|
||||
LM_COTHROW("Failed to serialize state", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to serialize state", LM_BOOL_ERROR);
|
||||
}
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream &i) LM_NOEXCEPTDECL override {
|
||||
LM_ERRBOOL deserialize(std::istream &i) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
uint32_t n_ctx, embd_size, prompt_size, state_size;
|
||||
// Initialization to prevent compiler complaints
|
||||
|
@ -349,53 +347,53 @@ public:
|
|||
// Read sizes
|
||||
for (uint32_t *s : {&n_ctx, &embd_size, &prompt_size, &state_size}) {
|
||||
if (!i.read(reinterpret_cast<char*>(s), sizeof(*s))) {
|
||||
LM_COTHROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
|
||||
}
|
||||
}
|
||||
if (state->n_ctx != n_ctx) {
|
||||
LM_COTHROW("Context length differs (My "+std::to_string(state->n_ctx)+" vs. files "+std::to_string(n_ctx)+')', LM_BOOL_ERROR);
|
||||
LM_THROW("Context length differs (My "+std::to_string(state->n_ctx)+" vs. files "+std::to_string(n_ctx)+')', LM_BOOL_ERROR);
|
||||
}
|
||||
// Read tokens
|
||||
state->tokens.resize(embd_size);
|
||||
if (!i.read(reinterpret_cast<char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
||||
LM_COTHROW("Failed to deserialize tokens", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to deserialize tokens", LM_BOOL_ERROR);
|
||||
}
|
||||
// Read prompt
|
||||
state->prompt.resize(prompt_size);
|
||||
if (!i.read(state->prompt.data(), state->prompt.size())) {
|
||||
LM_COTHROW("Failed to deserialize prompt", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to deserialize prompt", LM_BOOL_ERROR);
|
||||
}
|
||||
// Read state
|
||||
std::vector<uint8_t> state_buf(state_size);
|
||||
if (!i.read(reinterpret_cast<char*>(state_buf.data()), state_buf.size())) {
|
||||
LM_COTHROW("Failed to deserialize state", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to deserialize state", LM_BOOL_ERROR);
|
||||
}
|
||||
llama_set_state_data(state->ctx, state_buf.data());
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) load_grammar(const std::string& src, bool override_temperature) LM_NOEXCEPTDECL override {
|
||||
LM_ERRBOOL load_grammar(const std::string& src, bool override_temperature) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
|
||||
state->parsed_grammar = grammar_parser::parse(src.c_str());
|
||||
if (state->parsed_grammar.rules.empty()) {
|
||||
LM_COTHROW("Failed to parse grammar (or no rules)", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to parse grammar (or no rules)", LM_BOOL_ERROR);
|
||||
}
|
||||
|
||||
auto rules = state->parsed_grammar.c_rules();
|
||||
state->grammar = llama_grammar_init(rules.data(), rules.size(), state->parsed_grammar.symbol_ids.at("root"));
|
||||
if (!state->grammar) {
|
||||
LM_COTHROW("Failed to generate llama grammar", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to generate llama grammar", LM_BOOL_ERROR);
|
||||
}
|
||||
|
||||
state->grammar_override_temp = override_temperature;
|
||||
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) unload_grammar() LM_NOEXCEPTDECL override {
|
||||
LM_ERRBOOL unload_grammar() LM_NOEXCEPTDECL override {
|
||||
get_state()->grammar = nullptr;
|
||||
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
|
||||
|
|
|
@ -68,12 +68,12 @@ class MPTInference final : public Inference {
|
|||
|
||||
// This function reduces the size of our tokens vector according to some parameters
|
||||
// All tokens will be evaluated if scrolling was needed and true will be returned
|
||||
LM_SCHEDULABLE(bool) window_scroll() LM_NOEXCEPTDECL {
|
||||
bool window_scroll() LM_NOEXCEPTDECL {
|
||||
auto &state = get_state();
|
||||
// Check that we actually need to scroll
|
||||
if (state->tokens.size() <= params.n_ctx) {
|
||||
// Nope
|
||||
LM_CORETURN false;
|
||||
return false;
|
||||
}
|
||||
// Start scrolling
|
||||
if (params.scroll_keep > 0.0f) {
|
||||
|
@ -90,11 +90,11 @@ class MPTInference final : public Inference {
|
|||
state->tokens.resize(params.n_ctx_window_top_bar);
|
||||
}
|
||||
// Evaluate tokens
|
||||
LM_ERROR_FORWARD(LM_COAWAIT evaluate_tokens(0, on_scroll), LM_BOOL_ERROR);
|
||||
LM_CORETURN true;
|
||||
LM_ERROR_FORWARD(evaluate_tokens(0, on_scroll), LM_BOOL_ERROR);
|
||||
return true;
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick) LM_NOEXCEPTDECL {
|
||||
LM_ERRBOOL evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick) LM_NOEXCEPTDECL {
|
||||
auto& state = get_state();
|
||||
|
||||
// Evaluate tokens in batches
|
||||
|
@ -105,7 +105,7 @@ class MPTInference final : public Inference {
|
|||
// Evaluate
|
||||
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+params.n_batch);
|
||||
if (!mpt_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
|
||||
LM_COTHROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
|
||||
}
|
||||
|
||||
// Tick
|
||||
|
@ -113,8 +113,7 @@ class MPTInference final : public Inference {
|
|||
// Calculate progress
|
||||
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
|
||||
// Tick and yield
|
||||
if (!on_tick(progress)) LM_CORETURN LM_BOOL_SUCCESS;
|
||||
else if (!LM_TASKYIELD) LM_CORETURN LM_BOOL_SUCCESS;
|
||||
if (!on_tick(progress)) return LM_BOOL_SUCCESS;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -124,7 +123,7 @@ class MPTInference final : public Inference {
|
|||
//TODO: This is extremely inefficient! Don't do that...
|
||||
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+1);
|
||||
if (!mpt_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
|
||||
LM_COTHROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -132,7 +131,7 @@ class MPTInference final : public Inference {
|
|||
// Notify about completion
|
||||
if (on_tick) on_tick(100.f);
|
||||
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -143,7 +142,7 @@ public:
|
|||
deinit();
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
|
||||
LM_ERRBOOL append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
|
||||
// Append to current prompt
|
||||
|
@ -161,16 +160,16 @@ public:
|
|||
);
|
||||
|
||||
// Make sure token limit isn't being hit
|
||||
if (LM_COAWAIT window_scroll()) {
|
||||
if (window_scroll()) {
|
||||
// That function already has evaluated our tokens since scrolling was needed
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
// Evaluate new tokens
|
||||
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
|
||||
return evaluate_tokens(old_token_count, on_tick);
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(std::string) run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
|
||||
std::string run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
std::string fres;
|
||||
|
||||
|
@ -202,7 +201,7 @@ public:
|
|||
state->tokens.push_back(id);
|
||||
|
||||
// Make sure token limit isn't being hit
|
||||
LM_COAWAIT window_scroll();
|
||||
window_scroll();
|
||||
|
||||
// Get token as string
|
||||
const std::string_view str = state->vocab.id_to_token[id];
|
||||
|
@ -218,13 +217,12 @@ public:
|
|||
// TODO: Respect batch size
|
||||
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
|
||||
if (!mpt_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
|
||||
LM_COTHROW("Failed to evaluate new tokens", "");
|
||||
LM_THROW("Failed to evaluate new tokens", "");
|
||||
}
|
||||
}
|
||||
|
||||
// Tick
|
||||
if (on_tick && !on_tick(str.data())) abort = true;
|
||||
else if (!LM_TASKYIELD) abort = true;
|
||||
}
|
||||
|
||||
// Create final string TODO: Could be optimized
|
||||
|
@ -233,59 +231,59 @@ public:
|
|||
}
|
||||
|
||||
// Return final string
|
||||
LM_CORETURN fres;
|
||||
return fres;
|
||||
}
|
||||
|
||||
unsigned get_context_size() const noexcept override {
|
||||
return get_state()->tokens.size();
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
|
||||
LM_ERRBOOL create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
sv.buf.resize(mpt_get_state_size(state->model));
|
||||
mpt_copy_state_data(state->model, state->rng, sv.buf.data());
|
||||
sv.tokens = state->tokens;
|
||||
sv.prompt = state->prompt;
|
||||
sv.ctx = generic_state;
|
||||
LM_CORETURN LM_BOOL_SUCCESS ;
|
||||
return LM_BOOL_SUCCESS ;
|
||||
}
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
|
||||
LM_ERRBOOL restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
if (sv.ctx != generic_state)
|
||||
LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR);
|
||||
LM_THROW("Savestate does not match context", LM_BOOL_ERROR);
|
||||
mpt_set_state_data(&state->model, &state->rng, sv.buf.data());
|
||||
state->tokens = sv.tokens;
|
||||
state->prompt = sv.prompt;
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
|
||||
LM_ERRBOOL serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
// Get state size
|
||||
auto state_size = mpt_get_state_size(state->model);
|
||||
// Write sizes
|
||||
for (const uint32_t s : {state->tokens.size(), state->prompt.size(), state_size}) {
|
||||
if (!o.write(reinterpret_cast<const char*>(&s), sizeof(s))) {
|
||||
LM_COTHROW("Failed to serialize data sizes", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to serialize data sizes", LM_BOOL_ERROR);
|
||||
}
|
||||
}
|
||||
// Write tokens
|
||||
if (!o.write(reinterpret_cast<const char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
||||
LM_COTHROW("Failed to serialize tokens", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to serialize tokens", LM_BOOL_ERROR);
|
||||
}
|
||||
// Write prompt
|
||||
if (!o.write(state->prompt.data(), state->prompt.size())) {
|
||||
LM_COTHROW("Failed to serialize prompt", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to serialize prompt", LM_BOOL_ERROR);
|
||||
}
|
||||
// Write state
|
||||
std::vector<uint8_t> state_buf(state_size);
|
||||
mpt_copy_state_data(state->model, state->rng, state_buf.data());
|
||||
if (!o.write(reinterpret_cast<const char*>(state_buf.data()), state_size)) {
|
||||
LM_COTHROW("Failed to serialize state", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to serialize state", LM_BOOL_ERROR);
|
||||
}
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream &i) LM_NOEXCEPTDECL override {
|
||||
LM_ERRBOOL deserialize(std::istream &i) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
uint32_t embd_size, promptsize, state_size;
|
||||
// Initialization to prevent compiler complaints
|
||||
|
@ -293,26 +291,26 @@ public:
|
|||
// Read sizes
|
||||
for (uint32_t *s : {&embd_size, &promptsize, &state_size}) {
|
||||
if (!i.read(reinterpret_cast<char*>(s), sizeof(*s))) {
|
||||
LM_COTHROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
|
||||
}
|
||||
}
|
||||
// Read tokens
|
||||
state->tokens.resize(embd_size);
|
||||
if (!i.read(reinterpret_cast<char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
||||
LM_COTHROW("Failed to deserialize tokens", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to deserialize tokens", LM_BOOL_ERROR);
|
||||
}
|
||||
// Read prompt
|
||||
state->prompt.resize(promptsize);
|
||||
if (!i.read(state->prompt.data(), state->prompt.size())) {
|
||||
LM_COTHROW("Failed to deserialize prompt", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to deserialize prompt", LM_BOOL_ERROR);
|
||||
}
|
||||
// Read state
|
||||
std::vector<uint8_t> state_buf(state_size);
|
||||
if (!i.read(reinterpret_cast<char*>(state_buf.data()), state_buf.size())) {
|
||||
LM_COTHROW("Failed to deserialize state", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to deserialize state", LM_BOOL_ERROR);
|
||||
}
|
||||
mpt_set_state_data(&state->model, &state->rng, state_buf.data());
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
return LM_BOOL_SUCCESS;
|
||||
}
|
||||
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
|
||||
return get_state()->prompt;
|
||||
|
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
|
||||
|
||||
LM_SCHEDULABLE(bool) LM::InferencePool::store_slot(Slot &slot) {
|
||||
bool LM::InferencePool::store_slot(Slot &slot) {
|
||||
auto inference = slot.get_inference();
|
||||
// Open output file
|
||||
std::ofstream f(get_slot_filename(slot.get_id()), std::ios::binary);
|
||||
|
@ -17,61 +17,61 @@ LM_SCHEDULABLE(bool) LM::InferencePool::store_slot(Slot &slot) {
|
|||
f.write(weights_path.data(), weights_path.size());
|
||||
// Write params
|
||||
if (!f.write(reinterpret_cast<const char*>(&inference->params), sizeof(inference->params))) {
|
||||
LM_CORETURN false;
|
||||
return false;
|
||||
}
|
||||
// Serialize instance
|
||||
try {
|
||||
LM_COAWAIT inference->serialize(f);
|
||||
inference->serialize(f);
|
||||
} catch (...) {
|
||||
LM_CORETURN false;
|
||||
return false;
|
||||
}
|
||||
// Return success
|
||||
LM_CORETURN true;
|
||||
return true;
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::load_slot(size_t id, Slot *suggested_slot) {
|
||||
LM::InferencePool::Slot *LM::InferencePool::load_slot(size_t id, Slot *suggested_slot) {
|
||||
// Open input file
|
||||
std::ifstream f(get_slot_filename(id), std::ios::binary);
|
||||
if (!f) {
|
||||
// Does not exist
|
||||
LM_CORETURN nullptr;
|
||||
return nullptr;
|
||||
}
|
||||
// Read weights path
|
||||
std::string weights_path;
|
||||
uint32_t weights_path_len;
|
||||
if (!f.read(reinterpret_cast<char*>(&weights_path_len), sizeof(weights_path_len))) {
|
||||
LM_CORETURN nullptr;
|
||||
return nullptr;
|
||||
}
|
||||
weights_path.resize(weights_path_len);
|
||||
if (!f.read(weights_path.data(), weights_path.size())) {
|
||||
LM_CORETURN nullptr;
|
||||
return nullptr;
|
||||
}
|
||||
// Read params
|
||||
LM::Inference::Params p;
|
||||
if (!f.read(reinterpret_cast<char*>(&p), sizeof(p))) {
|
||||
LM_CORETURN nullptr;
|
||||
return nullptr;
|
||||
}
|
||||
// Create instance
|
||||
auto& slot = suggested_slot?*suggested_slot:*(LM_COAWAIT get_free_slot());
|
||||
auto& slot = suggested_slot?*suggested_slot:*(get_free_slot());
|
||||
auto inference = slot.create_inference(id, weights_path, p);
|
||||
// Deserialize instance
|
||||
try {
|
||||
LM_COAWAIT inference->deserialize(f);
|
||||
inference->deserialize(f);
|
||||
} catch (...) {
|
||||
slot.reset();
|
||||
LM_CORETURN nullptr;
|
||||
return nullptr;
|
||||
}
|
||||
// Return final slot
|
||||
LM_CORETURN &slot;
|
||||
return &slot;
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::get_free_slot() {
|
||||
LM::InferencePool::Slot *LM::InferencePool::get_free_slot() {
|
||||
// Attempt to find free slot while finding oldest one
|
||||
Slot *oldest = nullptr;
|
||||
for (auto& slot : slots) {
|
||||
// Take free slot
|
||||
if (slot.is_free()) {
|
||||
LM_CORETURN &slot;
|
||||
return &slot;
|
||||
}
|
||||
// Update oldest
|
||||
if (oldest == nullptr || slot.get_last_access() < oldest->get_last_access()) {
|
||||
|
@ -80,17 +80,17 @@ LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::get_free_slot() {
|
|||
}
|
||||
// Free up oldest slot and take that one
|
||||
// Note: Since there has to be at least 1 slot, oldest is never going to be a nullptr
|
||||
LM_COAWAIT store_and_reset_slot(*oldest);
|
||||
LM_CORETURN oldest;
|
||||
store_and_reset_slot(*oldest);
|
||||
return oldest;
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::find_slot_by_id(size_t id, bool deserialize) {
|
||||
LM::InferencePool::Slot *LM::InferencePool::find_slot_by_id(size_t id, bool deserialize) {
|
||||
// Attempt to find given slot while finding oldest one
|
||||
Slot *oldest = nullptr;
|
||||
for (auto& slot : slots) {
|
||||
// Take slot with ID
|
||||
if (slot.get_id() == id) {
|
||||
LM_CORETURN &slot;
|
||||
return &slot;
|
||||
}
|
||||
// Update oldest
|
||||
if (oldest == nullptr || slot.get_last_access() < oldest->get_last_access()) {
|
||||
|
@ -99,38 +99,38 @@ LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::find_slot_by_id(size
|
|||
}
|
||||
// Slot not found, attempt to load it
|
||||
if (deserialize) {
|
||||
if (!oldest->is_free()) LM_COAWAIT store_slot(*oldest);
|
||||
if (!LM_COAWAIT load_slot(id, oldest)) {
|
||||
if (!oldest->is_free()) store_slot(*oldest);
|
||||
if (!load_slot(id, oldest)) {
|
||||
// In case slot loading failed, still reset slot for later use
|
||||
//TODO: Make this configurable
|
||||
oldest->reset();
|
||||
} else {
|
||||
LM_CORETURN oldest;
|
||||
return oldest;
|
||||
}
|
||||
}
|
||||
// Slot not found
|
||||
LM_CORETURN nullptr;
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(std::shared_ptr<LM::Inference>) LM::InferencePool::get_inference(size_t id) {
|
||||
auto slot = LM_COAWAIT find_slot_by_id(id);
|
||||
std::shared_ptr<LM::Inference> LM::InferencePool::get_inference(size_t id) {
|
||||
auto slot = find_slot_by_id(id);
|
||||
if (slot) {
|
||||
LM_CORETURN slot->get_inference(true);
|
||||
return slot->get_inference(true);
|
||||
}
|
||||
LM_CORETURN {};
|
||||
return {};
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(std::shared_ptr<LM::Inference>) LM::InferencePool::get_or_create_inference(size_t id, const std::string &weights_path, const Inference::Params &p) {
|
||||
auto slot = LM_COAWAIT find_slot_by_id(id);
|
||||
std::shared_ptr<LM::Inference> LM::InferencePool::get_or_create_inference(size_t id, const std::string &weights_path, const Inference::Params &p) {
|
||||
auto slot = find_slot_by_id(id);
|
||||
if (slot) {
|
||||
LM_CORETURN slot->get_inference(true);
|
||||
return slot->get_inference(true);
|
||||
}
|
||||
slot = LM_COAWAIT get_free_slot();
|
||||
LM_CORETURN slot->create_inference(id, weights_path, p);
|
||||
slot = get_free_slot();
|
||||
return slot->create_inference(id, weights_path, p);
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(void) LM::InferencePool::delete_inference(size_t id) {
|
||||
auto slot = LM_COAWAIT find_slot_by_id(id, false);
|
||||
void LM::InferencePool::delete_inference(size_t id) {
|
||||
auto slot = find_slot_by_id(id, false);
|
||||
// Reset slot
|
||||
if (slot) {
|
||||
slot->reset();
|
||||
|
@ -140,12 +140,12 @@ LM_SCHEDULABLE(void) LM::InferencePool::delete_inference(size_t id) {
|
|||
std::filesystem::remove(get_slot_filename(id), ec);
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(void) LM::InferencePool::store_all() {
|
||||
void LM::InferencePool::store_all() {
|
||||
for (auto& slot : slots) {
|
||||
if (slot.is_free()) continue;
|
||||
LM_COAWAIT store_slot(slot);
|
||||
store_slot(slot);
|
||||
}
|
||||
LM_CORETURN;
|
||||
return;
|
||||
}
|
||||
|
||||
std::vector<size_t> LM::InferencePool::get_active_slot_ids() const {
|
||||
|
|
Loading…
Add table
Reference in a new issue