From 90e54d66d0dca054895588575d4355293d8ec1fb Mon Sep 17 00:00:00 2001 From: niansa Date: Mon, 25 Mar 2024 01:18:37 +0100 Subject: [PATCH] Removed cosched support --- CMakeLists.txt | 9 ----- include/justlm.hpp | 47 +++++++--------------- include/justlm_pool.hpp | 28 +++++++------- justlm_gptj.hpp | 70 ++++++++++++++++----------------- justlm_llama.hpp | 86 ++++++++++++++++++++--------------------- justlm_mpt.hpp | 70 ++++++++++++++++----------------- justlm_pool.cpp | 76 ++++++++++++++++++------------------ 7 files changed, 175 insertions(+), 211 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4bc303f..a7928cf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,24 +9,15 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) set(LM_PYBIND No CACHE BOOL "If justlm Python bindings should be build") -set(LM_COSCHED No CACHE BOOL "If justlm should make use of CoSched") set(LM_NOEXCEPT No CACHE BOOL "If justlm exceptions should be disabled") set(LM_LLAMA Yes CACHE BOOL "If LLaMa model support should be built into justlm") set(LM_GPTJ Yes CACHE BOOL "If GPT-J model support should be built into justlm") set(LM_MPT Yes CACHE BOOL "If MPT model support should be built into justlm") -if (LM_COSCHED) - set(CMAKE_CXX_STANDARD 20) -endif() - function(target_justlm_setup TARGET_NAME) message(STATUS "Configuring model implementation target ${TARGET_NAME}") target_include_directories(${TARGET_NAME} PUBLIC include/) - if (LM_COSCHED) - target_compile_definitions(${TARGET_NAME} PUBLIC LM_COSCHED) - target_link_libraries(${TARGET_NAME} PUBLIC cosched) - endif() if (LM_NOEXCEPT) target_compile_definitions(${TARGET_NAME} PUBLIC LM_NOEXCEPT) endif() diff --git a/include/justlm.hpp b/include/justlm.hpp index 59b9470..ceffb32 100644 --- a/include/justlm.hpp +++ b/include/justlm.hpp @@ -7,35 +7,20 @@ #include #include -#ifdef LM_COSCHED -# include -# define LM_SCHEDULABLE(type) ::CoSched::AwaitableTask -# define LM_CORETURN co_return -# define LM_COAWAIT co_await -# define LM_TASKYIELD (co_await ::CoSched::Task::get_current().yield()) -#else -# define LM_SCHEDULABLE(type) type -# define LM_CORETURN return -# define LM_COAWAIT -# define LM_TASKYIELD (true) -#endif - #ifdef LM_NOEXCEPT # define LM_NOEXCEPTDECL noexcept -# define LM_THROW(t, r) this->last_error = (t); return r; -# define LM_COTHROW(t, r) this->last_error = (t); LM_CORETURN r; +# define LM_THROW(t, r) do {this->last_error = (t); return r;} while (0) # define LM_LAST_ERROR_STORAGE mutable std::string last_error; # define LM_LAST_ERROR_GETTER const std::string& get_last_error() const {return last_error;} # define LM_ERRBOOL bool # define LM_BOOL_ERROR false # define LM_BOOL_SUCCESS true -# define LM_RETHROW(x) LM_CORETURN x; +# define LM_RETHROW(x) return x # define LM_ERROR_CATCH(x, errval, ...) {auto v = x; if (v == (errval)) __VA_ARGS__} -# define LM_ERROR_FORWARD(x, errval) {auto v = x; if (v == (errval)) LM_CORETURN x;} 0 +# define LM_ERROR_FORWARD(x, errval) do {auto v = x; if (v == (errval)) return x;} while (0) #else # define LM_NOEXCEPTDECL # define LM_THROW(t, r) throw Exception(t) -# define LM_COTHROW(t, r) throw Exception(t) # define LM_LAST_ERROR_STORAGE # define LM_LAST_ERROR_GETTER # define LM_ERRBOOL void @@ -46,12 +31,6 @@ # define LM_ERROR_FORWARD(x, errval) {x;} #endif -#ifdef LM_COSCHED -#ifndef LM_NOEXCEPT -#warning Exceptions should not be enabled in combination with CoSched. Any exceptions thrown will lead to a std::terminate() call -#endif -#endif - #if _MSC_VER #include #endif @@ -134,24 +113,24 @@ public: } // This must be called with a non-empty prompt! - virtual LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const AppendCallback& on_tick = nullptr) LM_NOEXCEPTDECL = 0; + virtual LM_ERRBOOL append(const std::string& prompt, const AppendCallback& on_tick = nullptr) LM_NOEXCEPTDECL = 0; // append() must have been called at least once before calling this! - virtual LM_SCHEDULABLE(std::string) run(std::string_view end = "", const GenerateCallback& on_tick = nullptr, const GenerateCallback& pre_tick = nullptr) LM_NOEXCEPTDECL = 0; + virtual std::string run(std::string_view end = "", const GenerateCallback& on_tick = nullptr, const GenerateCallback& pre_tick = nullptr) LM_NOEXCEPTDECL = 0; virtual unsigned get_context_size() const noexcept = 0; - virtual LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate&) const LM_NOEXCEPTDECL = 0; - virtual LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate&) LM_NOEXCEPTDECL = 0; + virtual LM_ERRBOOL create_savestate(Savestate&) const LM_NOEXCEPTDECL = 0; + virtual LM_ERRBOOL restore_savestate(const Savestate&) LM_NOEXCEPTDECL = 0; - virtual LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream&) const LM_NOEXCEPTDECL = 0; - virtual LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream&) LM_NOEXCEPTDECL = 0; + virtual LM_ERRBOOL serialize(std::ostream&) const LM_NOEXCEPTDECL = 0; + virtual LM_ERRBOOL deserialize(std::istream&) LM_NOEXCEPTDECL = 0; - virtual LM_SCHEDULABLE(LM_ERRBOOL) load_grammar(const std::string&, bool override_temperature [[maybe_unused]] = false) LM_NOEXCEPTDECL { - LM_COTHROW("Grammar is not available for this models backend", LM_BOOL_ERROR); + virtual LM_ERRBOOL load_grammar(const std::string&, bool override_temperature [[maybe_unused]] = false) LM_NOEXCEPTDECL { + LM_THROW("Grammar is not available for this models backend", LM_BOOL_ERROR); } - virtual LM_SCHEDULABLE(LM_ERRBOOL) unload_grammar() LM_NOEXCEPTDECL { - LM_COTHROW("Grammar is not available for this models backend", LM_BOOL_ERROR); + virtual LM_ERRBOOL unload_grammar() LM_NOEXCEPTDECL { + LM_THROW("Grammar is not available for this models backend", LM_BOOL_ERROR); } virtual const std::string& get_prompt() const LM_NOEXCEPTDECL = 0; diff --git a/include/justlm_pool.hpp b/include/justlm_pool.hpp index 6a33e22..0cf3c73 100644 --- a/include/justlm_pool.hpp +++ b/include/justlm_pool.hpp @@ -63,21 +63,21 @@ class InferencePool { } // Returns false on error - LM_SCHEDULABLE(bool) store_slot(Slot& slot); + bool store_slot(Slot& slot); // Returns nullptr on error - LM_SCHEDULABLE(Slot*) load_slot(size_t id, Slot *suggested_slot = nullptr); + Slot *load_slot(size_t id, Slot *suggested_slot = nullptr); - LM_SCHEDULABLE(void) store_and_reset_slot(Slot& slot) { - LM_COAWAIT store_slot(slot); //TODO: Should handle errors somehow + void store_and_reset_slot(Slot& slot) { + store_slot(slot); //TODO: Should handle errors somehow slot.reset(); - LM_CORETURN; + return; } // Doesn't fail - LM_SCHEDULABLE(Slot*) get_free_slot(); + Slot *get_free_slot(); // Returns nullptr if not found - LM_SCHEDULABLE(Slot*) find_slot_by_id(size_t id, bool deserialize = true); + Slot *find_slot_by_id(size_t id, bool deserialize = true); public: // The pool_name must be unique amonst all applications in cwd @@ -93,14 +93,14 @@ public: } } - LM_SCHEDULABLE(std::shared_ptr) create_inference(size_t id, const std::string& weights_path, const Inference::Params& p) { - auto slot = LM_COAWAIT get_free_slot(); - LM_CORETURN slot->create_inference(id, weights_path, p); + std::shared_ptr create_inference(size_t id, const std::string& weights_path, const Inference::Params& p) { + auto slot = get_free_slot(); + return slot->create_inference(id, weights_path, p); } - LM_SCHEDULABLE(std::shared_ptr) get_inference(size_t id); - LM_SCHEDULABLE(std::shared_ptr) get_or_create_inference(size_t id, const std::string& weights_path, const Inference::Params& p); - LM_SCHEDULABLE(void) delete_inference(size_t id); - LM_SCHEDULABLE(void) store_all(); + std::shared_ptr get_inference(size_t id); + std::shared_ptr get_or_create_inference(size_t id, const std::string& weights_path, const Inference::Params& p); + void delete_inference(size_t id); + void store_all(); std::vector get_active_slot_ids() const; void cleanup(); diff --git a/justlm_gptj.hpp b/justlm_gptj.hpp index 38c889f..bffb20a 100644 --- a/justlm_gptj.hpp +++ b/justlm_gptj.hpp @@ -59,12 +59,12 @@ class GPTJInference final : public Inference { // This function reduces the size of our tokens vector according to some parameters // All tokens will be evaluated if scrolling was needed and true will be returned - LM_SCHEDULABLE(bool) window_scroll() LM_NOEXCEPTDECL { + bool window_scroll() LM_NOEXCEPTDECL { auto &state = get_state(); // Check that we actually need to scroll if (state->tokens.size() <= params.n_ctx) { // Nope - LM_CORETURN false; + return false; } // Start scrolling if (params.scroll_keep > 0.0f) { @@ -81,11 +81,11 @@ class GPTJInference final : public Inference { state->tokens.resize(params.n_ctx_window_top_bar); } // Evaluate tokens - LM_ERROR_FORWARD(LM_COAWAIT evaluate_tokens(0, on_scroll), LM_BOOL_ERROR); - LM_CORETURN true; + LM_ERROR_FORWARD(evaluate_tokens(0, on_scroll), LM_BOOL_ERROR); + return true; } - LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick = nullptr) LM_NOEXCEPTDECL { + LM_ERRBOOL evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick = nullptr) LM_NOEXCEPTDECL { auto& state = get_state(); // Evaluate tokens in batches @@ -96,7 +96,7 @@ class GPTJInference final : public Inference { // Evaluate std::vector batch(state->tokens.begin()+it, state->tokens.begin()+it+params.n_batch); if (!gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) { - LM_COTHROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR); + LM_THROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR); } // Tick @@ -104,8 +104,7 @@ class GPTJInference final : public Inference { // Calculate progress auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f; // Tick and yield - if (!on_tick(progress)) LM_CORETURN LM_BOOL_SUCCESS; - else if (!LM_TASKYIELD) LM_CORETURN LM_BOOL_SUCCESS; + if (!on_tick(progress)) return LM_BOOL_SUCCESS; } } @@ -115,7 +114,7 @@ class GPTJInference final : public Inference { //TODO: This is extremely inefficient! Don't do that... std::vector batch(state->tokens.begin()+it, state->tokens.begin()+it+1); if (!gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) { - LM_COTHROW("Failed to evaluate individual tokens", LM_BOOL_ERROR); + LM_THROW("Failed to evaluate individual tokens", LM_BOOL_ERROR); } } } @@ -123,7 +122,7 @@ class GPTJInference final : public Inference { // Notify about completion if (on_tick) on_tick(100.f); - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } public: @@ -134,7 +133,7 @@ public: deinit(); } - LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override { + LM_ERRBOOL append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override { auto& state = get_state(); // Append to current prompt @@ -152,16 +151,16 @@ public: ); // Make sure token limit isn't being hit - if (LM_COAWAIT window_scroll()) { + if (window_scroll()) { // That function already has evaluated our tokens since scrolling was needed - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } // Evaluate new tokens - LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick); + return evaluate_tokens(old_token_count, on_tick); } - LM_SCHEDULABLE(std::string) run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override { + std::string run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override { auto& state = get_state(); std::string fres; @@ -187,7 +186,7 @@ public: state->tokens.push_back(id); // Make sure token limit isn't being hit - LM_COAWAIT window_scroll(); + window_scroll(); // Get token as string const std::string_view str = state->vocab.id_to_token[id]; @@ -202,13 +201,12 @@ public: // TODO: Respect batch size std::vector batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size()); if (!gptj_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) { - LM_COTHROW("Failed to evaluate new tokens", ""); + LM_THROW("Failed to evaluate new tokens", ""); } } // Tick if (on_tick && !on_tick(str.data())) abort = true; - else if (!LM_TASKYIELD) abort = true; } // Create final string TODO: Could be optimized @@ -217,59 +215,59 @@ public: } // Return final string - LM_CORETURN fres; + return fres; } unsigned get_context_size() const noexcept override { return get_state()->tokens.size(); } - LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override { + LM_ERRBOOL create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override { auto& state = get_state(); sv.buf.resize(gptj_get_state_size(state->model)); gptj_copy_state_data(state->model, state->rng, sv.buf.data()); sv.tokens = state->tokens; sv.prompt = state->prompt; sv.ctx = generic_state; - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } - LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override { + LM_ERRBOOL restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override { auto& state = get_state(); if (sv.ctx != generic_state) - LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR); + LM_THROW("Savestate does not match context", LM_BOOL_ERROR); gptj_set_state_data(&state->model, &state->rng, sv.buf.data()); state->tokens = sv.tokens; state->prompt = sv.prompt; - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } - LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream &o) const LM_NOEXCEPTDECL override { + LM_ERRBOOL serialize(std::ostream &o) const LM_NOEXCEPTDECL override { auto& state = get_state(); // Get state size auto state_size = gptj_get_state_size(state->model); // Write sizes for (const uint32_t s : {state->tokens.size(), state->prompt.size(), state_size}) { if (!o.write(reinterpret_cast(&s), sizeof(s))) { - LM_COTHROW("Failed to serialize data sizes", LM_BOOL_ERROR); + LM_THROW("Failed to serialize data sizes", LM_BOOL_ERROR); } } // Write tokens if (!o.write(reinterpret_cast(state->tokens.data()), state->tokens.size()*sizeof(int))) { - LM_COTHROW("Failed to serialize tokens", LM_BOOL_ERROR); + LM_THROW("Failed to serialize tokens", LM_BOOL_ERROR); } // Write prompt if (!o.write(state->prompt.data(), state->prompt.size())) { - LM_COTHROW("Failed to serialize prompt", LM_BOOL_ERROR); + LM_THROW("Failed to serialize prompt", LM_BOOL_ERROR); } // Write state std::vector state_buf(state_size); gptj_copy_state_data(state->model, state->rng, state_buf.data()); if (!o.write(reinterpret_cast(state_buf.data()), state_size)) { - LM_COTHROW("Failed to serialize state", LM_BOOL_ERROR); + LM_THROW("Failed to serialize state", LM_BOOL_ERROR); } - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } - LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream &i) LM_NOEXCEPTDECL override { + LM_ERRBOOL deserialize(std::istream &i) LM_NOEXCEPTDECL override { auto& state = get_state(); uint32_t embd_size, prompt_size, state_size; // Initialization to prevent compiler complaints @@ -277,26 +275,26 @@ public: // Read sizes for (uint32_t *s : {&embd_size, &prompt_size, &state_size}) { if (!i.read(reinterpret_cast(s), sizeof(*s))) { - LM_COTHROW("Failed to deserialize data sizes", LM_BOOL_ERROR); + LM_THROW("Failed to deserialize data sizes", LM_BOOL_ERROR); } } // Read tokens state->tokens.resize(embd_size); if (!i.read(reinterpret_cast(state->tokens.data()), state->tokens.size()*sizeof(int))) { - LM_COTHROW("Failed to deserialize tokens", LM_BOOL_ERROR); + LM_THROW("Failed to deserialize tokens", LM_BOOL_ERROR); } // Read prompt state->prompt.resize(prompt_size); if (!i.read(state->prompt.data(), state->prompt.size())) { - LM_COTHROW("Failed to deserialize prompt", LM_BOOL_ERROR); + LM_THROW("Failed to deserialize prompt", LM_BOOL_ERROR); } // Read state std::vector state_buf(state_size); if (!i.read(reinterpret_cast(state_buf.data()), state_buf.size())) { - LM_COTHROW("Failed to deserialize state", LM_BOOL_ERROR); + LM_THROW("Failed to deserialize state", LM_BOOL_ERROR); } gptj_set_state_data(&state->model, &state->rng, state_buf.data()); - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } const std::string &get_prompt() const LM_NOEXCEPTDECL override { return get_state()->prompt; diff --git a/justlm_llama.hpp b/justlm_llama.hpp index d3ca43b..52b04ac 100644 --- a/justlm_llama.hpp +++ b/justlm_llama.hpp @@ -64,12 +64,12 @@ class LLaMAInference final : public Inference { // This function reduces the size of our tokens vector according to some parameters // All tokens will be evaluated if scrolling was needed and true will be returned - LM_SCHEDULABLE(bool) window_scroll() LM_NOEXCEPTDECL { + bool window_scroll() LM_NOEXCEPTDECL { auto &state = get_state(); // Check that we actually need to scroll if (state->tokens.size() <= state->n_ctx) { // Nope - LM_CORETURN false; + return false; } // Start scrolling if (params.scroll_keep > 0.0f) { @@ -86,11 +86,11 @@ class LLaMAInference final : public Inference { state->tokens.resize(params.n_ctx_window_top_bar); } // Evaluate tokens - LM_ERROR_FORWARD(LM_COAWAIT evaluate_tokens(0, on_scroll), LM_BOOL_ERROR); - LM_CORETURN true; + LM_ERROR_FORWARD(evaluate_tokens(0, on_scroll), LM_BOOL_ERROR); + return true; } - LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick = nullptr) LM_NOEXCEPTDECL { + LM_ERRBOOL evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick = nullptr) LM_NOEXCEPTDECL { auto& state = get_state(); // Evaluate tokens in batches @@ -101,7 +101,7 @@ class LLaMAInference final : public Inference { // Evaluate const auto batch = llama_batch_get_one(state->tokens.data()+it, params.n_batch, it, 0); if (llama_decode(state->ctx, batch)) { - LM_COTHROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR); + LM_THROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR); } // Tick @@ -109,8 +109,7 @@ class LLaMAInference final : public Inference { // Calculate progress auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f; // Tick and yield - if (!on_tick(progress)) LM_CORETURN LM_BOOL_SUCCESS; - else if (!LM_TASKYIELD) LM_CORETURN LM_BOOL_SUCCESS; + if (!on_tick(progress)) return LM_BOOL_SUCCESS; } } @@ -119,7 +118,7 @@ class LLaMAInference final : public Inference { for (; it != state->tokens.size(); it++) { const auto batch = llama_batch_get_one(state->tokens.data()+it, 1, it, 0); if (llama_decode(state->ctx, batch)) { - LM_COTHROW("Failed to evaluate individual tokens", LM_BOOL_ERROR); + LM_THROW("Failed to evaluate individual tokens", LM_BOOL_ERROR); } } } @@ -127,7 +126,7 @@ class LLaMAInference final : public Inference { // Notify about completion if (on_tick) on_tick(100.f); - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } int accept_token(int t) { @@ -198,7 +197,7 @@ public: } } - LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override { + LM_ERRBOOL append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override { auto& state = get_state(); // Check if prompt was empty @@ -216,16 +215,16 @@ public: state->tokens.resize(old_token_count+token_count); // Make sure token limit isn't being hit - if (LM_COAWAIT window_scroll()) { + if (window_scroll()) { // That function already has evaluated our tokens since scrolling was needed - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } // Evaluate new tokens - LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick); + return evaluate_tokens(old_token_count, on_tick); } - LM_SCHEDULABLE(std::string) run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override { + std::string run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override { auto& state = get_state(); std::string fres; @@ -240,7 +239,7 @@ public: try { id = llama_sample_top_p_top_k(); } catch (const std::exception& e) { - LM_COTHROW(e.what(), ""); + LM_THROW(e.what(), ""); } if (id == llama_token_eos(state->model)) { @@ -257,7 +256,7 @@ public: } // Make sure token limit isn't hit - LM_COAWAIT window_scroll(); + window_scroll(); // Get token as string std::string str(14, ' '); @@ -274,13 +273,12 @@ public: // TODO: Respect batch size const auto batch = llama_batch_get_one(state->tokens.data()+state->tokens.size()-1, 1, state->tokens.size()-1, 0); if (llama_decode(state->ctx, batch)) { - LM_COTHROW("Failed to evaluate new tokens", ""); + LM_THROW("Failed to evaluate new tokens", ""); } } // Tick and yield if (on_tick && !on_tick(str.data())) abort = true; - else if (!LM_TASKYIELD) abort = true; } // Create final string TODO: Could be optimized @@ -289,59 +287,59 @@ public: } // Return final string - LM_CORETURN fres; + return fres; } unsigned get_context_size() const noexcept override { return get_state()->tokens.size(); } - LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override { + LM_ERRBOOL create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override { auto& state = get_state(); sv.buf.resize(llama_get_state_size(state->ctx)); llama_copy_state_data(state->ctx, sv.buf.data()); sv.tokens = state->tokens; sv.prompt = state->prompt; sv.ctx = generic_state; - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } - LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override { + LM_ERRBOOL restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override { auto& state = get_state(); if (sv.ctx != generic_state) - LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR); + LM_THROW("Savestate does not match context", LM_BOOL_ERROR); llama_set_state_data(state->ctx, const_cast(sv.buf.data())); state->tokens = sv.tokens; state->prompt = sv.prompt; - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } - LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream &o) const LM_NOEXCEPTDECL override { + LM_ERRBOOL serialize(std::ostream &o) const LM_NOEXCEPTDECL override { auto& state = get_state(); // Get state size auto state_size = llama_get_state_size(state->ctx); // Write sizes for (const uint32_t s : {static_cast(state->n_ctx), state->tokens.size(), state->prompt.size(), state_size}) { if (!o.write(reinterpret_cast(&s), sizeof(s))) { - LM_COTHROW("Failed to serialize data sizes", LM_BOOL_ERROR); + LM_THROW("Failed to serialize data sizes", LM_BOOL_ERROR); } } // Write tokens if (!o.write(reinterpret_cast(state->tokens.data()), state->tokens.size()*sizeof(int))) { - LM_COTHROW("Failed to serialize tokens", LM_BOOL_ERROR); + LM_THROW("Failed to serialize tokens", LM_BOOL_ERROR); } // Write prompt if (!o.write(state->prompt.data(), state->prompt.size())) { - LM_COTHROW("Failed to serialize prompt", LM_BOOL_ERROR); + LM_THROW("Failed to serialize prompt", LM_BOOL_ERROR); } // Write state std::vector state_buf(state_size); llama_copy_state_data(state->ctx, state_buf.data()); if (!o.write(reinterpret_cast(state_buf.data()), state_size)) { - LM_COTHROW("Failed to serialize state", LM_BOOL_ERROR); + LM_THROW("Failed to serialize state", LM_BOOL_ERROR); } - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } - LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream &i) LM_NOEXCEPTDECL override { + LM_ERRBOOL deserialize(std::istream &i) LM_NOEXCEPTDECL override { auto& state = get_state(); uint32_t n_ctx, embd_size, prompt_size, state_size; // Initialization to prevent compiler complaints @@ -349,53 +347,53 @@ public: // Read sizes for (uint32_t *s : {&n_ctx, &embd_size, &prompt_size, &state_size}) { if (!i.read(reinterpret_cast(s), sizeof(*s))) { - LM_COTHROW("Failed to deserialize data sizes", LM_BOOL_ERROR); + LM_THROW("Failed to deserialize data sizes", LM_BOOL_ERROR); } } if (state->n_ctx != n_ctx) { - LM_COTHROW("Context length differs (My "+std::to_string(state->n_ctx)+" vs. files "+std::to_string(n_ctx)+')', LM_BOOL_ERROR); + LM_THROW("Context length differs (My "+std::to_string(state->n_ctx)+" vs. files "+std::to_string(n_ctx)+')', LM_BOOL_ERROR); } // Read tokens state->tokens.resize(embd_size); if (!i.read(reinterpret_cast(state->tokens.data()), state->tokens.size()*sizeof(int))) { - LM_COTHROW("Failed to deserialize tokens", LM_BOOL_ERROR); + LM_THROW("Failed to deserialize tokens", LM_BOOL_ERROR); } // Read prompt state->prompt.resize(prompt_size); if (!i.read(state->prompt.data(), state->prompt.size())) { - LM_COTHROW("Failed to deserialize prompt", LM_BOOL_ERROR); + LM_THROW("Failed to deserialize prompt", LM_BOOL_ERROR); } // Read state std::vector state_buf(state_size); if (!i.read(reinterpret_cast(state_buf.data()), state_buf.size())) { - LM_COTHROW("Failed to deserialize state", LM_BOOL_ERROR); + LM_THROW("Failed to deserialize state", LM_BOOL_ERROR); } llama_set_state_data(state->ctx, state_buf.data()); - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } - LM_SCHEDULABLE(LM_ERRBOOL) load_grammar(const std::string& src, bool override_temperature) LM_NOEXCEPTDECL override { + LM_ERRBOOL load_grammar(const std::string& src, bool override_temperature) LM_NOEXCEPTDECL override { auto& state = get_state(); state->parsed_grammar = grammar_parser::parse(src.c_str()); if (state->parsed_grammar.rules.empty()) { - LM_COTHROW("Failed to parse grammar (or no rules)", LM_BOOL_ERROR); + LM_THROW("Failed to parse grammar (or no rules)", LM_BOOL_ERROR); } auto rules = state->parsed_grammar.c_rules(); state->grammar = llama_grammar_init(rules.data(), rules.size(), state->parsed_grammar.symbol_ids.at("root")); if (!state->grammar) { - LM_COTHROW("Failed to generate llama grammar", LM_BOOL_ERROR); + LM_THROW("Failed to generate llama grammar", LM_BOOL_ERROR); } state->grammar_override_temp = override_temperature; - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } - LM_SCHEDULABLE(LM_ERRBOOL) unload_grammar() LM_NOEXCEPTDECL override { + LM_ERRBOOL unload_grammar() LM_NOEXCEPTDECL override { get_state()->grammar = nullptr; - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } const std::string &get_prompt() const LM_NOEXCEPTDECL override { diff --git a/justlm_mpt.hpp b/justlm_mpt.hpp index 1ef9978..0a7d6b6 100644 --- a/justlm_mpt.hpp +++ b/justlm_mpt.hpp @@ -68,12 +68,12 @@ class MPTInference final : public Inference { // This function reduces the size of our tokens vector according to some parameters // All tokens will be evaluated if scrolling was needed and true will be returned - LM_SCHEDULABLE(bool) window_scroll() LM_NOEXCEPTDECL { + bool window_scroll() LM_NOEXCEPTDECL { auto &state = get_state(); // Check that we actually need to scroll if (state->tokens.size() <= params.n_ctx) { // Nope - LM_CORETURN false; + return false; } // Start scrolling if (params.scroll_keep > 0.0f) { @@ -90,11 +90,11 @@ class MPTInference final : public Inference { state->tokens.resize(params.n_ctx_window_top_bar); } // Evaluate tokens - LM_ERROR_FORWARD(LM_COAWAIT evaluate_tokens(0, on_scroll), LM_BOOL_ERROR); - LM_CORETURN true; + LM_ERROR_FORWARD(evaluate_tokens(0, on_scroll), LM_BOOL_ERROR); + return true; } - LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick) LM_NOEXCEPTDECL { + LM_ERRBOOL evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick) LM_NOEXCEPTDECL { auto& state = get_state(); // Evaluate tokens in batches @@ -105,7 +105,7 @@ class MPTInference final : public Inference { // Evaluate std::vector batch(state->tokens.begin()+it, state->tokens.begin()+it+params.n_batch); if (!mpt_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) { - LM_COTHROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR); + LM_THROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR); } // Tick @@ -113,8 +113,7 @@ class MPTInference final : public Inference { // Calculate progress auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f; // Tick and yield - if (!on_tick(progress)) LM_CORETURN LM_BOOL_SUCCESS; - else if (!LM_TASKYIELD) LM_CORETURN LM_BOOL_SUCCESS; + if (!on_tick(progress)) return LM_BOOL_SUCCESS; } } @@ -124,7 +123,7 @@ class MPTInference final : public Inference { //TODO: This is extremely inefficient! Don't do that... std::vector batch(state->tokens.begin()+it, state->tokens.begin()+it+1); if (!mpt_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) { - LM_COTHROW("Failed to evaluate individual tokens", LM_BOOL_ERROR); + LM_THROW("Failed to evaluate individual tokens", LM_BOOL_ERROR); } } } @@ -132,7 +131,7 @@ class MPTInference final : public Inference { // Notify about completion if (on_tick) on_tick(100.f); - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } public: @@ -143,7 +142,7 @@ public: deinit(); } - LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override { + LM_ERRBOOL append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override { auto& state = get_state(); // Append to current prompt @@ -161,16 +160,16 @@ public: ); // Make sure token limit isn't being hit - if (LM_COAWAIT window_scroll()) { + if (window_scroll()) { // That function already has evaluated our tokens since scrolling was needed - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } // Evaluate new tokens - LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick); + return evaluate_tokens(old_token_count, on_tick); } - LM_SCHEDULABLE(std::string) run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override { + std::string run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override { auto& state = get_state(); std::string fres; @@ -202,7 +201,7 @@ public: state->tokens.push_back(id); // Make sure token limit isn't being hit - LM_COAWAIT window_scroll(); + window_scroll(); // Get token as string const std::string_view str = state->vocab.id_to_token[id]; @@ -218,13 +217,12 @@ public: // TODO: Respect batch size std::vector batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size()); if (!mpt_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) { - LM_COTHROW("Failed to evaluate new tokens", ""); + LM_THROW("Failed to evaluate new tokens", ""); } } // Tick if (on_tick && !on_tick(str.data())) abort = true; - else if (!LM_TASKYIELD) abort = true; } // Create final string TODO: Could be optimized @@ -233,59 +231,59 @@ public: } // Return final string - LM_CORETURN fres; + return fres; } unsigned get_context_size() const noexcept override { return get_state()->tokens.size(); } - LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override { + LM_ERRBOOL create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override { auto& state = get_state(); sv.buf.resize(mpt_get_state_size(state->model)); mpt_copy_state_data(state->model, state->rng, sv.buf.data()); sv.tokens = state->tokens; sv.prompt = state->prompt; sv.ctx = generic_state; - LM_CORETURN LM_BOOL_SUCCESS ; + return LM_BOOL_SUCCESS ; } - LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override { + LM_ERRBOOL restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override { auto& state = get_state(); if (sv.ctx != generic_state) - LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR); + LM_THROW("Savestate does not match context", LM_BOOL_ERROR); mpt_set_state_data(&state->model, &state->rng, sv.buf.data()); state->tokens = sv.tokens; state->prompt = sv.prompt; - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } - LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream &o) const LM_NOEXCEPTDECL override { + LM_ERRBOOL serialize(std::ostream &o) const LM_NOEXCEPTDECL override { auto& state = get_state(); // Get state size auto state_size = mpt_get_state_size(state->model); // Write sizes for (const uint32_t s : {state->tokens.size(), state->prompt.size(), state_size}) { if (!o.write(reinterpret_cast(&s), sizeof(s))) { - LM_COTHROW("Failed to serialize data sizes", LM_BOOL_ERROR); + LM_THROW("Failed to serialize data sizes", LM_BOOL_ERROR); } } // Write tokens if (!o.write(reinterpret_cast(state->tokens.data()), state->tokens.size()*sizeof(int))) { - LM_COTHROW("Failed to serialize tokens", LM_BOOL_ERROR); + LM_THROW("Failed to serialize tokens", LM_BOOL_ERROR); } // Write prompt if (!o.write(state->prompt.data(), state->prompt.size())) { - LM_COTHROW("Failed to serialize prompt", LM_BOOL_ERROR); + LM_THROW("Failed to serialize prompt", LM_BOOL_ERROR); } // Write state std::vector state_buf(state_size); mpt_copy_state_data(state->model, state->rng, state_buf.data()); if (!o.write(reinterpret_cast(state_buf.data()), state_size)) { - LM_COTHROW("Failed to serialize state", LM_BOOL_ERROR); + LM_THROW("Failed to serialize state", LM_BOOL_ERROR); } - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } - LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream &i) LM_NOEXCEPTDECL override { + LM_ERRBOOL deserialize(std::istream &i) LM_NOEXCEPTDECL override { auto& state = get_state(); uint32_t embd_size, promptsize, state_size; // Initialization to prevent compiler complaints @@ -293,26 +291,26 @@ public: // Read sizes for (uint32_t *s : {&embd_size, &promptsize, &state_size}) { if (!i.read(reinterpret_cast(s), sizeof(*s))) { - LM_COTHROW("Failed to deserialize data sizes", LM_BOOL_ERROR); + LM_THROW("Failed to deserialize data sizes", LM_BOOL_ERROR); } } // Read tokens state->tokens.resize(embd_size); if (!i.read(reinterpret_cast(state->tokens.data()), state->tokens.size()*sizeof(int))) { - LM_COTHROW("Failed to deserialize tokens", LM_BOOL_ERROR); + LM_THROW("Failed to deserialize tokens", LM_BOOL_ERROR); } // Read prompt state->prompt.resize(promptsize); if (!i.read(state->prompt.data(), state->prompt.size())) { - LM_COTHROW("Failed to deserialize prompt", LM_BOOL_ERROR); + LM_THROW("Failed to deserialize prompt", LM_BOOL_ERROR); } // Read state std::vector state_buf(state_size); if (!i.read(reinterpret_cast(state_buf.data()), state_buf.size())) { - LM_COTHROW("Failed to deserialize state", LM_BOOL_ERROR); + LM_THROW("Failed to deserialize state", LM_BOOL_ERROR); } mpt_set_state_data(&state->model, &state->rng, state_buf.data()); - LM_CORETURN LM_BOOL_SUCCESS; + return LM_BOOL_SUCCESS; } const std::string &get_prompt() const LM_NOEXCEPTDECL override { return get_state()->prompt; diff --git a/justlm_pool.cpp b/justlm_pool.cpp index 02996c5..5d9659e 100644 --- a/justlm_pool.cpp +++ b/justlm_pool.cpp @@ -6,7 +6,7 @@ -LM_SCHEDULABLE(bool) LM::InferencePool::store_slot(Slot &slot) { +bool LM::InferencePool::store_slot(Slot &slot) { auto inference = slot.get_inference(); // Open output file std::ofstream f(get_slot_filename(slot.get_id()), std::ios::binary); @@ -17,61 +17,61 @@ LM_SCHEDULABLE(bool) LM::InferencePool::store_slot(Slot &slot) { f.write(weights_path.data(), weights_path.size()); // Write params if (!f.write(reinterpret_cast(&inference->params), sizeof(inference->params))) { - LM_CORETURN false; + return false; } // Serialize instance try { - LM_COAWAIT inference->serialize(f); + inference->serialize(f); } catch (...) { - LM_CORETURN false; + return false; } // Return success - LM_CORETURN true; + return true; } -LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::load_slot(size_t id, Slot *suggested_slot) { +LM::InferencePool::Slot *LM::InferencePool::load_slot(size_t id, Slot *suggested_slot) { // Open input file std::ifstream f(get_slot_filename(id), std::ios::binary); if (!f) { // Does not exist - LM_CORETURN nullptr; + return nullptr; } // Read weights path std::string weights_path; uint32_t weights_path_len; if (!f.read(reinterpret_cast(&weights_path_len), sizeof(weights_path_len))) { - LM_CORETURN nullptr; + return nullptr; } weights_path.resize(weights_path_len); if (!f.read(weights_path.data(), weights_path.size())) { - LM_CORETURN nullptr; + return nullptr; } // Read params LM::Inference::Params p; if (!f.read(reinterpret_cast(&p), sizeof(p))) { - LM_CORETURN nullptr; + return nullptr; } // Create instance - auto& slot = suggested_slot?*suggested_slot:*(LM_COAWAIT get_free_slot()); + auto& slot = suggested_slot?*suggested_slot:*(get_free_slot()); auto inference = slot.create_inference(id, weights_path, p); // Deserialize instance try { - LM_COAWAIT inference->deserialize(f); + inference->deserialize(f); } catch (...) { slot.reset(); - LM_CORETURN nullptr; + return nullptr; } // Return final slot - LM_CORETURN &slot; + return &slot; } -LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::get_free_slot() { +LM::InferencePool::Slot *LM::InferencePool::get_free_slot() { // Attempt to find free slot while finding oldest one Slot *oldest = nullptr; for (auto& slot : slots) { // Take free slot if (slot.is_free()) { - LM_CORETURN &slot; + return &slot; } // Update oldest if (oldest == nullptr || slot.get_last_access() < oldest->get_last_access()) { @@ -80,17 +80,17 @@ LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::get_free_slot() { } // Free up oldest slot and take that one // Note: Since there has to be at least 1 slot, oldest is never going to be a nullptr - LM_COAWAIT store_and_reset_slot(*oldest); - LM_CORETURN oldest; + store_and_reset_slot(*oldest); + return oldest; } -LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::find_slot_by_id(size_t id, bool deserialize) { +LM::InferencePool::Slot *LM::InferencePool::find_slot_by_id(size_t id, bool deserialize) { // Attempt to find given slot while finding oldest one Slot *oldest = nullptr; for (auto& slot : slots) { // Take slot with ID if (slot.get_id() == id) { - LM_CORETURN &slot; + return &slot; } // Update oldest if (oldest == nullptr || slot.get_last_access() < oldest->get_last_access()) { @@ -99,38 +99,38 @@ LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::find_slot_by_id(size } // Slot not found, attempt to load it if (deserialize) { - if (!oldest->is_free()) LM_COAWAIT store_slot(*oldest); - if (!LM_COAWAIT load_slot(id, oldest)) { + if (!oldest->is_free()) store_slot(*oldest); + if (!load_slot(id, oldest)) { // In case slot loading failed, still reset slot for later use //TODO: Make this configurable oldest->reset(); } else { - LM_CORETURN oldest; + return oldest; } } // Slot not found - LM_CORETURN nullptr; + return nullptr; } -LM_SCHEDULABLE(std::shared_ptr) LM::InferencePool::get_inference(size_t id) { - auto slot = LM_COAWAIT find_slot_by_id(id); +std::shared_ptr LM::InferencePool::get_inference(size_t id) { + auto slot = find_slot_by_id(id); if (slot) { - LM_CORETURN slot->get_inference(true); + return slot->get_inference(true); } - LM_CORETURN {}; + return {}; } -LM_SCHEDULABLE(std::shared_ptr) LM::InferencePool::get_or_create_inference(size_t id, const std::string &weights_path, const Inference::Params &p) { - auto slot = LM_COAWAIT find_slot_by_id(id); +std::shared_ptr LM::InferencePool::get_or_create_inference(size_t id, const std::string &weights_path, const Inference::Params &p) { + auto slot = find_slot_by_id(id); if (slot) { - LM_CORETURN slot->get_inference(true); + return slot->get_inference(true); } - slot = LM_COAWAIT get_free_slot(); - LM_CORETURN slot->create_inference(id, weights_path, p); + slot = get_free_slot(); + return slot->create_inference(id, weights_path, p); } -LM_SCHEDULABLE(void) LM::InferencePool::delete_inference(size_t id) { - auto slot = LM_COAWAIT find_slot_by_id(id, false); +void LM::InferencePool::delete_inference(size_t id) { + auto slot = find_slot_by_id(id, false); // Reset slot if (slot) { slot->reset(); @@ -140,12 +140,12 @@ LM_SCHEDULABLE(void) LM::InferencePool::delete_inference(size_t id) { std::filesystem::remove(get_slot_filename(id), ec); } -LM_SCHEDULABLE(void) LM::InferencePool::store_all() { +void LM::InferencePool::store_all() { for (auto& slot : slots) { if (slot.is_free()) continue; - LM_COAWAIT store_slot(slot); + store_slot(slot); } - LM_CORETURN; + return; } std::vector LM::InferencePool::get_active_slot_ids() const {