1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00

Compare commits

...

32 commits
v1.0 ... master

Author SHA1 Message Date
449a25c360 Use option() in CMake 2024-03-25 01:19:05 +01:00
90e54d66d0 Removed cosched support 2024-03-25 01:18:37 +01:00
niansa
1a04a0e6d9 Updated llama.cpp-mainline 2023-12-25 16:55:30 +01:00
ef5df1dc31 Updated llama.cpp-mainline 2023-11-09 12:51:53 +01:00
niansa
fc5e4f5aa1 Updated llama.cpp-mainline 2023-10-04 22:13:48 +02:00
215db6b9b7 Fully implemented grammar sampling 2023-09-05 10:22:42 +02:00
f5314a0dde Added python bindings for grammar 2023-09-05 09:27:45 +02:00
niansa
79cf49faae Implemented grammar sampling and zero-temperature sampling 2023-08-31 19:37:33 +02:00
niansa
3a953ed13a Convert tokens to text correctly in llama 2023-08-31 18:23:55 +02:00
niansa
907cea7f9d Fixed exception if pre_tick is nullptr 2023-08-31 18:07:42 +02:00
niansa
7cd3899dd0 Check for correct magic value in llama 2023-08-31 17:57:56 +02:00
niansa
cb683aa8fc Updated llama.cpp.cmake 2023-08-31 17:00:50 +02:00
niansa
5d818e31aa Call llama_backend_init()/llama_backend_free() 2023-08-31 16:56:10 +02:00
niansa
e3d52c42b7 Updated llama-mainline and deleted old llama versions 2023-08-31 16:52:38 +02:00
niansa
d8f4efb0c9 Cut off ending from run() result properly 2023-06-25 01:20:56 +02:00
niansa
08ff1e72e7 Update llama.cpp-mainline 2023-06-25 01:18:57 +02:00
niansa
01b0d059ed Added pre_tick 2023-06-15 18:14:09 +02:00
niansa
bcacfc3d54 Minor CMake fixes 2023-06-10 02:04:50 +02:00
niansa
0199db02b7 Added GPU support 2023-06-10 00:49:21 +02:00
niansa
e2f7da65e4 Fixed llama.cpp not generating symbols 2023-06-10 00:38:38 +02:00
niansa
94953cd174 Improve some error handling macros 2023-06-09 23:53:01 +02:00
niansa
24849804b6 Major CMake improvements 2023-06-09 20:01:49 +02:00
niansa
b3bd78b350 Fixups in llama.cpp.cmake 2023-06-09 19:43:29 +02:00
niansa
a03558ae89 Expose options 2023-06-09 19:39:24 +02:00
niansa
38b229dab5 Updated to latest functional llama version 2023-06-09 12:01:41 +02:00
niansa
09e59a9536 Fixed compile errors because of previous commit 2023-05-31 20:22:18 +02:00
niansa
0142db3f7c Renamed operator ""_MB -> operator ""_MiB 2023-05-31 20:20:31 +02:00
niansa
2d57ade1b8 add msvc support -polyfill unistd 2023-05-31 19:56:40 +02:00
4b19bc49a5 Fixed llama.cpp.cmake 2023-05-26 13:44:26 +02:00
niansa
53a4623aef Added mirostat support 2023-05-26 00:43:07 +02:00
ad0b7e3c71 Updated llama.cpp-mainline 2023-05-23 13:41:30 +02:00
niansa
24ff52919f Renamed justlm_llama_old to justlm_llama_230511 2023-05-21 16:13:51 +02:00
16 changed files with 571 additions and 394 deletions

View file

@ -8,26 +8,16 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
set(LM_PYBIND No CACHE BOOL "If justlm Python bindings should be build")
set(LM_COSCHED No CACHE BOOL "If justlm should make use of CoSched")
set(LM_NOEXCEPT No CACHE BOOL "If justlm exceptions should be disabled")
set(LM_LLAMA Yes CACHE BOOL "If LLaMa model support should be built into justlm")
set(LM_LLAMA_OLD Yes CACHE BOOL "If old LLaMa model support should be built into justlm")
set(LM_GPTJ Yes CACHE BOOL "If GPT-J model support should be built into justlm")
set(LM_MPT Yes CACHE BOOL "If MPT model support should be built into justlm")
if (LM_COSCHED)
set(CMAKE_CXX_STANDARD 20)
endif()
option(LM_PYBIND "If justlm Python bindings should be build" OFF)
option(LM_NOEXCEPT "If justlm exceptions should be disabled" OFF)
option(LM_LLAMA "If LLaMa model support should be built into justlm" ON)
option(LM_GPTJ "If GPT-J model support should be built into justlm" ON)
option(LM_MPT "If MPT model support should be built into justlm" ON)
function(target_justlm_setup TARGET_NAME)
message(STATUS "Configuring model implementation target ${TARGET_NAME}")
target_include_directories(${TARGET_NAME} PUBLIC include/)
if (LM_COSCHED)
target_compile_definitions(${TARGET_NAME} PUBLIC LM_COSCHED)
target_link_libraries(${TARGET_NAME} PRIVATE cosched)
endif()
if (LM_NOEXCEPT)
target_compile_definitions(${TARGET_NAME} PUBLIC LM_NOEXCEPT)
endif()
@ -37,8 +27,6 @@ endfunction()
include(llama.cpp.cmake)
include_ggml(llama.cpp-mainline _mainline Yes)
include_ggml(llama.cpp-230511 _230511 Yes)
include_ggml(llama.cpp-230519 _230519 Yes)
include_ggml(llama.cpp-alibi _alibi No)
@ -53,32 +41,17 @@ endif()
if (LM_GPTJ)
add_library(justlm_gptj SHARED gptj.cpp justlm_gptj.hpp gptj/gptj.cpp gptj/gptj.hpp)
target_link_libraries(justlm_gptj PRIVATE ggml_230511 justlm_g4a_common)
target_link_libraries(justlm_gptj PRIVATE ggml_alibi justlm_g4a_common)
target_justlm_setup(justlm_gptj)
endif()
if (LM_LLAMA)
add_library(justlm_llama SHARED llama.cpp justlm_llama.hpp)
target_link_libraries(justlm_llama PRIVATE ggml_mainline llama_mainline)
target_compile_definitions(justlm_llama PRIVATE
LLAMA_VERSIONS=>=3 LLAMA_DATE=999999)
target_compile_definitions(justlm_llama PRIVATE LLAMA_DATE=999999)
target_justlm_setup(justlm_llama)
endif()
if (LM_LLAMA_OLD)
add_library(justlm_llama_old SHARED llama.cpp justlm_llama.hpp)
target_link_libraries(justlm_llama_old PRIVATE ggml_230511 llama_230511)
target_compile_definitions(justlm_llama_old PRIVATE
LLAMA_VERSIONS=<=1 LLAMA_DATE=230511)
target_justlm_setup(justlm_llama_old)
add_library(justlm_llama_230519 SHARED llama.cpp justlm_llama.hpp)
target_link_libraries(justlm_llama_230519 PRIVATE ggml_230519 llama_230519)
target_compile_definitions(justlm_llama_230519 PRIVATE
LLAMA_VERSIONS===2 LLAMA_DATE=230519)
target_justlm_setup(justlm_llama_230519)
endif()
add_library(justlm STATIC
include/justlm.hpp justlm.cpp

View file

@ -11,13 +11,13 @@
#include <string>
#include <vector>
#include <iostream>
#include <unistd.h>
#include "../msvc_compat_unistd.h"
#include <sstream>
#include <unordered_set>
#include <ggml.h>
constexpr inline
unsigned long long operator ""_MB(unsigned long long bytes) {
unsigned long long operator ""_MiB(unsigned long long bytes) {
return bytes*1024*1024;
}
@ -32,7 +32,7 @@ static bool kv_cache_init(
const int64_t n_mem = (int64_t)n_layer*n_ctx;
const int64_t n_elements = n_embd*n_mem;
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MB);
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MiB);
struct ggml_init_params params;
params.mem_size = cache.buf.size;
@ -394,7 +394,7 @@ bool gptj_eval(
const int n_vocab = hparams.n_vocab;
const int n_rot = hparams.n_rot;
static size_t buf_size = 1024_MB;
static size_t buf_size = 1024_MiB;
if (!model.buf.addr || model.buf.size < buf_size)
model.buf.resize(buf_size);

View file

@ -7,45 +7,28 @@
#include <memory>
#include <thread>
#ifdef LM_COSCHED
# include <scheduler.hpp>
# define LM_SCHEDULABLE(type) ::CoSched::AwaitableTask<type>
# define LM_CORETURN co_return
# define LM_COAWAIT co_await
# define LM_TASKYIELD (co_await ::CoSched::Task::get_current().yield())
#else
# define LM_SCHEDULABLE(type) type
# define LM_CORETURN return
# define LM_COAWAIT
# define LM_TASKYIELD (true)
#endif
#ifdef LM_NOEXCEPT
# define LM_NOEXCEPTDECL noexcept
# define LM_THROW(t, r) this->last_error = (t); return r;
# define LM_COTHROW(t, r) this->last_error = (t); LM_CORETURN r;
# define LM_THROW(t, r) do {this->last_error = (t); return r;} while (0)
# define LM_LAST_ERROR_STORAGE mutable std::string last_error;
# define LM_LAST_ERROR_GETTER const std::string& get_last_error() const {return last_error;}
# define LM_ERRBOOL bool
# define LM_BOOL_ERROR false
# define LM_BOOL_SUCCESS true
# define LM_ERROR_FORWARD(x) {auto v = x; if (!v) LM_CORETURN x;} 0
# define LM_RETHROW(x) return x
# define LM_ERROR_CATCH(x, errval, ...) {auto v = x; if (v == (errval)) __VA_ARGS__}
# define LM_ERROR_FORWARD(x, errval) do {auto v = x; if (v == (errval)) return x;} while (0)
#else
# define LM_NOEXCEPTDECL
# define LM_THROW(t, r) throw Exception(t)
# define LM_COTHROW(t, r) throw Exception(t)
# define LM_LAST_ERROR_STORAGE
# define LM_LAST_ERROR_GETTER
# define LM_ERRBOOL void
# define LM_BOOL_ERROR
# define LM_BOOL_SUCCESS
# define LM_ERROR_FORWARD(x) {x;}
#endif
#ifdef LM_COSCHED
#ifndef LM_NOEXCEPT
#warning Exceptions should not be enabled in combination with CoSched. Any exceptions thrown will lead to a std::terminate() call
#endif
# define LM_RETHROW(x) std::rethrow_exception(std::current_exception())
# define LM_ERROR_CATCH(x, errval, ...) try {x;} catch (...) __VA_ARGS__
# define LM_ERROR_FORWARD(x, errval) {x;}
#endif
#if _MSC_VER
@ -58,9 +41,12 @@ namespace LM {
using ssize_t = SSIZE_T;
#endif
using GenerateCallback = std::function<bool (const char *generated)>;
using AppendCallback = std::function<bool (float progress)>;
class Inference {
protected:
std::function<bool (float)> on_scroll = nullptr;
AppendCallback on_scroll = nullptr;
void *generic_state = nullptr;
@ -73,21 +59,25 @@ public:
struct Params {
int seed = 0; // RNG seed
unsigned n_threads = 0;
unsigned n_threads = 0; // Amount of threads to use, immutable after Inference was constructed
unsigned n_ctx = 2024; // Context size
unsigned n_ctx_window_top_bar = 0; // Top bar of context window. Must be smaller than context size
unsigned n_batch = 8; // Batch size
unsigned n_repeat_last = 0; // llama.cpp specific
unsigned n_repeat_last = 0;
unsigned n_eos_ignores = 0;
float scroll_keep = 0.0f; // 0.4f to keep 40% of context below top bar when scrolling; 0.0f to remove everything after top bar
unsigned top_k = 40;
float top_p = 0.9f;
float temp = 0.72f;
float repeat_penalty = 1.0f; // llama.cpp specific
unsigned eos_ignores = 0; // llama.cpp specific
float top_p = 0.9f;
float temp = 0.72f;
float mirostat_learning_rate = 0.1f; // mirostat specific
float mirostat_target_entropy = 5.0f; // mirostat specific
float repeat_penalty = 1.0f;
bool use_mlock = true; // llama.cpp specific
unsigned n_gpu_layers = 38;
bool use_mlock = true; // llama specific
int prefer_mirostat = 0; // Use given mirostat version if available (see is_mirostat_available()); llama specific
} params;
struct Savestate {
@ -118,26 +108,36 @@ public:
static
Inference *construct(const std::string& weights_path, const Params& p);
void set_scroll_callback(const std::function<bool (float)>& scroll_cb) noexcept {
void set_scroll_callback(const AppendCallback& scroll_cb) noexcept {
on_scroll = scroll_cb;
}
// This must be called with a non-empty prompt!
virtual LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float progress)>& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
virtual LM_ERRBOOL append(const std::string& prompt, const AppendCallback& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
// append() must have been called at least once before calling this!
virtual LM_SCHEDULABLE(std::string) run(std::string_view end = "", const std::function<bool (const char *generated)>& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
virtual std::string run(std::string_view end = "", const GenerateCallback& on_tick = nullptr, const GenerateCallback& pre_tick = nullptr) LM_NOEXCEPTDECL = 0;
virtual unsigned get_context_size() const noexcept = 0;
virtual LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate&) const LM_NOEXCEPTDECL = 0;
virtual LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate&) LM_NOEXCEPTDECL = 0;
virtual LM_ERRBOOL create_savestate(Savestate&) const LM_NOEXCEPTDECL = 0;
virtual LM_ERRBOOL restore_savestate(const Savestate&) LM_NOEXCEPTDECL = 0;
virtual LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream&) const LM_NOEXCEPTDECL = 0;
virtual LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream&) LM_NOEXCEPTDECL = 0;
virtual LM_ERRBOOL serialize(std::ostream&) const LM_NOEXCEPTDECL = 0;
virtual LM_ERRBOOL deserialize(std::istream&) LM_NOEXCEPTDECL = 0;
virtual LM_ERRBOOL load_grammar(const std::string&, bool override_temperature [[maybe_unused]] = false) LM_NOEXCEPTDECL {
LM_THROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
}
virtual LM_ERRBOOL unload_grammar() LM_NOEXCEPTDECL {
LM_THROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
}
virtual const std::string& get_prompt() const LM_NOEXCEPTDECL = 0;
virtual bool is_mirostat_available() const noexcept {return false;}
virtual bool is_grammar_available() const noexcept {return false;}
LM_LAST_ERROR_GETTER
};

View file

@ -63,21 +63,21 @@ class InferencePool {
}
// Returns false on error
LM_SCHEDULABLE(bool) store_slot(Slot& slot);
bool store_slot(Slot& slot);
// Returns nullptr on error
LM_SCHEDULABLE(Slot*) load_slot(size_t id, Slot *suggested_slot = nullptr);
Slot *load_slot(size_t id, Slot *suggested_slot = nullptr);
LM_SCHEDULABLE(void) store_and_reset_slot(Slot& slot) {
LM_COAWAIT store_slot(slot); //TODO: Should handle errors somehow
void store_and_reset_slot(Slot& slot) {
store_slot(slot); //TODO: Should handle errors somehow
slot.reset();
LM_CORETURN;
return;
}
// Doesn't fail
LM_SCHEDULABLE(Slot*) get_free_slot();
Slot *get_free_slot();
// Returns nullptr if not found
LM_SCHEDULABLE(Slot*) find_slot_by_id(size_t id, bool deserialize = true);
Slot *find_slot_by_id(size_t id, bool deserialize = true);
public:
// The pool_name must be unique amonst all applications in cwd
@ -93,14 +93,14 @@ public:
}
}
LM_SCHEDULABLE(std::shared_ptr<Inference>) create_inference(size_t id, const std::string& weights_path, const Inference::Params& p) {
auto slot = LM_COAWAIT get_free_slot();
LM_CORETURN slot->create_inference(id, weights_path, p);
std::shared_ptr<Inference> create_inference(size_t id, const std::string& weights_path, const Inference::Params& p) {
auto slot = get_free_slot();
return slot->create_inference(id, weights_path, p);
}
LM_SCHEDULABLE(std::shared_ptr<Inference>) get_inference(size_t id);
LM_SCHEDULABLE(std::shared_ptr<Inference>) get_or_create_inference(size_t id, const std::string& weights_path, const Inference::Params& p);
LM_SCHEDULABLE(void) delete_inference(size_t id);
LM_SCHEDULABLE(void) store_all();
std::shared_ptr<Inference> get_inference(size_t id);
std::shared_ptr<Inference> get_or_create_inference(size_t id, const std::string& weights_path, const Inference::Params& p);
void delete_inference(size_t id);
void store_all();
std::vector<size_t> get_active_slot_ids() const;
void cleanup();

View file

@ -59,12 +59,12 @@ class GPTJInference final : public Inference {
// This function reduces the size of our tokens vector according to some parameters
// All tokens will be evaluated if scrolling was needed and true will be returned
LM_SCHEDULABLE(bool) window_scroll() LM_NOEXCEPTDECL {
bool window_scroll() LM_NOEXCEPTDECL {
auto &state = get_state();
// Check that we actually need to scroll
if (state->tokens.size() <= params.n_ctx) {
// Nope
LM_CORETURN false;
return false;
}
// Start scrolling
if (params.scroll_keep > 0.0f) {
@ -81,11 +81,11 @@ class GPTJInference final : public Inference {
state->tokens.resize(params.n_ctx_window_top_bar);
}
// Evaluate tokens
LM_ERROR_FORWARD(LM_COAWAIT evaluate_tokens(0, on_scroll));
LM_CORETURN true;
LM_ERROR_FORWARD(evaluate_tokens(0, on_scroll), LM_BOOL_ERROR);
return true;
}
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL {
LM_ERRBOOL evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick = nullptr) LM_NOEXCEPTDECL {
auto& state = get_state();
// Evaluate tokens in batches
@ -96,7 +96,7 @@ class GPTJInference final : public Inference {
// Evaluate
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+params.n_batch);
if (!gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
LM_COTHROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
LM_THROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
}
// Tick
@ -104,8 +104,7 @@ class GPTJInference final : public Inference {
// Calculate progress
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
// Tick and yield
if (!on_tick(progress)) LM_CORETURN LM_BOOL_SUCCESS;
else if (!LM_TASKYIELD) LM_CORETURN LM_BOOL_SUCCESS;
if (!on_tick(progress)) return LM_BOOL_SUCCESS;
}
}
@ -115,7 +114,7 @@ class GPTJInference final : public Inference {
//TODO: This is extremely inefficient! Don't do that...
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+1);
if (!gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
LM_COTHROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
LM_THROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
}
}
}
@ -123,7 +122,7 @@ class GPTJInference final : public Inference {
// Notify about completion
if (on_tick) on_tick(100.f);
LM_CORETURN LM_BOOL_SUCCESS;
return LM_BOOL_SUCCESS;
}
public:
@ -134,7 +133,7 @@ public:
deinit();
}
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
LM_ERRBOOL append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
auto& state = get_state();
// Append to current prompt
@ -152,29 +151,31 @@ public:
);
// Make sure token limit isn't being hit
if (LM_COAWAIT window_scroll()) {
if (window_scroll()) {
// That function already has evaluated our tokens since scrolling was needed
LM_CORETURN LM_BOOL_SUCCESS;
return LM_BOOL_SUCCESS;
}
// Evaluate new tokens
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
return evaluate_tokens(old_token_count, on_tick);
}
LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
std::string run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
auto& state = get_state();
std::string fres;
// Loop until done
bool abort = false;
unsigned eos_count = 0;
size_t last_size = 0;
while (!abort && (end.empty() || fres.find(end) == fres.npos)) {
last_size = fres.size();
// Sample top p and top k
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
if (id == 50256) {
if (eos_count++ == params.eos_ignores) {
if (eos_count++ == params.n_eos_ignores) {
abort = true;
continue;
}
@ -185,7 +186,7 @@ public:
state->tokens.push_back(id);
// Make sure token limit isn't being hit
LM_COAWAIT window_scroll();
window_scroll();
// Get token as string
const std::string_view str = state->vocab.id_to_token[id];
@ -194,77 +195,79 @@ public:
state->prompt.append(str);
fres.append(str);
// Evaluate token
// TODO: Respect batch size
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
if (!gptj_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
LM_COTHROW("Failed to evaluate new tokens", "");
if (pre_tick && !pre_tick(str.data())) abort = true;
else {
// Evaluate token
// TODO: Respect batch size
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
if (!gptj_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
LM_THROW("Failed to evaluate new tokens", "");
}
}
// Tick
if (on_tick && !on_tick(str.data())) abort = true;
else if (!LM_TASKYIELD) abort = true;
}
// Create final string TODO: Could be optimized
if (!abort) {
fres = std::string(fres.data(), fres.size()-end.size());
fres = std::string(fres.data(), last_size);
}
// Return final string
LM_CORETURN fres;
return fres;
}
unsigned get_context_size() const noexcept override {
return get_state()->tokens.size();
}
LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
LM_ERRBOOL create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
auto& state = get_state();
sv.buf.resize(gptj_get_state_size(state->model));
gptj_copy_state_data(state->model, state->rng, sv.buf.data());
sv.tokens = state->tokens;
sv.prompt = state->prompt;
sv.ctx = generic_state;
LM_CORETURN LM_BOOL_SUCCESS;
return LM_BOOL_SUCCESS;
}
LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
LM_ERRBOOL restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
auto& state = get_state();
if (sv.ctx != generic_state)
LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR);
LM_THROW("Savestate does not match context", LM_BOOL_ERROR);
gptj_set_state_data(&state->model, &state->rng, sv.buf.data());
state->tokens = sv.tokens;
state->prompt = sv.prompt;
LM_CORETURN LM_BOOL_SUCCESS;
return LM_BOOL_SUCCESS;
}
LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
LM_ERRBOOL serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
auto& state = get_state();
// Get state size
auto state_size = gptj_get_state_size(state->model);
// Write sizes
for (const uint32_t s : {state->tokens.size(), state->prompt.size(), state_size}) {
if (!o.write(reinterpret_cast<const char*>(&s), sizeof(s))) {
LM_COTHROW("Failed to serialize data sizes", LM_BOOL_ERROR);
LM_THROW("Failed to serialize data sizes", LM_BOOL_ERROR);
}
}
// Write tokens
if (!o.write(reinterpret_cast<const char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
LM_COTHROW("Failed to serialize tokens", LM_BOOL_ERROR);
LM_THROW("Failed to serialize tokens", LM_BOOL_ERROR);
}
// Write prompt
if (!o.write(state->prompt.data(), state->prompt.size())) {
LM_COTHROW("Failed to serialize prompt", LM_BOOL_ERROR);
LM_THROW("Failed to serialize prompt", LM_BOOL_ERROR);
}
// Write state
std::vector<uint8_t> state_buf(state_size);
gptj_copy_state_data(state->model, state->rng, state_buf.data());
if (!o.write(reinterpret_cast<const char*>(state_buf.data()), state_size)) {
LM_COTHROW("Failed to serialize state", LM_BOOL_ERROR);
LM_THROW("Failed to serialize state", LM_BOOL_ERROR);
}
LM_CORETURN LM_BOOL_SUCCESS;
return LM_BOOL_SUCCESS;
}
LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream &i) LM_NOEXCEPTDECL override {
LM_ERRBOOL deserialize(std::istream &i) LM_NOEXCEPTDECL override {
auto& state = get_state();
uint32_t embd_size, prompt_size, state_size;
// Initialization to prevent compiler complaints
@ -272,26 +275,26 @@ public:
// Read sizes
for (uint32_t *s : {&embd_size, &prompt_size, &state_size}) {
if (!i.read(reinterpret_cast<char*>(s), sizeof(*s))) {
LM_COTHROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
LM_THROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
}
}
// Read tokens
state->tokens.resize(embd_size);
if (!i.read(reinterpret_cast<char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
LM_COTHROW("Failed to deserialize tokens", LM_BOOL_ERROR);
LM_THROW("Failed to deserialize tokens", LM_BOOL_ERROR);
}
// Read prompt
state->prompt.resize(prompt_size);
if (!i.read(state->prompt.data(), state->prompt.size())) {
LM_COTHROW("Failed to deserialize prompt", LM_BOOL_ERROR);
LM_THROW("Failed to deserialize prompt", LM_BOOL_ERROR);
}
// Read state
std::vector<uint8_t> state_buf(state_size);
if (!i.read(reinterpret_cast<char*>(state_buf.data()), state_buf.size())) {
LM_COTHROW("Failed to deserialize state", LM_BOOL_ERROR);
LM_THROW("Failed to deserialize state", LM_BOOL_ERROR);
}
gptj_set_state_data(&state->model, &state->rng, state_buf.data());
LM_CORETURN LM_BOOL_SUCCESS;
return LM_BOOL_SUCCESS;
}
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
return get_state()->prompt;

View file

@ -3,12 +3,17 @@
#include <cstring>
#include <ggml.h>
#include <llama.h>
#include <common/grammar-parser.h>
namespace LM {
class LLaMAInference final : public Inference {
struct State {
llama_context *ctx = nullptr;
llama_model *model;
llama_grammar *grammar = nullptr;
bool grammar_override_temp;
grammar_parser::parse_state parsed_grammar;
std::string prompt; // Mostly here for easy "debugging"
std::vector<int> tokens;
unsigned n_ctx;
@ -31,12 +36,24 @@ class LLaMAInference final : public Inference {
auto lparams = llama_context_default_params();
lparams.seed = params.seed;
lparams.n_ctx = params.n_ctx = params.n_ctx>0?params.n_ctx:2024;
lparams.use_mlock = params.use_mlock;
lparams.n_threads = params.n_threads;
//lparams.n_threads_batch = params.n_threads; TODO: Is this sane?
// Get model parameters
auto mparams = llama_model_default_params();
mparams.use_mlock = params.use_mlock;
mparams.n_gpu_layers = params.n_gpu_layers;
// Load model
state->model = llama_load_model_from_file(weights_path.c_str(), mparams);
if (!state->model) {
LM_THROW("Failed to initialize llama model from file", LM_BOOL_ERROR);
}
// Create context
state->ctx = llama_init_from_file(weights_path.c_str(), lparams);
state->ctx = llama_new_context_with_model(state->model, lparams);
if (!state->ctx) {
LM_THROW("Failed to initialize llama from file", LM_BOOL_ERROR);
LM_THROW("Failed to initialize llama context from model", LM_BOOL_ERROR);
}
// Initialize some variables
@ -47,12 +64,12 @@ class LLaMAInference final : public Inference {
// This function reduces the size of our tokens vector according to some parameters
// All tokens will be evaluated if scrolling was needed and true will be returned
LM_SCHEDULABLE(bool) window_scroll() LM_NOEXCEPTDECL {
bool window_scroll() LM_NOEXCEPTDECL {
auto &state = get_state();
// Check that we actually need to scroll
if (state->tokens.size() <= state->n_ctx) {
// Nope
LM_CORETURN false;
return false;
}
// Start scrolling
if (params.scroll_keep > 0.0f) {
@ -69,11 +86,11 @@ class LLaMAInference final : public Inference {
state->tokens.resize(params.n_ctx_window_top_bar);
}
// Evaluate tokens
LM_ERROR_FORWARD(LM_COAWAIT evaluate_tokens(0, on_scroll));
LM_CORETURN true;
LM_ERROR_FORWARD(evaluate_tokens(0, on_scroll), LM_BOOL_ERROR);
return true;
}
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL {
LM_ERRBOOL evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick = nullptr) LM_NOEXCEPTDECL {
auto& state = get_state();
// Evaluate tokens in batches
@ -82,8 +99,9 @@ class LLaMAInference final : public Inference {
if (it + params.n_batch >= ssize_t(state->tokens.size())) break;
// Evaluate
if (llama_eval(state->ctx, state->tokens.data()+it, params.n_batch, it, params.n_threads)) {
LM_COTHROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
const auto batch = llama_batch_get_one(state->tokens.data()+it, params.n_batch, it, 0);
if (llama_decode(state->ctx, batch)) {
LM_THROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
}
// Tick
@ -91,16 +109,16 @@ class LLaMAInference final : public Inference {
// Calculate progress
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
// Tick and yield
if (!on_tick(progress)) LM_CORETURN LM_BOOL_SUCCESS;
else if (!LM_TASKYIELD) LM_CORETURN LM_BOOL_SUCCESS;
if (!on_tick(progress)) return LM_BOOL_SUCCESS;
}
}
// Evaluate remaining tokens
if (it < state->tokens.size()) {
for (; it != state->tokens.size(); it++) {
if (llama_eval(state->ctx, state->tokens.data()+it, 1, it, params.n_threads)) {
LM_COTHROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
const auto batch = llama_batch_get_one(state->tokens.data()+it, 1, it, 0);
if (llama_decode(state->ctx, batch)) {
LM_THROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
}
}
}
@ -108,14 +126,20 @@ class LLaMAInference final : public Inference {
// Notify about completion
if (on_tick) on_tick(100.f);
LM_CORETURN LM_BOOL_SUCCESS;
return LM_BOOL_SUCCESS;
}
int accept_token(int t) {
auto& state = get_state();
if (state->grammar)
llama_grammar_accept_token(state->ctx, state->grammar, t);
return t;
}
#if LLAMA_DATE >= 230519
int llama_sample_top_p_top_k() {
auto& state = get_state();
auto logits = llama_get_logits(state->ctx);
auto n_vocab = llama_n_vocab(state->ctx);
auto n_vocab = llama_n_vocab(state->model);
// Populate initial list of all candidates
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
@ -125,22 +149,40 @@ class LLaMAInference final : public Inference {
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
// Sample repeat penalty
auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
llama_sample_repetition_penalty(state->ctx, &candidates_p, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-n_repeat_last):nullptr, n_repeat_last, params.repeat_penalty);
// Temperature sampling
llama_sample_top_k(state->ctx, &candidates_p, params.top_k, 1);
llama_sample_tail_free(state->ctx, &candidates_p, 1.0f, 1);
llama_sample_typical(state->ctx, &candidates_p, 1.0f, 1);
llama_sample_top_p(state->ctx, &candidates_p, params.top_p, 1);
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
return llama_sample_token(state->ctx, &candidates_p);
llama_sample_repetition_penalties(state->ctx, &candidates_p, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-n_repeat_last):nullptr, n_repeat_last, params.repeat_penalty, 1.0f, 1.0f); // Might be wrong
// Grammar sampling
if (state->grammar) {
llama_sample_grammar(state->ctx, &candidates_p, state->grammar);
}
if (!(state->grammar && state->grammar_override_temp) && (params.temp > 0.01f || params.temp < -0.01f)) {
// Temperature sampling
switch (params.prefer_mirostat) {
case 0: {
llama_sample_top_k(state->ctx, &candidates_p, params.top_k, 1);
llama_sample_tail_free(state->ctx, &candidates_p, 1.0f, 1);
llama_sample_typical(state->ctx, &candidates_p, 1.0f, 1);
llama_sample_top_p(state->ctx, &candidates_p, params.top_p, 1);
llama_sample_temp(state->ctx, &candidates_p, params.temp);
return accept_token(llama_sample_token(state->ctx, &candidates_p));
}
case 1: {
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
const int mirostat_m = 100;
llama_sample_temp(state->ctx, &candidates_p, params.temp);
return accept_token(llama_sample_token_mirostat(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, mirostat_m, &mirostat_mu));
}
case 2: {
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
llama_sample_temp(state->ctx, &candidates_p, params.temp);
return accept_token(llama_sample_token_mirostat_v2(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, &mirostat_mu));
}
default: LM_THROW("Invalid mirostat version "+std::to_string(params.prefer_mirostat), LM_BOOL_ERROR);
}
} else {
// Greedy sampling
return accept_token(llama_sample_token(state->ctx, &candidates_p));
}
}
#else
int llama_sample_top_p_top_k() {
auto& state = get_state();
auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
return ::llama_sample_top_p_top_k(state->ctx, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-n_repeat_last):nullptr, n_repeat_last, params.top_k, params.top_p, params.temp, params.repeat_penalty);
}
#endif
public:
LLaMAInference(const std::string& weights_path, const Params& p) : Inference(p) {
@ -155,7 +197,7 @@ public:
}
}
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
LM_ERRBOOL append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
auto& state = get_state();
// Check if prompt was empty
@ -169,37 +211,44 @@ public:
state->tokens.resize(old_token_count+state->prompt.size());
// Run tokenizer
const auto token_count = llama_tokenize(state->ctx, prompt.c_str(), state->tokens.data()+old_token_count, state->tokens.size()-old_token_count, was_empty);
const auto token_count = llama_tokenize(state->model, prompt.c_str(), prompt.size(), state->tokens.data()+old_token_count, state->tokens.size()-old_token_count, was_empty, false);
state->tokens.resize(old_token_count+token_count);
// Make sure token limit isn't being hit
if (LM_COAWAIT window_scroll()) {
if (window_scroll()) {
// That function already has evaluated our tokens since scrolling was needed
LM_CORETURN LM_BOOL_SUCCESS;
return LM_BOOL_SUCCESS;
}
// Evaluate new tokens
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
return evaluate_tokens(old_token_count, on_tick);
}
LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
std::string run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
auto& state = get_state();
std::string fres;
// Loop until done
bool abort = false;
unsigned eos_count = 0;
size_t last_size = 0;
while (!abort && (end.empty() || fres.find(end) == fres.npos)) {
last_size = fres.size();
// Sample top p and top k
auto id = llama_sample_top_p_top_k();
int id;
try {
id = llama_sample_top_p_top_k();
} catch (const std::exception& e) {
LM_THROW(e.what(), "");
}
if (id == llama_token_eos()) {
if (eos_count++ == params.eos_ignores) {
if (id == llama_token_eos(state->model)) {
if (eos_count++ == params.n_eos_ignores) {
abort = true;
continue;
}
state->tokens.push_back(0);
llama_tokenize(state->ctx, "\n", &state->tokens.back(), 1, false);
llama_tokenize(state->model, "\n", 1, &state->tokens.back(), 1, false, false);
id = state->tokens.back();
} else {
// Add token
@ -207,85 +256,90 @@ public:
}
// Make sure token limit isn't hit
LM_COAWAIT window_scroll();
window_scroll();
// Get token as string
const std::string_view str = llama_token_to_str(state->ctx, id);
std::string str(14, ' ');
str.resize(llama_token_to_piece(state->model, id, str.data(), 14));
// Append string to function result
state->prompt.append(str);
fres.append(str);
// Evaluate token
// TODO: Respect batch size
if (llama_eval(state->ctx, state->tokens.data()+state->tokens.size()-1, 1, state->tokens.size()-1, params.n_threads)) {
LM_COTHROW("Failed to evaluate new tokens", "");
// Tick
if (pre_tick && !pre_tick(str.data())) abort = true;
else {
// Evaluate token
// TODO: Respect batch size
const auto batch = llama_batch_get_one(state->tokens.data()+state->tokens.size()-1, 1, state->tokens.size()-1, 0);
if (llama_decode(state->ctx, batch)) {
LM_THROW("Failed to evaluate new tokens", "");
}
}
// Tick and yield
if (on_tick && !on_tick(str.data())) abort = true;
else if (!LM_TASKYIELD) abort = true;
}
// Create final string TODO: Could be optimized
if (!abort && fres.size() > end.size()) {
fres = std::string(fres.data(), fres.size()-end.size());
fres = std::string(fres.data(), last_size);
}
// Return final string
LM_CORETURN fres;
return fres;
}
unsigned get_context_size() const noexcept override {
return get_state()->tokens.size();
}
LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
LM_ERRBOOL create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
auto& state = get_state();
sv.buf.resize(llama_get_state_size(state->ctx));
llama_copy_state_data(state->ctx, sv.buf.data());
sv.tokens = state->tokens;
sv.prompt = state->prompt;
sv.ctx = generic_state;
LM_CORETURN LM_BOOL_SUCCESS;
return LM_BOOL_SUCCESS;
}
LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
LM_ERRBOOL restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
auto& state = get_state();
if (sv.ctx != generic_state)
LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR);
LM_THROW("Savestate does not match context", LM_BOOL_ERROR);
llama_set_state_data(state->ctx, const_cast<uint8_t*>(sv.buf.data()));
state->tokens = sv.tokens;
state->prompt = sv.prompt;
LM_CORETURN LM_BOOL_SUCCESS;
return LM_BOOL_SUCCESS;
}
LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
LM_ERRBOOL serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
auto& state = get_state();
// Get state size
auto state_size = llama_get_state_size(state->ctx);
// Write sizes
for (const uint32_t s : {static_cast<size_t>(state->n_ctx), state->tokens.size(), state->prompt.size(), state_size}) {
if (!o.write(reinterpret_cast<const char*>(&s), sizeof(s))) {
LM_COTHROW("Failed to serialize data sizes", LM_BOOL_ERROR);
LM_THROW("Failed to serialize data sizes", LM_BOOL_ERROR);
}
}
// Write tokens
if (!o.write(reinterpret_cast<const char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
LM_COTHROW("Failed to serialize tokens", LM_BOOL_ERROR);
LM_THROW("Failed to serialize tokens", LM_BOOL_ERROR);
}
// Write prompt
if (!o.write(state->prompt.data(), state->prompt.size())) {
LM_COTHROW("Failed to serialize prompt", LM_BOOL_ERROR);
LM_THROW("Failed to serialize prompt", LM_BOOL_ERROR);
}
// Write state
std::vector<uint8_t> state_buf(state_size);
llama_copy_state_data(state->ctx, state_buf.data());
if (!o.write(reinterpret_cast<const char*>(state_buf.data()), state_size)) {
LM_COTHROW("Failed to serialize state", LM_BOOL_ERROR);
LM_THROW("Failed to serialize state", LM_BOOL_ERROR);
}
LM_CORETURN LM_BOOL_SUCCESS;
return LM_BOOL_SUCCESS;
}
LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream &i) LM_NOEXCEPTDECL override {
LM_ERRBOOL deserialize(std::istream &i) LM_NOEXCEPTDECL override {
auto& state = get_state();
uint32_t n_ctx, embd_size, prompt_size, state_size;
// Initialization to prevent compiler complaints
@ -293,33 +347,65 @@ public:
// Read sizes
for (uint32_t *s : {&n_ctx, &embd_size, &prompt_size, &state_size}) {
if (!i.read(reinterpret_cast<char*>(s), sizeof(*s))) {
LM_COTHROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
LM_THROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
}
}
if (state->n_ctx != n_ctx) {
LM_COTHROW("Context length differs (My "+std::to_string(state->n_ctx)+" vs. files "+std::to_string(n_ctx)+')', LM_BOOL_ERROR);
LM_THROW("Context length differs (My "+std::to_string(state->n_ctx)+" vs. files "+std::to_string(n_ctx)+')', LM_BOOL_ERROR);
}
// Read tokens
state->tokens.resize(embd_size);
if (!i.read(reinterpret_cast<char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
LM_COTHROW("Failed to deserialize tokens", LM_BOOL_ERROR);
LM_THROW("Failed to deserialize tokens", LM_BOOL_ERROR);
}
// Read prompt
state->prompt.resize(prompt_size);
if (!i.read(state->prompt.data(), state->prompt.size())) {
LM_COTHROW("Failed to deserialize prompt", LM_BOOL_ERROR);
LM_THROW("Failed to deserialize prompt", LM_BOOL_ERROR);
}
// Read state
std::vector<uint8_t> state_buf(state_size);
if (!i.read(reinterpret_cast<char*>(state_buf.data()), state_buf.size())) {
LM_COTHROW("Failed to deserialize state", LM_BOOL_ERROR);
LM_THROW("Failed to deserialize state", LM_BOOL_ERROR);
}
llama_set_state_data(state->ctx, state_buf.data());
LM_CORETURN LM_BOOL_SUCCESS;
return LM_BOOL_SUCCESS;
}
LM_ERRBOOL load_grammar(const std::string& src, bool override_temperature) LM_NOEXCEPTDECL override {
auto& state = get_state();
state->parsed_grammar = grammar_parser::parse(src.c_str());
if (state->parsed_grammar.rules.empty()) {
LM_THROW("Failed to parse grammar (or no rules)", LM_BOOL_ERROR);
}
auto rules = state->parsed_grammar.c_rules();
state->grammar = llama_grammar_init(rules.data(), rules.size(), state->parsed_grammar.symbol_ids.at("root"));
if (!state->grammar) {
LM_THROW("Failed to generate llama grammar", LM_BOOL_ERROR);
}
state->grammar_override_temp = override_temperature;
return LM_BOOL_SUCCESS;
}
LM_ERRBOOL unload_grammar() LM_NOEXCEPTDECL override {
get_state()->grammar = nullptr;
return LM_BOOL_SUCCESS;
}
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
return get_state()->prompt;
}
bool is_mirostat_available() const noexcept override {
return true;
}
bool is_grammar_available() const noexcept override {
return true;
}
};
}

View file

@ -68,12 +68,12 @@ class MPTInference final : public Inference {
// This function reduces the size of our tokens vector according to some parameters
// All tokens will be evaluated if scrolling was needed and true will be returned
LM_SCHEDULABLE(bool) window_scroll() LM_NOEXCEPTDECL {
bool window_scroll() LM_NOEXCEPTDECL {
auto &state = get_state();
// Check that we actually need to scroll
if (state->tokens.size() <= params.n_ctx) {
// Nope
LM_CORETURN false;
return false;
}
// Start scrolling
if (params.scroll_keep > 0.0f) {
@ -90,11 +90,11 @@ class MPTInference final : public Inference {
state->tokens.resize(params.n_ctx_window_top_bar);
}
// Evaluate tokens
LM_ERROR_FORWARD(LM_COAWAIT evaluate_tokens(0, on_scroll));
LM_CORETURN true;
LM_ERROR_FORWARD(evaluate_tokens(0, on_scroll), LM_BOOL_ERROR);
return true;
}
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL {
LM_ERRBOOL evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick) LM_NOEXCEPTDECL {
auto& state = get_state();
// Evaluate tokens in batches
@ -105,7 +105,7 @@ class MPTInference final : public Inference {
// Evaluate
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+params.n_batch);
if (!mpt_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
LM_COTHROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
LM_THROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
}
// Tick
@ -113,8 +113,7 @@ class MPTInference final : public Inference {
// Calculate progress
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
// Tick and yield
if (!on_tick(progress)) LM_CORETURN LM_BOOL_SUCCESS;
else if (!LM_TASKYIELD) LM_CORETURN LM_BOOL_SUCCESS;
if (!on_tick(progress)) return LM_BOOL_SUCCESS;
}
}
@ -124,7 +123,7 @@ class MPTInference final : public Inference {
//TODO: This is extremely inefficient! Don't do that...
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+1);
if (!mpt_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
LM_COTHROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
LM_THROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
}
}
}
@ -132,7 +131,7 @@ class MPTInference final : public Inference {
// Notify about completion
if (on_tick) on_tick(100.f);
LM_CORETURN LM_BOOL_SUCCESS;
return LM_BOOL_SUCCESS;
}
public:
@ -143,7 +142,7 @@ public:
deinit();
}
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
LM_ERRBOOL append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
auto& state = get_state();
// Append to current prompt
@ -161,35 +160,37 @@ public:
);
// Make sure token limit isn't being hit
if (LM_COAWAIT window_scroll()) {
if (window_scroll()) {
// That function already has evaluated our tokens since scrolling was needed
LM_CORETURN LM_BOOL_SUCCESS;
return LM_BOOL_SUCCESS;
}
// Evaluate new tokens
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
return evaluate_tokens(old_token_count, on_tick);
}
LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
std::string run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
auto& state = get_state();
std::string fres;
// Loop until done
bool abort = false;
unsigned eos_count = 0;
size_t last_size = 0;
while (!abort && (end.empty() || fres.find(end) == fres.npos)) {
last_size = fres.size();
// Sample top p and top k
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
if (state->im_end && id == state->im_end) {
if (eos_count++ == params.eos_ignores) {
if (eos_count++ == params.n_eos_ignores) {
abort = true;
continue;
}
id = gpt_tokenize(state->vocab, "\n")[0];
} else if (id == 0) {
if (eos_count++ == params.eos_ignores) {
if (eos_count++ == params.n_eos_ignores) {
abort = true;
continue;
}
@ -200,7 +201,7 @@ public:
state->tokens.push_back(id);
// Make sure token limit isn't being hit
LM_COAWAIT window_scroll();
window_scroll();
// Get token as string
const std::string_view str = state->vocab.id_to_token[id];
@ -209,77 +210,80 @@ public:
fres.append(str);
state->prompt.append(str);
// Evaluate token
// TODO: Respect batch size
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
if (!mpt_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
LM_COTHROW("Failed to evaluate new tokens", "");
// Tick
if (pre_tick && !pre_tick(str.data())) abort = true;
else {
// Evaluate token
// TODO: Respect batch size
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
if (!mpt_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
LM_THROW("Failed to evaluate new tokens", "");
}
}
// Tick
if (on_tick && !on_tick(str.data())) abort = true;
else if (!LM_TASKYIELD) abort = true;
}
// Create final string TODO: Could be optimized
if (!abort) {
fres = std::string(fres.data(), fres.size()-end.size());
fres = std::string(fres.data(), last_size);
}
// Return final string
LM_CORETURN fres;
return fres;
}
unsigned get_context_size() const noexcept override {
return get_state()->tokens.size();
}
LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
LM_ERRBOOL create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
auto& state = get_state();
sv.buf.resize(mpt_get_state_size(state->model));
mpt_copy_state_data(state->model, state->rng, sv.buf.data());
sv.tokens = state->tokens;
sv.prompt = state->prompt;
sv.ctx = generic_state;
LM_CORETURN LM_BOOL_SUCCESS ;
return LM_BOOL_SUCCESS ;
}
LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
LM_ERRBOOL restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
auto& state = get_state();
if (sv.ctx != generic_state)
LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR);
LM_THROW("Savestate does not match context", LM_BOOL_ERROR);
mpt_set_state_data(&state->model, &state->rng, sv.buf.data());
state->tokens = sv.tokens;
state->prompt = sv.prompt;
LM_CORETURN LM_BOOL_SUCCESS;
return LM_BOOL_SUCCESS;
}
LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
LM_ERRBOOL serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
auto& state = get_state();
// Get state size
auto state_size = mpt_get_state_size(state->model);
// Write sizes
for (const uint32_t s : {state->tokens.size(), state->prompt.size(), state_size}) {
if (!o.write(reinterpret_cast<const char*>(&s), sizeof(s))) {
LM_COTHROW("Failed to serialize data sizes", LM_BOOL_ERROR);
LM_THROW("Failed to serialize data sizes", LM_BOOL_ERROR);
}
}
// Write tokens
if (!o.write(reinterpret_cast<const char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
LM_COTHROW("Failed to serialize tokens", LM_BOOL_ERROR);
LM_THROW("Failed to serialize tokens", LM_BOOL_ERROR);
}
// Write prompt
if (!o.write(state->prompt.data(), state->prompt.size())) {
LM_COTHROW("Failed to serialize prompt", LM_BOOL_ERROR);
LM_THROW("Failed to serialize prompt", LM_BOOL_ERROR);
}
// Write state
std::vector<uint8_t> state_buf(state_size);
mpt_copy_state_data(state->model, state->rng, state_buf.data());
if (!o.write(reinterpret_cast<const char*>(state_buf.data()), state_size)) {
LM_COTHROW("Failed to serialize state", LM_BOOL_ERROR);
LM_THROW("Failed to serialize state", LM_BOOL_ERROR);
}
LM_CORETURN LM_BOOL_SUCCESS;
return LM_BOOL_SUCCESS;
}
LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream &i) LM_NOEXCEPTDECL override {
LM_ERRBOOL deserialize(std::istream &i) LM_NOEXCEPTDECL override {
auto& state = get_state();
uint32_t embd_size, promptsize, state_size;
// Initialization to prevent compiler complaints
@ -287,26 +291,26 @@ public:
// Read sizes
for (uint32_t *s : {&embd_size, &promptsize, &state_size}) {
if (!i.read(reinterpret_cast<char*>(s), sizeof(*s))) {
LM_COTHROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
LM_THROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
}
}
// Read tokens
state->tokens.resize(embd_size);
if (!i.read(reinterpret_cast<char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
LM_COTHROW("Failed to deserialize tokens", LM_BOOL_ERROR);
LM_THROW("Failed to deserialize tokens", LM_BOOL_ERROR);
}
// Read prompt
state->prompt.resize(promptsize);
if (!i.read(state->prompt.data(), state->prompt.size())) {
LM_COTHROW("Failed to deserialize prompt", LM_BOOL_ERROR);
LM_THROW("Failed to deserialize prompt", LM_BOOL_ERROR);
}
// Read state
std::vector<uint8_t> state_buf(state_size);
if (!i.read(reinterpret_cast<char*>(state_buf.data()), state_buf.size())) {
LM_COTHROW("Failed to deserialize state", LM_BOOL_ERROR);
LM_THROW("Failed to deserialize state", LM_BOOL_ERROR);
}
mpt_set_state_data(&state->model, &state->rng, state_buf.data());
LM_CORETURN LM_BOOL_SUCCESS;
return LM_BOOL_SUCCESS;
}
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
return get_state()->prompt;

View file

@ -6,7 +6,7 @@
LM_SCHEDULABLE(bool) LM::InferencePool::store_slot(Slot &slot) {
bool LM::InferencePool::store_slot(Slot &slot) {
auto inference = slot.get_inference();
// Open output file
std::ofstream f(get_slot_filename(slot.get_id()), std::ios::binary);
@ -17,61 +17,61 @@ LM_SCHEDULABLE(bool) LM::InferencePool::store_slot(Slot &slot) {
f.write(weights_path.data(), weights_path.size());
// Write params
if (!f.write(reinterpret_cast<const char*>(&inference->params), sizeof(inference->params))) {
LM_CORETURN false;
return false;
}
// Serialize instance
try {
LM_COAWAIT inference->serialize(f);
inference->serialize(f);
} catch (...) {
LM_CORETURN false;
return false;
}
// Return success
LM_CORETURN true;
return true;
}
LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::load_slot(size_t id, Slot *suggested_slot) {
LM::InferencePool::Slot *LM::InferencePool::load_slot(size_t id, Slot *suggested_slot) {
// Open input file
std::ifstream f(get_slot_filename(id), std::ios::binary);
if (!f) {
// Does not exist
LM_CORETURN nullptr;
return nullptr;
}
// Read weights path
std::string weights_path;
uint32_t weights_path_len;
if (!f.read(reinterpret_cast<char*>(&weights_path_len), sizeof(weights_path_len))) {
LM_CORETURN nullptr;
return nullptr;
}
weights_path.resize(weights_path_len);
if (!f.read(weights_path.data(), weights_path.size())) {
LM_CORETURN nullptr;
return nullptr;
}
// Read params
LM::Inference::Params p;
if (!f.read(reinterpret_cast<char*>(&p), sizeof(p))) {
LM_CORETURN nullptr;
return nullptr;
}
// Create instance
auto& slot = suggested_slot?*suggested_slot:*(LM_COAWAIT get_free_slot());
auto& slot = suggested_slot?*suggested_slot:*(get_free_slot());
auto inference = slot.create_inference(id, weights_path, p);
// Deserialize instance
try {
LM_COAWAIT inference->deserialize(f);
inference->deserialize(f);
} catch (...) {
slot.reset();
LM_CORETURN nullptr;
return nullptr;
}
// Return final slot
LM_CORETURN &slot;
return &slot;
}
LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::get_free_slot() {
LM::InferencePool::Slot *LM::InferencePool::get_free_slot() {
// Attempt to find free slot while finding oldest one
Slot *oldest = nullptr;
for (auto& slot : slots) {
// Take free slot
if (slot.is_free()) {
LM_CORETURN &slot;
return &slot;
}
// Update oldest
if (oldest == nullptr || slot.get_last_access() < oldest->get_last_access()) {
@ -80,17 +80,17 @@ LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::get_free_slot() {
}
// Free up oldest slot and take that one
// Note: Since there has to be at least 1 slot, oldest is never going to be a nullptr
LM_COAWAIT store_and_reset_slot(*oldest);
LM_CORETURN oldest;
store_and_reset_slot(*oldest);
return oldest;
}
LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::find_slot_by_id(size_t id, bool deserialize) {
LM::InferencePool::Slot *LM::InferencePool::find_slot_by_id(size_t id, bool deserialize) {
// Attempt to find given slot while finding oldest one
Slot *oldest = nullptr;
for (auto& slot : slots) {
// Take slot with ID
if (slot.get_id() == id) {
LM_CORETURN &slot;
return &slot;
}
// Update oldest
if (oldest == nullptr || slot.get_last_access() < oldest->get_last_access()) {
@ -99,38 +99,38 @@ LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::find_slot_by_id(size
}
// Slot not found, attempt to load it
if (deserialize) {
if (!oldest->is_free()) LM_COAWAIT store_slot(*oldest);
if (!LM_COAWAIT load_slot(id, oldest)) {
if (!oldest->is_free()) store_slot(*oldest);
if (!load_slot(id, oldest)) {
// In case slot loading failed, still reset slot for later use
//TODO: Make this configurable
oldest->reset();
} else {
LM_CORETURN oldest;
return oldest;
}
}
// Slot not found
LM_CORETURN nullptr;
return nullptr;
}
LM_SCHEDULABLE(std::shared_ptr<LM::Inference>) LM::InferencePool::get_inference(size_t id) {
auto slot = LM_COAWAIT find_slot_by_id(id);
std::shared_ptr<LM::Inference> LM::InferencePool::get_inference(size_t id) {
auto slot = find_slot_by_id(id);
if (slot) {
LM_CORETURN slot->get_inference(true);
return slot->get_inference(true);
}
LM_CORETURN {};
return {};
}
LM_SCHEDULABLE(std::shared_ptr<LM::Inference>) LM::InferencePool::get_or_create_inference(size_t id, const std::string &weights_path, const Inference::Params &p) {
auto slot = LM_COAWAIT find_slot_by_id(id);
std::shared_ptr<LM::Inference> LM::InferencePool::get_or_create_inference(size_t id, const std::string &weights_path, const Inference::Params &p) {
auto slot = find_slot_by_id(id);
if (slot) {
LM_CORETURN slot->get_inference(true);
return slot->get_inference(true);
}
slot = LM_COAWAIT get_free_slot();
LM_CORETURN slot->create_inference(id, weights_path, p);
slot = get_free_slot();
return slot->create_inference(id, weights_path, p);
}
LM_SCHEDULABLE(void) LM::InferencePool::delete_inference(size_t id) {
auto slot = LM_COAWAIT find_slot_by_id(id, false);
void LM::InferencePool::delete_inference(size_t id) {
auto slot = find_slot_by_id(id, false);
// Reset slot
if (slot) {
slot->reset();
@ -140,12 +140,12 @@ LM_SCHEDULABLE(void) LM::InferencePool::delete_inference(size_t id) {
std::filesystem::remove(get_slot_filename(id), ec);
}
LM_SCHEDULABLE(void) LM::InferencePool::store_all() {
void LM::InferencePool::store_all() {
for (auto& slot : slots) {
if (slot.is_free()) continue;
LM_COAWAIT store_slot(slot);
store_slot(slot);
}
LM_CORETURN;
return;
}
std::vector<size_t> LM::InferencePool::get_active_slot_ids() const {

View file

@ -18,11 +18,7 @@ bool magic_match(std::istream& f) {
// Check magic
uint32_t magic = 0;
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
if (magic != 0x67676a74) return false;
// Check version
uint32_t version = 0;
f.read(reinterpret_cast<char*>(&version), sizeof(version));
return version LLAMA_VERSIONS;
return magic == 0x46554747;
}
LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) {
@ -30,3 +26,14 @@ LM::Inference *construct(const std::string &weights_path, std::ifstream& f, cons
return new LM::LLaMAInference(weights_path, p);
}
}
__attribute__((constructor))
static void init() {
llama_backend_init(true);
}
__attribute__((destructor))
static void deinit() {
llama_backend_free();
}

@ -1 +0,0 @@
Subproject commit 0e018fe008eacebdbcfa2d61b6c988c245c961cd

@ -1 +0,0 @@
Subproject commit 5ea43392731040b454c293123839b90e159cbb99

@ -1 +1 @@
Subproject commit 29cf5596fe0c37213f9b74e80d8f631193a93f0f
Subproject commit b9f47952ffae4e0d3420905526003c23333f6c98

View file

@ -51,22 +51,27 @@ option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer"
option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF)
# instruction set specific
#option(LLAMA_AVX "llama: enable AVX" ON)
#option(LLAMA_AVX2 "llama: enable AVX2" ON)
#option(LLAMA_AVX512 "llama: enable AVX512" OFF)
#option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
#option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
#option(LLAMA_FMA "llama: enable FMA" ON)
option(LLAMA_AVX "llama: enable AVX" ON)
option(LLAMA_AVX2 "llama: enable AVX2" ON)
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
option(LLAMA_FMA "llama: enable FMA" ON)
# in MSVC F16C is implied with AVX2/AVX512
#if (NOT MSVC)
# option(LLAMA_F16C "llama: enable F16C" ON)
#endif()
if (NOT MSVC)
option(LLAMA_F16C "llama: enable F16C" ON)
endif()
# 3rd party libs
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF)
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
option(LLAMA_METAL "llama: use Metal" OFF)
option(LLAMA_K_QUANTS "llama: use k-quants" ON)
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
#
# Compile flags
@ -207,89 +212,36 @@ if (NOT MSVC)
endif()
endif()
function(remove_nonexistent SOURCES)
set(SOURCES_BAK ${${SOURCES}})
set(${SOURCES} )
foreach (FILE ${SOURCES_BAK})
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${FILE})
set(${SOURCES} ${${SOURCES}} ${FILE})
endif()
endforeach()
set(${SOURCES} ${${SOURCES}} PARENT_SCOPE)
endfunction()
function(include_ggml DIRECTORY SUFFIX WITH_LLAMA)
message(STATUS "Configuring ggml implementation target llama${SUFFIX} in ${CMAKE_CURRENT_SOURCE_DIR}/${DIRECTORY}")
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
message(STATUS "ARM detected")
if (MSVC)
# TODO: arm msvc?
else()
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
add_compile_options(-mcpu=native)
endif()
# TODO: armv6,7,8 version specific flags
endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
message(STATUS "x86 detected")
if (MSVC)
if (LLAMA_AVX512)
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX512>)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
# MSVC has no compile-time flags enabling specific
# AVX512 extensions, neither it defines the
# macros corresponding to the extensions.
# Do it manually.
if (LLAMA_AVX512_VBMI)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
endif()
if (LLAMA_AVX512_VNNI)
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
endif()
elseif (LLAMA_AVX2)
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX2>)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>)
elseif (LLAMA_AVX)
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX>)
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX>)
endif()
else()
if (LLAMA_F16C)
add_compile_options(-mf16c)
endif()
if (LLAMA_FMA)
add_compile_options(-mfma)
endif()
if (LLAMA_AVX)
add_compile_options(-mavx)
endif()
if (LLAMA_AVX2)
add_compile_options(-mavx2)
endif()
if (LLAMA_AVX512)
add_compile_options(-mavx512f)
add_compile_options(-mavx512bw)
endif()
if (LLAMA_AVX512_VBMI)
add_compile_options(-mavx512vbmi)
endif()
if (LLAMA_AVX512_VNNI)
add_compile_options(-mavx512vnni)
endif()
endif()
else()
# TODO: support PowerPC
message(STATUS "Unknown architecture")
endif()
#
# Build libraries
#
set(GGML_CUBLAS_USE NO)
if (LLAMA_CUBLAS)
cmake_minimum_required(VERSION 3.17)
find_package(CUDAToolkit)
if (CUDAToolkit_FOUND)
set(GGML_CUBLAS_USE YES)
message(STATUS "cuBLAS found")
enable_language(CUDA)
set(GGML_CUDA_SOURCES ${DIRECTORY}ggml-cuda.cu ${DIRECTORY}ggml-cuda.h)
add_compile_definitions(GGML_USE_CUBLAS)
set(GGML_SOURCES_CUDA ${DIRECTORY}/ggml-cuda.cu ${DIRECTORY}/ggml-cuda.h)
if (LLAMA_STATIC)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
@ -302,14 +254,19 @@ function(include_ggml DIRECTORY SUFFIX WITH_LLAMA)
endif()
endif()
set(GGML_CLBLAST_USE NO)
if (LLAMA_CLBLAST)
find_package(CLBlast)
if (CLBlast_FOUND)
set(GGML_CLBLAST_USE YES)
message(STATUS "CLBlast found")
set(GGML_OPENCL_SOURCES ${DIRECTORY}ggml-opencl.c ${DIRECTORY}ggml-opencl.h)
set(GGML_OPENCL_SOURCE_FILE ggml-opencl.cpp)
if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${DIRECTORY}/${GGML_OPENCL_SOURCE_FILE})
set(GGML_OPENCL_SOURCE_FILE ggml-opencl.c)
endif()
add_compile_definitions(GGML_USE_CLBLAST)
set(GGML_OPENCL_SOURCES ${DIRECTORY}/${GGML_OPENCL_SOURCE_FILE} ${DIRECTORY}/ggml-opencl.h)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast)
else()
@ -317,35 +274,73 @@ function(include_ggml DIRECTORY SUFFIX WITH_LLAMA)
endif()
endif()
add_library(ggml${SUFFIX} OBJECT
${DIRECTORY}/ggml.c
${DIRECTORY}/ggml.h
${GGML_CUDA_SOURCES}
${GGML_OPENCL_SOURCES})
set(GGML_METAL_SOURCES )
if (LLAMA_METAL)
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
set(GGML_METAL_SOURCES ${DIRECTORY}/ggml-metal.m ${DIRECTORY}/ggml-metal.h)
# get full path to the file
#add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/")
# copy ggml-metal.metal to bin directory
configure_file(${DIRECTORY}/ggml-metal.metal bin/ggml-metal.metal COPYONLY)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS}
${FOUNDATION_LIBRARY}
${METAL_FRAMEWORK}
${METALKIT_FRAMEWORK}
${METALPERFORMANCE_FRAMEWORK}
)
endif()
set(GGML_SOURCES
${DIRECTORY}/ggml.c
${DIRECTORY}/ggml.h
${DIRECTORY}/ggml-alloc.c
${DIRECTORY}/ggml-alloc.h
${DIRECTORY}/ggml-quants.c
${DIRECTORY}/ggml-quants.h
${DIRECTORY}/ggml-backend.c
${DIRECTORY}/ggml-backend.h}
${GGML_SOURCES_CUDA}
${GGML_METAL_SOURCES}
${GGML_OPENCL_SOURCES})
remove_nonexistent(GGML_SOURCES)
add_library(ggml${SUFFIX} OBJECT ${GGML_SOURCES})
target_compile_definitions(ggml${SUFFIX} PRIVATE _GNU_SOURCE)
if (LLAMA_K_QUANTS)
target_compile_definitions(ggml${SUFFIX} PUBLIC GGML_USE_K_QUANTS)
endif()
if (LLAMA_METAL AND GGML_METAL_SOURCES)
target_compile_definitions(ggml${SUFFIX} PUBLIC GGML_USE_METAL GGML_METAL_NDEBUG)
endif()
target_include_directories(ggml${SUFFIX} PUBLIC ${DIRECTORY})
target_compile_features(ggml${SUFFIX} PUBLIC c_std_11) # don't bump
target_link_libraries(ggml${SUFFIX} PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
if (BUILD_SHARED_LIBS)
set_target_properties(ggml${SUFFIX} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
if (WITH_LLAMA)
# Backwards compatibility with old llama.cpp versions
set(LLAMA_UTIL_SOURCE_FILE llama-util.h)
if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${DIRECTORY}/${LLAMA_UTIL_SOURCE_FILE})
set(LLAMA_UTIL_SOURCE_FILE llama_util.h)
SET(LLAMA_SOURCES
${DIRECTORY}/llama.cpp
${DIRECTORY}/llama.h
${DIRECTORY}/common/grammar-parser.h
${DIRECTORY}/common/grammar-parser.cpp)
remove_nonexistent(LLAMA_SOURCES)
add_library(llama${SUFFIX} ${LLAMA_SOURCES})
if (LLAMA_METAL AND GGML_METAL_SOURCES)
target_compile_definitions(llama${SUFFIX} PUBLIC GGML_USE_METAL GGML_METAL_NDEBUG)
endif()
add_library(llama${SUFFIX}
${DIRECTORY}/llama.cpp
${DIRECTORY}/llama.h
${DIRECTORY}/${LLAMA_UTIL_SOURCE_FILE})
target_include_directories(llama${SUFFIX} PUBLIC ${DIRECTORY})
target_compile_features(llama${SUFFIX} PUBLIC cxx_std_11) # don't bump
target_link_libraries(llama${SUFFIX} PRIVATE ggml${SUFFIX} ${LLAMA_EXTRA_LIBS})
if (BUILD_SHARED_LIBS)
set_target_properties(llama${SUFFIX} PROPERTIES POSITION_INDEPENDENT_CODE ON)
@ -353,7 +348,7 @@ function(include_ggml DIRECTORY SUFFIX WITH_LLAMA)
endif()
endif()
if (GGML_CUDA_SOURCES)
if (GGML_SOURCES_CUDA)
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
set_property(TARGET ggml${SUFFIX} PROPERTY CUDA_ARCHITECTURES OFF)
set_property(TARGET ggml${SUFFIX} PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
@ -361,4 +356,97 @@ function(include_ggml DIRECTORY SUFFIX WITH_LLAMA)
set_property(TARGET llama${SUFFIX} PROPERTY CUDA_ARCHITECTURES OFF)
endif()
endif()
if (GGML_CUBLAS_USE)
target_compile_definitions(ggml${SUFFIX} PRIVATE
GGML_USE_CUBLAS
GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}
GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
if (WITH_LLAMA)
target_compile_definitions(llama${SUFFIX} PRIVATE
GGML_USE_CUBLAS
GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}
GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
endif()
endif()
if (GGML_CLBLAST_USE)
if (WITH_LLAMA)
target_compile_definitions(llama${SUFFIX} PRIVATE GGML_USE_CLBLAST)
endif()
target_compile_definitions(ggml${SUFFIX} PRIVATE GGML_USE_CLBLAST)
endif()
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
message(STATUS "ARM detected")
if (MSVC)
# TODO: arm msvc?
else()
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
target_compile_options(ggml${SUFFIX} PRIVATE -mcpu=native)
endif()
# TODO: armv6,7,8 version specific flags
endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
message(STATUS "x86 detected")
if (MSVC)
if (LLAMA_AVX512)
target_compile_definitions(ggml${SUFFIX} PRIVATE
$<$<COMPILE_LANGUAGE:C>:/arch:AVX512>
$<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
# MSVC has no compile-time flags enabling specific
# AVX512 extensions, neither it defines the
# macros corresponding to the extensions.
# Do it manually.
if (LLAMA_AVX512_VBMI)
target_compile_definitions(ggml${SUFFIX} PRIVATE
$<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>
$<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
endif()
if (LLAMA_AVX512_VNNI)
target_compile_definitions(ggml${SUFFIX} PRIVATE
$<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>
$<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
endif()
elseif (LLAMA_AVX2)
target_compile_definitions(ggml${SUFFIX} PRIVATE
$<$<COMPILE_LANGUAGE:C>:/arch:AVX2>
$<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>)
elseif (LLAMA_AVX)
target_compile_definitions(ggml${SUFFIX} PRIVATE
$<$<COMPILE_LANGUAGE:C>:/arch:AVX>
$<$<COMPILE_LANGUAGE:CXX>:/arch:AVX>)
endif()
else()
if (LLAMA_F16C)
target_compile_options(ggml${SUFFIX} PRIVATE -mf16c)
endif()
if (LLAMA_FMA)
target_compile_options(ggml${SUFFIX} PRIVATE -mfma)
endif()
if (LLAMA_AVX)
target_compile_options(ggml${SUFFIX} PRIVATE -mavx)
endif()
if (LLAMA_AVX2)
target_compile_options(ggml${SUFFIX} PRIVATE -mavx2)
endif()
if (LLAMA_AVX512)
target_compile_options(ggml${SUFFIX} PRIVATE -mavx512f)
target_compile_options(ggml${SUFFIX} PRIVATE -mavx512bw)
endif()
if (LLAMA_AVX512_VBMI)
target_compile_options(ggml${SUFFIX} PRIVATE -mavx512vbmi)
endif()
if (LLAMA_AVX512_VNNI)
target_compile_options(ggml${SUFFIX} PRIVATE -mavx512vnni)
endif()
endif()
else()
# TODO: support PowerPC
message(STATUS "Unknown architecture")
endif()
target_link_libraries(ggml${SUFFIX} PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
if (WITH_LLAMA)
target_link_libraries(llama${SUFFIX} PRIVATE ggml${SUFFIX} ${LLAMA_EXTRA_LIBS})
endif()
endfunction()

View file

@ -11,7 +11,7 @@
#include <string>
#include <vector>
#include <iostream>
#include <unistd.h>
#include "../msvc_compat_unistd.h"
#include <sstream>
#include <thread>
#include <unordered_set>
@ -19,7 +19,7 @@
#include <ggml.h>
inline
unsigned long long operator ""_MB(unsigned long long bytes) {
unsigned long long operator ""_MiB(unsigned long long bytes) {
return bytes*1024*1024;
}
@ -34,7 +34,7 @@ static bool kv_cache_init(
const int64_t n_mem = (int64_t)n_layer*n_ctx;
const int64_t n_elements = n_embd*n_mem;
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MB);
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MiB);
struct ggml_init_params params;
params.mem_size = cache.buf.size;
@ -356,7 +356,7 @@ bool mpt_eval(
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
const size_t init_buf_size = 1024_MB;
const size_t init_buf_size = 1024_MiB;
if (!model.buf.addr || model.buf.size < init_buf_size)
model.buf.resize(init_buf_size);

11
msvc_compat_unistd.h Normal file
View file

@ -0,0 +1,11 @@
#if defined(_WIN32) && defined(_MSC_VER)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#include <io.h>
#include <stdio.h> // for _fseeki64
#else
#include <unistd.h>
#endif

View file

@ -24,16 +24,23 @@ PYBIND11_MODULE(justlm_py, m) {
.def_readwrite("top_p", &Inference::Params::top_p)
.def_readwrite("temp", &Inference::Params::temp)
.def_readwrite("repeat_penalty", &Inference::Params::repeat_penalty)
.def_readwrite("eos_ignores", &Inference::Params::eos_ignores)
.def_readwrite("use_mlock", &Inference::Params::use_mlock);
.def_readwrite("eos_ignores", &Inference::Params::n_eos_ignores)
.def_readwrite("use_mlock", &Inference::Params::use_mlock)
.def_readwrite("prefer_mirostat", &Inference::Params::prefer_mirostat)
.def_readwrite("mirostat_learning_rate", &Inference::Params::mirostat_learning_rate)
.def_readwrite("mirostat_target_entropy", &Inference::Params::mirostat_target_entropy);
py::class_<Inference>(m, "Inference")
.def_static("construct", &Inference::construct, py::arg("weights_path"), py::arg("params") = Inference::Params())
.def("append", &Inference::append, py::arg("prompt"), py::arg("on_tick") = nullptr)
.def("run", &Inference::run, py::arg("end") = "", py::arg("on_tick") = nullptr)
.def("run", &Inference::run, py::arg("end") = "", py::arg("on_tick") = nullptr, py::arg("pre_tick") = nullptr)
.def("create_savestate", &Inference::create_savestate)
.def("restore_savestate", &Inference::restore_savestate)
.def("get_prompt", &Inference::get_prompt)
.def("get_context_size", &Inference::get_context_size)
.def("is_mirostat_available", &Inference::is_mirostat_available)
.def("is_grammar_available", &Inference::is_grammar_available)
.def("load_grammar", &Inference::load_grammar)
.def("unload_grammar", &Inference::unload_grammar)
.def_readwrite("params", &Inference::params);
py::class_<Inference::Savestate>(m, "Savestate")
.def(py::init<>());