1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00

Compare commits

..

No commits in common. "master" and "v0.1" have entirely different histories.
master ... v0.1

24 changed files with 592 additions and 1270 deletions

12
.gitmodules vendored
View file

@ -1,12 +1,6 @@
[submodule "llama.cpp"]
path = llama.cpp
url = https://github.com/ggerganov/llama.cpp.git
[submodule "llama.cpp-alibi"]
path = llama.cpp-alibi
url = https://github.com/manyoso/llama.cpp.git
[submodule "llama.cpp-230511"]
path = llama.cpp-230511
url = https://github.com/ggerganov/llama.cpp.git
[submodule "llama.cpp-230519"]
path = llama.cpp-230519
url = https://github.com/ggerganov/llama.cpp.git
[submodule "llama.cpp-mainline"]
path = llama.cpp-mainline
url = https://github.com/ggerganov/llama.cpp.git

View file

@ -1,68 +1,52 @@
cmake_minimum_required(VERSION 3.18)
project(justlm LANGUAGES C CXX)
cmake_minimum_required(VERSION 3.14)
project(libjustlm LANGUAGES C CXX)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
option(LM_PYBIND "If justlm Python bindings should be build" OFF)
option(LM_NOEXCEPT "If justlm exceptions should be disabled" OFF)
option(LM_LLAMA "If LLaMa model support should be built into justlm" ON)
option(LM_GPTJ "If GPT-J model support should be built into justlm" ON)
option(LM_MPT "If MPT model support should be built into justlm" ON)
function(target_justlm_setup TARGET_NAME)
message(STATUS "Configuring model implementation target ${TARGET_NAME}")
target_include_directories(${TARGET_NAME} PUBLIC include/)
if (LM_NOEXCEPT)
target_compile_definitions(${TARGET_NAME} PUBLIC LM_NOEXCEPT)
endif()
endfunction()
include(llama.cpp.cmake)
include_ggml(llama.cpp-mainline _mainline Yes)
include_ggml(llama.cpp-alibi _alibi No)
add_library(justlm_g4a_common SHARED g4a_common.cpp g4a_common.hpp)
set(LM_PYBIND No CACHE BOOL "If Libjustlm Python bindings should be build")
set(LM_COSCHED No CACHE BOOL "If Libjustlm should make use of CoSched")
set(LM_NOEXCEPT No CACHE BOOL "If exceptions should be disabled")
set(LM_MPT No CACHE BOOL "If MPT model support should be built")
if (LM_COSCHED)
set(CMAKE_CXX_STANDARD 20)
endif()
if (LM_MPT)
add_library(justlm_mpt SHARED mpt.cpp justlm_mpt.hpp mpt/mpt.cpp mpt/mpt.hpp)
target_link_libraries(justlm_mpt PRIVATE ggml_alibi justlm_g4a_common)
target_justlm_setup(justlm_mpt)
set(LM_MPT_SOURCES justlm_mpt.hpp mpt/mpt.cpp mpt/mpt.hpp)
add_subdirectory(llama.cpp-alibi)
else()
set(LM_MPT_SOURCES )
add_subdirectory(llama.cpp)
endif()
if (LM_GPTJ)
add_library(justlm_gptj SHARED gptj.cpp justlm_gptj.hpp gptj/gptj.cpp gptj/gptj.hpp)
target_link_libraries(justlm_gptj PRIVATE ggml_alibi justlm_g4a_common)
target_justlm_setup(justlm_gptj)
endif()
if (LM_LLAMA)
add_library(justlm_llama SHARED llama.cpp justlm_llama.hpp)
target_link_libraries(justlm_llama PRIVATE ggml_mainline llama_mainline)
target_compile_definitions(justlm_llama PRIVATE LLAMA_DATE=999999)
target_justlm_setup(justlm_llama)
endif()
add_library(justlm STATIC
add_library(libjustlm STATIC
include/justlm.hpp justlm.cpp
justlm_llama.hpp
g4a-common.cpp g4a-common.hpp
justlm_gptj.hpp gptj/gptj.cpp gptj/gptj.hpp
${LM_MPT_SOURCES}
include/justlm_pool.hpp justlm_pool.cpp
dlhandle.hpp
)
add_library(libjustlm ALIAS justlm)
target_link_libraries(justlm PRIVATE dl)
target_include_directories(justlm PUBLIC include/)
target_compile_definitions(justlm PRIVATE LIB_FILE_EXT="${CMAKE_SHARED_LIBRARY_SUFFIX}")
target_justlm_setup(justlm)
target_link_libraries(libjustlm PRIVATE llama)
if (LM_MPT)
target_compile_definitions(libjustlm PUBLIC LM_MPT)
endif()
if (LM_COSCHED)
target_compile_definitions(libjustlm PUBLIC LM_COSCHED)
target_link_libraries(libjustlm PRIVATE cosched)
set(LM_COSCHED Yes CACHE BOOL "If Libjustlm should make use of CoSched" FORCE)
endif()
if (LM_NOEXCEPT)
target_compile_definitions(libjustlm PUBLIC LM_NOEXCEPT)
endif()
if (LM_PYBIND)
if (LM_COSCHED)
@ -71,6 +55,8 @@ if (LM_PYBIND)
find_package(Python COMPONENTS Interpreter Development)
find_package(pybind11 CONFIG)
pybind11_add_module(justlm_py pybind.cpp)
target_link_libraries(justlm_py PRIVATE justlm)
pybind11_add_module(libjustlm_py pybind.cpp)
target_link_libraries(libjustlm_py PRIVATE libjustlm)
endif()
target_include_directories(libjustlm PUBLIC include/)

View file

@ -1,8 +1,8 @@
# JustLM
Super easy to use library for doing LLaMA/GPT-J/MPT stuff!
Super easy to use library for doing LLaMA/GPT-J stuff!
## Overview
This library implements an easy to use interface to LLaMa, GPT-J and MPT, with optional Python bindings.
This library implements an easy to use interface to both LLaMa and GPT-J, with optional Python bindings.
Context scrolling is automatic and supports a top window bar.

View file

@ -1,99 +0,0 @@
#ifndef DLHANDLE_H
#define DLHANDLE_H
#ifndef __WIN32
#include <string>
#include <stdexcept>
#include <utility>
#include <dlfcn.h>
class Dlhandle {
void *chandle;
public:
class Exception : public std::runtime_error {
public:
using std::runtime_error::runtime_error;
};
Dlhandle() : chandle(nullptr) {}
Dlhandle(const std::string& fpath, int flags = RTLD_LAZY) {
chandle = dlopen(fpath.c_str(), flags);
if (!chandle) {
throw Exception("dlopen(\""+fpath+"\"): "+dlerror());
}
}
Dlhandle(const Dlhandle& o) = delete;
Dlhandle(Dlhandle&& o) : chandle(o.chandle) {
o.chandle = nullptr;
}
~Dlhandle() {
if (chandle) dlclose(chandle);
}
auto operator =(Dlhandle&& o) {
chandle = std::exchange(o.chandle, nullptr);
}
bool is_valid() const {
return chandle != nullptr;
}
operator bool() const {
return is_valid();
}
template<typename T>
T* get(const std::string& fname) {
auto fres = reinterpret_cast<T*>(dlsym(chandle, fname.c_str()));
return (dlerror()==NULL)?fres:nullptr;
}
auto get_fnc(const std::string& fname) {
return get<void*(...)>(fname);
}
};
#else
#include <string>
#include <exception>
#include <libloaderapi.h>
class Dlhandle {
HMODULE chandle;
public:
class Exception : public std::runtime_error {
public:
using std::runtime_error::runtime_error;
};
Dlhandle() : chandle(nullptr) {}
Dlhandle(const std::string& fpath) {
chandle = LoadLibraryA(fpath.c_str());
if (!chandle) {
throw Exception("dlopen(\""+fpath+"\"): Error");
}
}
Dlhandle(const Dlhandle& o) = delete;
Dlhandle(Dlhandle&& o) : chandle(o.chandle) {
o.chandle = nullptr;
}
~Dlhandle() {
if (chandle) FreeLibrary(chandle);
}
bool is_valid() const {
return chandle != nullptr;
}
template<typename T>
T* get(const std::string& fname) {
return reinterpret_cast<T*>(GetProcAddress(chandle, fname.c_str()));
}
auto get_fnc(const std::string& fname) {
return get<void*(...)>(fname);
}
};
#endif
#endif // DLHANDLE_H

View file

@ -1,4 +1,4 @@
#include "g4a_common.hpp"
#include "g4a-common.hpp"
#include <fstream>
#include <regex>
@ -102,7 +102,7 @@ std::map<std::string, int32_t> json_parse(const std::string & fname) {
return result;
}
std::vector<gpt_vocab::id> gpt_tokenize_inner(const gpt_vocab & vocab, const std::string & text) {
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
std::vector<std::string> words;
// first split the text into words
@ -157,47 +157,6 @@ std::vector<gpt_vocab::id> gpt_tokenize_inner(const gpt_vocab & vocab, const std
return tokens;
}
std::string regex_escape(const std::string &s) {
static const std::regex metacharacters(R"([\.\^\$\-\+\(\)\[\]\{\}\|\?\*])");
return std::regex_replace(s, metacharacters, "\\$&");
}
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
// Generate the subpattern from the special_tokens vector if it's not empty
if (!vocab.special_tokens.empty()) {
std::vector<gpt_vocab::id> out;
std::vector<std::string> chunks;
std::string str = text;
std::string special_tokens_subpattern;
for (const auto &token : vocab.special_tokens) {
if (!special_tokens_subpattern.empty()) {
special_tokens_subpattern += "|";
}
special_tokens_subpattern += regex_escape(token);
}
std::regex re(special_tokens_subpattern);
std::smatch m;
while (std::regex_search(str, m, re)) {
auto tok = vocab.token_to_id.find(m.str());
if (tok != vocab.token_to_id.end()) {
auto tokid = tok->second;
auto pfxtoks = gpt_tokenize_inner(vocab, m.prefix());
out.insert(out.end(), pfxtoks.begin(), pfxtoks.end());
out.push_back(tokid);
str = m.suffix();
}
}
if (!str.empty()) {
auto tokrest = gpt_tokenize_inner(vocab, str);
out.insert(out.end(), tokrest.begin(), tokrest.end());
}
return out;
} else {
return gpt_tokenize_inner(vocab, text);
}
}
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
@ -218,7 +177,7 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
}
gpt_vocab::id gpt_sample_top_k_top_p(
const size_t actualVocabSize,
const gpt_vocab & vocab,
const int32_t * last_n_tokens_data,
int last_n_tokens_size,
const std::vector<float> logits,
@ -227,7 +186,7 @@ gpt_vocab::id gpt_sample_top_k_top_p(
double temp,
float repeat_penalty,
std::mt19937 & rng) {
int n_logits = actualVocabSize;
int n_logits = vocab.id_to_token.size();
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
const auto * plogits = logits.data() + logits.size() - n_logits;

View file

@ -44,11 +44,6 @@ struct gpt_vocab {
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
std::vector<std::string> special_tokens;
void add_special_token(const std::string &token) {
special_tokens.push_back(token);
}
};
void replace(std::string & str, const std::string & needle, const std::string & replacement);
@ -79,7 +74,7 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
// TODO: not sure if this implementation is correct
//
gpt_vocab::id gpt_sample_top_k_top_p(
const size_t actualVocabSize,
const gpt_vocab & vocab,
const int32_t * last_n_tokens_data,
int last_n_tokens_size,
const std::vector<float> logits,

View file

@ -1,26 +0,0 @@
#include "justlm_gptj.hpp"
#include "justlm.hpp"
#include <string>
#include <string_view>
#include <fstream>
#include <cstdint>
extern "C" {
const LM::Implementation *get_justlm_implementation() {
static LM::Implementation fres{false};
return &fres;
}
bool magic_match(std::istream& f) {
uint32_t magic;
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
return magic == 0x67676d6c;
}
LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) {
return new LM::GPTJInference(weights_path, f, p);
}
}

View file

@ -1,6 +1,6 @@
#include "gptj.hpp"
#include "../g4a_common.hpp"
#include "../g4a-common.hpp"
#include <cassert>
#include <cmath>
@ -11,13 +11,13 @@
#include <string>
#include <vector>
#include <iostream>
#include "../msvc_compat_unistd.h"
#include <unistd.h>
#include <sstream>
#include <unordered_set>
#include <ggml.h>
constexpr inline
unsigned long long operator ""_MiB(unsigned long long bytes) {
unsigned long long operator ""_MB(unsigned long long bytes) {
return bytes*1024*1024;
}
@ -32,7 +32,7 @@ static bool kv_cache_init(
const int64_t n_mem = (int64_t)n_layer*n_ctx;
const int64_t n_elements = n_embd*n_mem;
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MiB);
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MB);
struct ggml_init_params params;
params.mem_size = cache.buf.size;
@ -394,7 +394,7 @@ bool gptj_eval(
const int n_vocab = hparams.n_vocab;
const int n_rot = hparams.n_rot;
static size_t buf_size = 1024_MiB;
static size_t buf_size = 1024_MB;
if (!model.buf.addr || model.buf.size < buf_size)
model.buf.resize(buf_size);

View file

@ -5,7 +5,7 @@
#include <map>
#include <ggml.h>
#include "../g4a_common.hpp"
#include "../g4a-common.hpp"
// default hparams (GPT-J 6B)

View file

@ -7,28 +7,45 @@
#include <memory>
#include <thread>
#ifdef LM_COSCHED
# include <scheduler.hpp>
# define LM_SCHEDULABLE(type) ::CoSched::AwaitableTask<type>
# define LM_CORETURN co_return
# define LM_COAWAIT co_await
# define LM_TASKYIELD (co_await ::CoSched::Task::get_current().yield())
#else
# define LM_SCHEDULABLE(type) type
# define LM_CORETURN return
# define LM_COAWAIT
# define LM_TASKYIELD (true)
#endif
#ifdef LM_NOEXCEPT
# define LM_NOEXCEPTDECL noexcept
# define LM_THROW(t, r) do {this->last_error = (t); return r;} while (0)
# define LM_THROW(t, r) this->last_error = (t); return r;
# define LM_COTHROW(t, r) this->last_error = (t); LM_CORETURN r;
# define LM_LAST_ERROR_STORAGE mutable std::string last_error;
# define LM_LAST_ERROR_GETTER const std::string& get_last_error() const {return last_error;}
# define LM_ERRBOOL bool
# define LM_BOOL_ERROR false
# define LM_BOOL_SUCCESS true
# define LM_RETHROW(x) return x
# define LM_ERROR_CATCH(x, errval, ...) {auto v = x; if (v == (errval)) __VA_ARGS__}
# define LM_ERROR_FORWARD(x, errval) do {auto v = x; if (v == (errval)) return x;} while (0)
# define LM_ERROR_FORWARD(x) {auto v = x; if (!v) LM_CORETURN x;} 0
#else
# define LM_NOEXCEPTDECL
# define LM_THROW(t, r) throw Exception(t)
# define LM_COTHROW(t, r) throw Exception(t)
# define LM_LAST_ERROR_STORAGE
# define LM_LAST_ERROR_GETTER
# define LM_ERRBOOL void
# define LM_BOOL_ERROR
# define LM_BOOL_SUCCESS
# define LM_RETHROW(x) std::rethrow_exception(std::current_exception())
# define LM_ERROR_CATCH(x, errval, ...) try {x;} catch (...) __VA_ARGS__
# define LM_ERROR_FORWARD(x, errval) {x;}
# define LM_ERROR_FORWARD(x) {x;}
#endif
#ifdef LM_COSCHED
#ifndef LM_NOEXCEPT
#warning Exceptions should not be enabled in combination with CoSched. Any exceptions thrown will lead to a std::terminate() call
#endif
#endif
#if _MSC_VER
@ -41,15 +58,18 @@ namespace LM {
using ssize_t = SSIZE_T;
#endif
using GenerateCallback = std::function<bool (const char *generated)>;
using AppendCallback = std::function<bool (float progress)>;
class Inference {
protected:
AppendCallback on_scroll = nullptr;
std::function<bool (float)> on_scroll = nullptr;
void *generic_state = nullptr;
static inline
bool ends_with(std::string_view str, std::string_view suffix) noexcept {
if (suffix.empty()) return false;
return str.size() >= suffix.size() && 0 == str.compare(str.size()-suffix.size(), suffix.size(), suffix);
}
LM_LAST_ERROR_STORAGE
public:
@ -59,25 +79,21 @@ public:
struct Params {
int seed = 0; // RNG seed
unsigned n_threads = 0; // Amount of threads to use, immutable after Inference was constructed
unsigned n_ctx = 2024; // Context size
unsigned n_threads = 0;
unsigned n_ctx = 2012; // Context size
unsigned n_ctx_window_top_bar = 0; // Top bar of context window. Must be smaller than context size
unsigned n_batch = 8; // Batch size
unsigned n_repeat_last = 0;
unsigned n_eos_ignores = 0;
unsigned n_repeat_last = 0; // llama.cpp specific
float scroll_keep = 0.0f; // 0.4f to keep 40% of context below top bar when scrolling; 0.0f to remove everything after top bar
unsigned top_k = 40;
float top_p = 0.9f;
float temp = 0.72f;
float mirostat_learning_rate = 0.1f; // mirostat specific
float mirostat_target_entropy = 5.0f; // mirostat specific
float repeat_penalty = 1.0f;
float top_p = 0.9f;
float temp = 0.72f;
float repeat_penalty = 1.0f; // llama.cpp specific
unsigned eos_ignores = 0; // llama.cpp specific
unsigned n_gpu_layers = 38;
bool use_mlock = true; // llama specific
int prefer_mirostat = 0; // Use given mirostat version if available (see is_mirostat_available()); llama specific
bool use_mlock = true; // llama.cpp specific
} params;
struct Savestate {
@ -108,42 +124,27 @@ public:
static
Inference *construct(const std::string& weights_path, const Params& p);
void set_scroll_callback(const AppendCallback& scroll_cb) noexcept {
void set_scroll_callback(const std::function<bool (float)>& scroll_cb) noexcept {
on_scroll = scroll_cb;
}
// This must be called with a non-empty prompt!
virtual LM_ERRBOOL append(const std::string& prompt, const AppendCallback& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
virtual LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float progress)>& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
// append() must have been called at least once before calling this!
virtual std::string run(std::string_view end = "", const GenerateCallback& on_tick = nullptr, const GenerateCallback& pre_tick = nullptr) LM_NOEXCEPTDECL = 0;
virtual LM_SCHEDULABLE(std::string) run(std::string_view end = "", const std::function<bool (const char *generated)>& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
virtual unsigned get_context_size() const noexcept = 0;
virtual LM_ERRBOOL create_savestate(Savestate&) const LM_NOEXCEPTDECL = 0;
virtual LM_ERRBOOL restore_savestate(const Savestate&) LM_NOEXCEPTDECL = 0;
virtual LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate&) const LM_NOEXCEPTDECL = 0;
virtual LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate&) LM_NOEXCEPTDECL = 0;
virtual LM_ERRBOOL serialize(std::ostream&) const LM_NOEXCEPTDECL = 0;
virtual LM_ERRBOOL deserialize(std::istream&) LM_NOEXCEPTDECL = 0;
virtual LM_ERRBOOL load_grammar(const std::string&, bool override_temperature [[maybe_unused]] = false) LM_NOEXCEPTDECL {
LM_THROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
}
virtual LM_ERRBOOL unload_grammar() LM_NOEXCEPTDECL {
LM_THROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
}
virtual LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream&) const LM_NOEXCEPTDECL = 0;
virtual LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream&) LM_NOEXCEPTDECL = 0;
virtual const std::string& get_prompt() const LM_NOEXCEPTDECL = 0;
virtual bool is_mirostat_available() const noexcept {return false;}
virtual bool is_grammar_available() const noexcept {return false;}
LM_LAST_ERROR_GETTER
};
struct Implementation {
bool is_fallback = false;
};
}
#endif // JUSTLM_HPP

View file

@ -63,21 +63,21 @@ class InferencePool {
}
// Returns false on error
bool store_slot(Slot& slot);
LM_SCHEDULABLE(bool) store_slot(Slot& slot);
// Returns nullptr on error
Slot *load_slot(size_t id, Slot *suggested_slot = nullptr);
LM_SCHEDULABLE(Slot*) load_slot(size_t id, Slot *suggested_slot = nullptr);
void store_and_reset_slot(Slot& slot) {
store_slot(slot); //TODO: Should handle errors somehow
LM_SCHEDULABLE(void) store_and_reset_slot(Slot& slot) {
LM_COAWAIT store_slot(slot); //TODO: Should handle errors somehow
slot.reset();
return;
LM_CORETURN;
}
// Doesn't fail
Slot *get_free_slot();
LM_SCHEDULABLE(Slot*) get_free_slot();
// Returns nullptr if not found
Slot *find_slot_by_id(size_t id, bool deserialize = true);
LM_SCHEDULABLE(Slot*) find_slot_by_id(size_t id, bool deserialize = true);
public:
// The pool_name must be unique amonst all applications in cwd
@ -93,14 +93,14 @@ public:
}
}
std::shared_ptr<Inference> create_inference(size_t id, const std::string& weights_path, const Inference::Params& p) {
auto slot = get_free_slot();
return slot->create_inference(id, weights_path, p);
LM_SCHEDULABLE(std::shared_ptr<Inference>) create_inference(size_t id, const std::string& weights_path, const Inference::Params& p) {
auto slot = LM_COAWAIT get_free_slot();
LM_CORETURN slot->create_inference(id, weights_path, p);
}
std::shared_ptr<Inference> get_inference(size_t id);
std::shared_ptr<Inference> get_or_create_inference(size_t id, const std::string& weights_path, const Inference::Params& p);
void delete_inference(size_t id);
void store_all();
LM_SCHEDULABLE(std::shared_ptr<Inference>) get_inference(size_t id);
LM_SCHEDULABLE(std::shared_ptr<Inference>) get_or_create_inference(size_t id, const std::string& weights_path, const Inference::Params& p);
LM_SCHEDULABLE(void) delete_inference(size_t id);
LM_SCHEDULABLE(void) store_all();
std::vector<size_t> get_active_slot_ids() const;
void cleanup();

View file

@ -1,66 +1,30 @@
#include "justlm.hpp"
#include "dlhandle.hpp"
#include "justlm_llama.hpp"
#include "justlm_gptj.hpp"
#ifdef LM_MPT
# include "justlm_mpt.hpp"
#endif
#include <string>
#include <vector>
#include <fstream>
#include <filesystem>
static
Dlhandle get_implementation(std::ifstream& input_f) {
Dlhandle matching;
Dlhandle fallback;
// Iterate over all libraries
for (const auto& f : std::filesystem::directory_iterator(".")) {
// Get path
const auto& p = f.path();
// Check extension
if (p.extension() != LIB_FILE_EXT) continue;
// Load library
try {
Dlhandle dl(p);
// Get implementation info getter
auto implementation_getter = dl.get<const LM::Implementation *()>("get_justlm_implementation");
if (!implementation_getter) continue;
// Get implementation info
const auto *implementation_info = implementation_getter();
// Set if fallback
if (implementation_info->is_fallback) {
fallback = std::move(dl);
continue;
}
// Set if matching magic
input_f.seekg(0);
auto magic_match = dl.get<bool(std::ifstream&)>("magic_match");
if (magic_match && magic_match(input_f)) {
matching = std::move(dl);
continue;
}
} catch (...) {}
}
// Return matching if any, fallback otherwise
if (matching) return matching;
return fallback;
}
LM::Inference *LM::Inference::construct(const std::string &weights_path, const Params &p) {
static std::vector<Dlhandle> dls;
// Read magic
std::ifstream f(weights_path, std::ios::binary);
if (!f) {
throw Exception("Failed to open weights file for reading at "+weights_path);
uint32_t magic;
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
// Create inference instance
if (magic == 0x67676d6c) {
f.seekg(0);
return new GPTJInference(weights_path, f, p);
# ifdef LM_MPT
} else if (magic == 0x67676d6d) {
f.seekg(0);
return new MPTInference(weights_path, f, p);
# endif
} else {
f.close();
return new LLaMaInference(weights_path, p);
}
// Get correct implementation
auto impl = get_implementation(f);
if (!impl) return nullptr;
// Get inference constructor
auto constructor = impl.get<LM::Inference *(const std::string &, std::ifstream&, const LM::Inference::Params &)>("construct");
if (!constructor) return nullptr;
// Back up Dlhandle
dls.push_back(std::move(impl));
// Construct inference
f.seekg(0);
return constructor(weights_path, f, p);
}

View file

@ -4,7 +4,7 @@
#include <random>
#include <cstring>
#include "gptj/gptj.hpp"
#include "g4a_common.hpp"
#include "g4a-common.hpp"
namespace LM {
@ -53,18 +53,19 @@ class GPTJInference final : public Inference {
auto& state = get_state();
if (state) {
if (state->model.ctx) ggml_free(state->model.ctx); //TODO: Is that enough?
delete state;
}
}
// This function reduces the size of our tokens vector according to some parameters
// All tokens will be evaluated if scrolling was needed and true will be returned
bool window_scroll() LM_NOEXCEPTDECL {
LM_SCHEDULABLE(bool) window_scroll() LM_NOEXCEPTDECL {
auto &state = get_state();
// Check that we actually need to scroll
if (state->tokens.size() <= params.n_ctx) {
// Nope
return false;
LM_CORETURN false;
}
// Start scrolling
if (params.scroll_keep > 0.0f) {
@ -81,11 +82,11 @@ class GPTJInference final : public Inference {
state->tokens.resize(params.n_ctx_window_top_bar);
}
// Evaluate tokens
LM_ERROR_FORWARD(evaluate_tokens(0, on_scroll), LM_BOOL_ERROR);
return true;
LM_ERROR_FORWARD(LM_COAWAIT evaluate_tokens(0, on_scroll));
LM_CORETURN true;
}
LM_ERRBOOL evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick = nullptr) LM_NOEXCEPTDECL {
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL {
auto& state = get_state();
// Evaluate tokens in batches
@ -96,7 +97,7 @@ class GPTJInference final : public Inference {
// Evaluate
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+params.n_batch);
if (!gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
LM_THROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
LM_COTHROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
}
// Tick
@ -104,7 +105,8 @@ class GPTJInference final : public Inference {
// Calculate progress
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
// Tick and yield
if (!on_tick(progress)) return LM_BOOL_SUCCESS;
if (!on_tick(progress)) LM_CORETURN LM_BOOL_SUCCESS;
else if (!LM_TASKYIELD) LM_CORETURN LM_BOOL_SUCCESS;
}
}
@ -114,7 +116,7 @@ class GPTJInference final : public Inference {
//TODO: This is extremely inefficient! Don't do that...
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+1);
if (!gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
LM_THROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
LM_COTHROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
}
}
}
@ -122,7 +124,7 @@ class GPTJInference final : public Inference {
// Notify about completion
if (on_tick) on_tick(100.f);
return LM_BOOL_SUCCESS;
LM_CORETURN LM_BOOL_SUCCESS;
}
public:
@ -133,7 +135,7 @@ public:
deinit();
}
LM_ERRBOOL append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
auto& state = get_state();
// Append to current prompt
@ -151,123 +153,119 @@ public:
);
// Make sure token limit isn't being hit
if (window_scroll()) {
if (LM_COAWAIT window_scroll()) {
// That function already has evaluated our tokens since scrolling was needed
return LM_BOOL_SUCCESS;
LM_CORETURN LM_BOOL_SUCCESS;
}
// Evaluate new tokens
return evaluate_tokens(old_token_count, on_tick);
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
}
std::string run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
auto& state = get_state();
std::string fres;
// Loop until done
bool abort = false;
unsigned eos_count = 0;
size_t last_size = 0;
while (!abort && (end.empty() || fres.find(end) == fres.npos)) {
last_size = fres.size();
while (!abort && !ends_with(fres, end)) {
// Sample top p and top k
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
auto id = gpt_sample_top_k_top_p(state->vocab, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-params.n_repeat_last):nullptr, params.n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
if (id == 50256) {
if (eos_count++ == params.n_eos_ignores) {
if (eos_count++ == params.eos_ignores) {
abort = true;
continue;
}
id = gpt_tokenize(state->vocab, "\n")[0];
state->tokens.push_back(id);
} else {
// Add token
state->tokens.push_back(id);
}
// Add token
state->tokens.push_back(id);
// Make sure token limit isn't being hit
window_scroll();
LM_COAWAIT window_scroll();
// Get token as string
const std::string_view str = state->vocab.id_to_token[id];
const auto str = state->vocab.id_to_token[id];
// Append string to function result
state->prompt.append(str);
fres.append(str);
if (pre_tick && !pre_tick(str.data())) abort = true;
else {
// Evaluate token
// TODO: Respect batch size
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
if (!gptj_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
LM_THROW("Failed to evaluate new tokens", "");
}
// Evaluate token
// TODO: Respect batch size
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
if (!gptj_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
LM_COTHROW("Failed to evaluate new tokens", "");
}
// Tick
if (on_tick && !on_tick(str.data())) abort = true;
if (on_tick && !on_tick(str.c_str())) abort = true;
else if (!LM_TASKYIELD) abort = true;
}
// Create final string TODO: Could be optimized
state->prompt.append(fres);
if (!abort) {
fres = std::string(fres.data(), last_size);
fres = std::string(fres.data(), fres.size()-end.size());
}
// Return final string
return fres;
LM_CORETURN fres;
}
unsigned get_context_size() const noexcept override {
return get_state()->tokens.size();
}
LM_ERRBOOL create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
auto& state = get_state();
sv.buf.resize(gptj_get_state_size(state->model));
gptj_copy_state_data(state->model, state->rng, sv.buf.data());
sv.tokens = state->tokens;
sv.prompt = state->prompt;
sv.ctx = generic_state;
return LM_BOOL_SUCCESS;
LM_CORETURN LM_BOOL_SUCCESS;
}
LM_ERRBOOL restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
auto& state = get_state();
if (sv.ctx != generic_state)
LM_THROW("Savestate does not match context", LM_BOOL_ERROR);
LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR);
gptj_set_state_data(&state->model, &state->rng, sv.buf.data());
state->tokens = sv.tokens;
state->prompt = sv.prompt;
return LM_BOOL_SUCCESS;
LM_CORETURN LM_BOOL_SUCCESS;
}
LM_ERRBOOL serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
auto& state = get_state();
// Get state size
auto state_size = gptj_get_state_size(state->model);
// Write sizes
for (const uint32_t s : {state->tokens.size(), state->prompt.size(), state_size}) {
if (!o.write(reinterpret_cast<const char*>(&s), sizeof(s))) {
LM_THROW("Failed to serialize data sizes", LM_BOOL_ERROR);
LM_COTHROW("Failed to serialize data sizes", LM_BOOL_ERROR);
}
}
// Write tokens
if (!o.write(reinterpret_cast<const char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
LM_THROW("Failed to serialize tokens", LM_BOOL_ERROR);
LM_COTHROW("Failed to serialize tokens", LM_BOOL_ERROR);
}
// Write prompt
if (!o.write(state->prompt.data(), state->prompt.size())) {
LM_THROW("Failed to serialize prompt", LM_BOOL_ERROR);
LM_COTHROW("Failed to serialize prompt", LM_BOOL_ERROR);
}
// Write state
std::vector<uint8_t> state_buf(state_size);
gptj_copy_state_data(state->model, state->rng, state_buf.data());
if (!o.write(reinterpret_cast<const char*>(state_buf.data()), state_size)) {
LM_THROW("Failed to serialize state", LM_BOOL_ERROR);
LM_COTHROW("Failed to serialize state", LM_BOOL_ERROR);
}
return LM_BOOL_SUCCESS;
LM_CORETURN LM_BOOL_SUCCESS;
}
LM_ERRBOOL deserialize(std::istream &i) LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream &i) LM_NOEXCEPTDECL override {
auto& state = get_state();
uint32_t embd_size, prompt_size, state_size;
// Initialization to prevent compiler complaints
@ -275,26 +273,26 @@ public:
// Read sizes
for (uint32_t *s : {&embd_size, &prompt_size, &state_size}) {
if (!i.read(reinterpret_cast<char*>(s), sizeof(*s))) {
LM_THROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
LM_COTHROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
}
}
// Read tokens
state->tokens.resize(embd_size);
if (!i.read(reinterpret_cast<char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
LM_THROW("Failed to deserialize tokens", LM_BOOL_ERROR);
LM_COTHROW("Failed to deserialize tokens", LM_BOOL_ERROR);
}
// Read prompt
state->prompt.resize(prompt_size);
if (!i.read(state->prompt.data(), state->prompt.size())) {
LM_THROW("Failed to deserialize prompt", LM_BOOL_ERROR);
LM_COTHROW("Failed to deserialize prompt", LM_BOOL_ERROR);
}
// Read state
std::vector<uint8_t> state_buf(state_size);
if (!i.read(reinterpret_cast<char*>(state_buf.data()), state_buf.size())) {
LM_THROW("Failed to deserialize state", LM_BOOL_ERROR);
LM_COTHROW("Failed to deserialize state", LM_BOOL_ERROR);
}
gptj_set_state_data(&state->model, &state->rng, state_buf.data());
return LM_BOOL_SUCCESS;
LM_CORETURN LM_BOOL_SUCCESS;
}
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
return get_state()->prompt;

View file

@ -3,20 +3,15 @@
#include <cstring>
#include <ggml.h>
#include <llama.h>
#include <common/grammar-parser.h>
namespace LM {
class LLaMAInference final : public Inference {
class LLaMaInference final : public Inference {
struct State {
llama_context *ctx = nullptr;
llama_model *model;
llama_grammar *grammar = nullptr;
bool grammar_override_temp;
grammar_parser::parse_state parsed_grammar;
std::string prompt; // Mostly here for easy "debugging"
std::vector<int> tokens;
unsigned n_ctx;
int n_ctx;
};
State*& get_state() {
@ -36,24 +31,12 @@ class LLaMAInference final : public Inference {
auto lparams = llama_context_default_params();
lparams.seed = params.seed;
lparams.n_ctx = params.n_ctx = params.n_ctx>0?params.n_ctx:2024;
lparams.n_threads = params.n_threads;
//lparams.n_threads_batch = params.n_threads; TODO: Is this sane?
// Get model parameters
auto mparams = llama_model_default_params();
mparams.use_mlock = params.use_mlock;
mparams.n_gpu_layers = params.n_gpu_layers;
// Load model
state->model = llama_load_model_from_file(weights_path.c_str(), mparams);
if (!state->model) {
LM_THROW("Failed to initialize llama model from file", LM_BOOL_ERROR);
}
lparams.use_mlock = params.use_mlock;
// Create context
state->ctx = llama_new_context_with_model(state->model, lparams);
state->ctx = llama_init_from_file(weights_path.c_str(), lparams);
if (!state->ctx) {
LM_THROW("Failed to initialize llama context from model", LM_BOOL_ERROR);
LM_THROW("Failed to initialize llama from file", LM_BOOL_ERROR);
}
// Initialize some variables
@ -64,12 +47,12 @@ class LLaMAInference final : public Inference {
// This function reduces the size of our tokens vector according to some parameters
// All tokens will be evaluated if scrolling was needed and true will be returned
bool window_scroll() LM_NOEXCEPTDECL {
LM_SCHEDULABLE(bool) window_scroll() LM_NOEXCEPTDECL {
auto &state = get_state();
// Check that we actually need to scroll
if (state->tokens.size() <= state->n_ctx) {
// Nope
return false;
LM_CORETURN false;
}
// Start scrolling
if (params.scroll_keep > 0.0f) {
@ -86,11 +69,11 @@ class LLaMAInference final : public Inference {
state->tokens.resize(params.n_ctx_window_top_bar);
}
// Evaluate tokens
LM_ERROR_FORWARD(evaluate_tokens(0, on_scroll), LM_BOOL_ERROR);
return true;
LM_ERROR_FORWARD(LM_COAWAIT evaluate_tokens(0, on_scroll));
LM_CORETURN true;
}
LM_ERRBOOL evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick = nullptr) LM_NOEXCEPTDECL {
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL {
auto& state = get_state();
// Evaluate tokens in batches
@ -99,9 +82,8 @@ class LLaMAInference final : public Inference {
if (it + params.n_batch >= ssize_t(state->tokens.size())) break;
// Evaluate
const auto batch = llama_batch_get_one(state->tokens.data()+it, params.n_batch, it, 0);
if (llama_decode(state->ctx, batch)) {
LM_THROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
if (llama_eval(state->ctx, state->tokens.data()+it, params.n_batch, it, params.n_threads)) {
LM_COTHROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
}
// Tick
@ -109,16 +91,16 @@ class LLaMAInference final : public Inference {
// Calculate progress
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
// Tick and yield
if (!on_tick(progress)) return LM_BOOL_SUCCESS;
if (!on_tick(progress)) LM_BOOL_SUCCESS;
else if (!LM_TASKYIELD) LM_BOOL_SUCCESS;
}
}
// Evaluate remaining tokens
if (it < state->tokens.size()) {
for (; it != state->tokens.size(); it++) {
const auto batch = llama_batch_get_one(state->tokens.data()+it, 1, it, 0);
if (llama_decode(state->ctx, batch)) {
LM_THROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
if (llama_eval(state->ctx, state->tokens.data()+it, 1, it, params.n_threads)) {
LM_COTHROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
}
}
}
@ -126,69 +108,14 @@ class LLaMAInference final : public Inference {
// Notify about completion
if (on_tick) on_tick(100.f);
return LM_BOOL_SUCCESS;
}
int accept_token(int t) {
auto& state = get_state();
if (state->grammar)
llama_grammar_accept_token(state->ctx, state->grammar, t);
return t;
}
int llama_sample_top_p_top_k() {
auto& state = get_state();
auto logits = llama_get_logits(state->ctx);
auto n_vocab = llama_n_vocab(state->model);
// Populate initial list of all candidates
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
for (int token_id = 0; token_id < n_vocab; token_id++) {
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
}
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
// Sample repeat penalty
auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
llama_sample_repetition_penalties(state->ctx, &candidates_p, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-n_repeat_last):nullptr, n_repeat_last, params.repeat_penalty, 1.0f, 1.0f); // Might be wrong
// Grammar sampling
if (state->grammar) {
llama_sample_grammar(state->ctx, &candidates_p, state->grammar);
}
if (!(state->grammar && state->grammar_override_temp) && (params.temp > 0.01f || params.temp < -0.01f)) {
// Temperature sampling
switch (params.prefer_mirostat) {
case 0: {
llama_sample_top_k(state->ctx, &candidates_p, params.top_k, 1);
llama_sample_tail_free(state->ctx, &candidates_p, 1.0f, 1);
llama_sample_typical(state->ctx, &candidates_p, 1.0f, 1);
llama_sample_top_p(state->ctx, &candidates_p, params.top_p, 1);
llama_sample_temp(state->ctx, &candidates_p, params.temp);
return accept_token(llama_sample_token(state->ctx, &candidates_p));
}
case 1: {
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
const int mirostat_m = 100;
llama_sample_temp(state->ctx, &candidates_p, params.temp);
return accept_token(llama_sample_token_mirostat(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, mirostat_m, &mirostat_mu));
}
case 2: {
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
llama_sample_temp(state->ctx, &candidates_p, params.temp);
return accept_token(llama_sample_token_mirostat_v2(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, &mirostat_mu));
}
default: LM_THROW("Invalid mirostat version "+std::to_string(params.prefer_mirostat), LM_BOOL_ERROR);
}
} else {
// Greedy sampling
return accept_token(llama_sample_token(state->ctx, &candidates_p));
}
LM_CORETURN LM_BOOL_SUCCESS;
}
public:
LLaMAInference(const std::string& weights_path, const Params& p) : Inference(p) {
LLaMaInference(const std::string& weights_path, const Params& p) : Inference(p) {
init(weights_path);
}
~LLaMAInference() override {
~LLaMaInference() override {
auto& state = get_state();
if (state) {
@ -197,7 +124,7 @@ public:
}
}
LM_ERRBOOL append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
auto& state = get_state();
// Check if prompt was empty
@ -211,44 +138,37 @@ public:
state->tokens.resize(old_token_count+state->prompt.size());
// Run tokenizer
const auto token_count = llama_tokenize(state->model, prompt.c_str(), prompt.size(), state->tokens.data()+old_token_count, state->tokens.size()-old_token_count, was_empty, false);
const auto token_count = llama_tokenize(state->ctx, prompt.c_str(), state->tokens.data()+old_token_count, state->tokens.size()-old_token_count, was_empty);
state->tokens.resize(old_token_count+token_count);
// Make sure token limit isn't being hit
if (window_scroll()) {
if (LM_COAWAIT window_scroll()) {
// That function already has evaluated our tokens since scrolling was needed
return LM_BOOL_SUCCESS;
LM_CORETURN LM_BOOL_SUCCESS;
}
// Evaluate new tokens
return evaluate_tokens(old_token_count, on_tick);
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
}
std::string run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
auto& state = get_state();
std::string fres;
// Loop until done
bool abort = false;
unsigned eos_count = 0;
size_t last_size = 0;
while (!abort && (end.empty() || fres.find(end) == fres.npos)) {
last_size = fres.size();
while (!abort && !ends_with(fres, end)) {
// Sample top p and top k
int id;
try {
id = llama_sample_top_p_top_k();
} catch (const std::exception& e) {
LM_THROW(e.what(), "");
}
auto id = llama_sample_top_p_top_k(state->ctx, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-params.n_repeat_last):nullptr, params.n_repeat_last, params.top_k, params.top_p, params.temp, params.repeat_penalty);
if (id == llama_token_eos(state->model)) {
if (eos_count++ == params.n_eos_ignores) {
if (id == llama_token_eos()) {
if (eos_count++ == params.eos_ignores) {
abort = true;
continue;
}
state->tokens.push_back(0);
llama_tokenize(state->model, "\n", 1, &state->tokens.back(), 1, false, false);
llama_tokenize(state->ctx, "\n", &state->tokens.back(), 1, false);
id = state->tokens.back();
} else {
// Add token
@ -256,90 +176,85 @@ public:
}
// Make sure token limit isn't hit
window_scroll();
LM_COAWAIT window_scroll();
// Get token as string
std::string str(14, ' ');
str.resize(llama_token_to_piece(state->model, id, str.data(), 14));
const auto str = llama_token_to_str(state->ctx, id);
// Append string to function result
state->prompt.append(str);
fres.append(str);
// Tick
if (pre_tick && !pre_tick(str.data())) abort = true;
else {
// Evaluate token
// TODO: Respect batch size
const auto batch = llama_batch_get_one(state->tokens.data()+state->tokens.size()-1, 1, state->tokens.size()-1, 0);
if (llama_decode(state->ctx, batch)) {
LM_THROW("Failed to evaluate new tokens", "");
}
// Evaluate token
// TODO: Respect batch size
if (llama_eval(state->ctx, state->tokens.data()+state->tokens.size()-1, 1, state->tokens.size()-1, params.n_threads)) {
LM_COTHROW("Failed to evaluate new tokens", "");
}
// Tick and yield
if (on_tick && !on_tick(str.data())) abort = true;
if (on_tick && !on_tick(str)) abort = true;
else if (!LM_TASKYIELD) abort = true;
}
// Create final string TODO: Could be optimized
if (!abort && fres.size() > end.size()) {
fres = std::string(fres.data(), last_size);
state->prompt.append(fres);
if (!abort) {
fres = std::string(fres.data(), fres.size()-end.size());
}
// Return final string
return fres;
LM_CORETURN fres;
}
unsigned get_context_size() const noexcept override {
return get_state()->tokens.size();
}
LM_ERRBOOL create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
auto& state = get_state();
sv.buf.resize(llama_get_state_size(state->ctx));
llama_copy_state_data(state->ctx, sv.buf.data());
sv.tokens = state->tokens;
sv.prompt = state->prompt;
sv.ctx = generic_state;
return LM_BOOL_SUCCESS;
LM_CORETURN LM_BOOL_SUCCESS;
}
LM_ERRBOOL restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
auto& state = get_state();
if (sv.ctx != generic_state)
LM_THROW("Savestate does not match context", LM_BOOL_ERROR);
llama_set_state_data(state->ctx, const_cast<uint8_t*>(sv.buf.data()));
LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR);
llama_set_state_data(state->ctx, sv.buf.data());
state->tokens = sv.tokens;
state->prompt = sv.prompt;
return LM_BOOL_SUCCESS;
LM_CORETURN LM_BOOL_SUCCESS;
}
LM_ERRBOOL serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
auto& state = get_state();
// Get state size
auto state_size = llama_get_state_size(state->ctx);
// Write sizes
for (const uint32_t s : {static_cast<size_t>(state->n_ctx), state->tokens.size(), state->prompt.size(), state_size}) {
if (!o.write(reinterpret_cast<const char*>(&s), sizeof(s))) {
LM_THROW("Failed to serialize data sizes", LM_BOOL_ERROR);
LM_COTHROW("Failed to serialize data sizes", LM_BOOL_ERROR);
}
}
// Write tokens
if (!o.write(reinterpret_cast<const char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
LM_THROW("Failed to serialize tokens", LM_BOOL_ERROR);
LM_COTHROW("Failed to serialize tokens", LM_BOOL_ERROR);
}
// Write prompt
if (!o.write(state->prompt.data(), state->prompt.size())) {
LM_THROW("Failed to serialize prompt", LM_BOOL_ERROR);
LM_COTHROW("Failed to serialize prompt", LM_BOOL_ERROR);
}
// Write state
std::vector<uint8_t> state_buf(state_size);
llama_copy_state_data(state->ctx, state_buf.data());
if (!o.write(reinterpret_cast<const char*>(state_buf.data()), state_size)) {
LM_THROW("Failed to serialize state", LM_BOOL_ERROR);
LM_COTHROW("Failed to serialize state", LM_BOOL_ERROR);
}
return LM_BOOL_SUCCESS;
LM_CORETURN LM_BOOL_SUCCESS;
}
LM_ERRBOOL deserialize(std::istream &i) LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream &i) LM_NOEXCEPTDECL override {
auto& state = get_state();
uint32_t n_ctx, embd_size, prompt_size, state_size;
// Initialization to prevent compiler complaints
@ -347,65 +262,33 @@ public:
// Read sizes
for (uint32_t *s : {&n_ctx, &embd_size, &prompt_size, &state_size}) {
if (!i.read(reinterpret_cast<char*>(s), sizeof(*s))) {
LM_THROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
LM_COTHROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
}
}
if (state->n_ctx != n_ctx) {
LM_THROW("Context length differs (My "+std::to_string(state->n_ctx)+" vs. files "+std::to_string(n_ctx)+')', LM_BOOL_ERROR);
LM_COTHROW("Context length differs (My "+std::to_string(state->n_ctx)+" vs. files "+std::to_string(n_ctx)+')', LM_BOOL_ERROR);
}
// Read tokens
state->tokens.resize(embd_size);
if (!i.read(reinterpret_cast<char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
LM_THROW("Failed to deserialize tokens", LM_BOOL_ERROR);
LM_COTHROW("Failed to deserialize tokens", LM_BOOL_ERROR);
}
// Read prompt
state->prompt.resize(prompt_size);
if (!i.read(state->prompt.data(), state->prompt.size())) {
LM_THROW("Failed to deserialize prompt", LM_BOOL_ERROR);
LM_COTHROW("Failed to deserialize prompt", LM_BOOL_ERROR);
}
// Read state
std::vector<uint8_t> state_buf(state_size);
if (!i.read(reinterpret_cast<char*>(state_buf.data()), state_buf.size())) {
LM_THROW("Failed to deserialize state", LM_BOOL_ERROR);
LM_COTHROW("Failed to deserialize state", LM_BOOL_ERROR);
}
llama_set_state_data(state->ctx, state_buf.data());
return LM_BOOL_SUCCESS;
}
LM_ERRBOOL load_grammar(const std::string& src, bool override_temperature) LM_NOEXCEPTDECL override {
auto& state = get_state();
state->parsed_grammar = grammar_parser::parse(src.c_str());
if (state->parsed_grammar.rules.empty()) {
LM_THROW("Failed to parse grammar (or no rules)", LM_BOOL_ERROR);
}
auto rules = state->parsed_grammar.c_rules();
state->grammar = llama_grammar_init(rules.data(), rules.size(), state->parsed_grammar.symbol_ids.at("root"));
if (!state->grammar) {
LM_THROW("Failed to generate llama grammar", LM_BOOL_ERROR);
}
state->grammar_override_temp = override_temperature;
return LM_BOOL_SUCCESS;
}
LM_ERRBOOL unload_grammar() LM_NOEXCEPTDECL override {
get_state()->grammar = nullptr;
return LM_BOOL_SUCCESS;
LM_CORETURN LM_BOOL_SUCCESS;
}
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
return get_state()->prompt;
}
bool is_mirostat_available() const noexcept override {
return true;
}
bool is_grammar_available() const noexcept override {
return true;
}
};
}

View file

@ -4,7 +4,7 @@
#include <random>
#include <cstring>
#include "mpt/mpt.hpp"
#include "g4a_common.hpp"
#include "g4a-common.hpp"
namespace LM {
@ -12,14 +12,13 @@ class MPTInference final : public Inference {
std::string weights_path;
struct State {
gpt_vocab vocab;
mpt_vocab vocab;
mpt_model model;
std::string prompt; // Mostly here for easy "debugging"
std::vector<int> tokens;
std::vector<float> logits;
size_t mem_per_token = 0;
std::mt19937 rng;
int im_end = 0;
State(int32_t seed) : rng(seed) {}
};
@ -48,32 +47,25 @@ class MPTInference final : public Inference {
static std::vector<gpt_vocab::id> r_instruct;
mpt_eval(state->model, params.n_threads, 0, { 0, 1, 2, 3 }, state->logits, state->mem_per_token);
// Find im_end token
{
auto res = state->vocab.token_to_id.find("<|im_end|>");
if (res != state->vocab.token_to_id.end()) {
state->im_end = res->second;
}
}
return LM_BOOL_SUCCESS;
}
void deinit() LM_NOEXCEPTDECL {
auto& state = get_state();
if (state) {
if (state->model.ctx) ggml_free(state->model.ctx); //TODO: Is that enough?
delete state;
}
}
// This function reduces the size of our tokens vector according to some parameters
// All tokens will be evaluated if scrolling was needed and true will be returned
bool window_scroll() LM_NOEXCEPTDECL {
LM_SCHEDULABLE(bool) window_scroll() LM_NOEXCEPTDECL {
auto &state = get_state();
// Check that we actually need to scroll
if (state->tokens.size() <= params.n_ctx) {
// Nope
return false;
LM_CORETURN false;
}
// Start scrolling
if (params.scroll_keep > 0.0f) {
@ -90,11 +82,11 @@ class MPTInference final : public Inference {
state->tokens.resize(params.n_ctx_window_top_bar);
}
// Evaluate tokens
LM_ERROR_FORWARD(evaluate_tokens(0, on_scroll), LM_BOOL_ERROR);
return true;
LM_ERROR_FORWARD(LM_COAWAIT evaluate_tokens(0, on_scroll));
LM_CORETURN true;
}
LM_ERRBOOL evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick) LM_NOEXCEPTDECL {
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL {
auto& state = get_state();
// Evaluate tokens in batches
@ -105,7 +97,7 @@ class MPTInference final : public Inference {
// Evaluate
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+params.n_batch);
if (!mpt_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
LM_THROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
LM_COTHROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
}
// Tick
@ -113,7 +105,8 @@ class MPTInference final : public Inference {
// Calculate progress
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
// Tick and yield
if (!on_tick(progress)) return LM_BOOL_SUCCESS;
if (!on_tick(progress)) LM_CORETURN LM_BOOL_SUCCESS;
else if (!LM_TASKYIELD) LM_CORETURN LM_BOOL_SUCCESS;
}
}
@ -123,7 +116,7 @@ class MPTInference final : public Inference {
//TODO: This is extremely inefficient! Don't do that...
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+1);
if (!mpt_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
LM_THROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
LM_COTHROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
}
}
}
@ -131,7 +124,7 @@ class MPTInference final : public Inference {
// Notify about completion
if (on_tick) on_tick(100.f);
return LM_BOOL_SUCCESS;
LM_CORETURN LM_BOOL_SUCCESS;
}
public:
@ -142,7 +135,7 @@ public:
deinit();
}
LM_ERRBOOL append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
auto& state = get_state();
// Append to current prompt
@ -152,7 +145,7 @@ public:
const auto old_token_count = state->tokens.size();
// Run tokenizer
const auto tokens = gpt_tokenize(state->vocab, prompt);
const auto tokens = mpt_tokenize(state->vocab, prompt);
state->tokens.insert(
state->tokens.end(),
std::make_move_iterator(tokens.begin()),
@ -160,130 +153,132 @@ public:
);
// Make sure token limit isn't being hit
if (window_scroll()) {
if (LM_COAWAIT window_scroll()) {
// That function already has evaluated our tokens since scrolling was needed
return LM_BOOL_SUCCESS;
LM_CORETURN LM_BOOL_SUCCESS;
}
// Evaluate new tokens
return evaluate_tokens(old_token_count, on_tick);
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
}
std::string run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
/*mpt_vocab::id mpt_sample_top_k_top_p(
const mpt_vocab & vocab,
const size_t actualVocabSize,
const int32_t * last_n_tokens_data,
int last_n_tokens_size,
const std::vector<float> logits,
int top_k,
double top_p,
double temp,
float repeat_penalty,
std::mt19937 & rng)
*/
LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
auto& state = get_state();
std::string fres;
// Loop until done
bool abort = false;
unsigned eos_count = 0;
size_t last_size = 0;
while (!abort && (end.empty() || fres.find(end) == fres.npos)) {
last_size = fres.size();
while (!abort && !ends_with(fres, end)) {
// Sample top p and top k
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
auto id = mpt_sample_top_k_top_p(state->vocab, state->model.hparams.n_vocab, state->tokens.data(), state->tokens.size(), state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
if (state->im_end && id == state->im_end) {
if (eos_count++ == params.n_eos_ignores) {
if (id == state->vocab.token_to_id["<|im_end|>"]) {
if (eos_count++ == params.eos_ignores) {
abort = true;
continue;
}
id = gpt_tokenize(state->vocab, "\n")[0];
} else if (id == 0) {
if (eos_count++ == params.n_eos_ignores) {
abort = true;
continue;
}
id = gpt_tokenize(state->vocab, "\n")[0];
id = mpt_tokenize(state->vocab, "\n")[0];
state->tokens.push_back(id);
} else {
// Add token
state->tokens.push_back(id);
}
// Add token
state->tokens.push_back(id);
// Make sure token limit isn't being hit
window_scroll();
LM_COAWAIT window_scroll();
// Get token as string
const std::string_view str = state->vocab.id_to_token[id];
const auto str = state->vocab.id_to_token[id];
// Append string to function result
fres.append(str);
state->prompt.append(str);
// Tick
if (pre_tick && !pre_tick(str.data())) abort = true;
else {
// Evaluate token
// TODO: Respect batch size
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
if (!mpt_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
LM_THROW("Failed to evaluate new tokens", "");
}
// Evaluate token
// TODO: Respect batch size
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
if (!mpt_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
LM_COTHROW("Failed to evaluate new tokens", "");
}
// Tick
if (on_tick && !on_tick(str.data())) abort = true;
if (on_tick && !on_tick(str.c_str())) abort = true;
else if (!LM_TASKYIELD) abort = true;
}
// Create final string TODO: Could be optimized
state->prompt.append(fres);
if (!abort) {
fres = std::string(fres.data(), last_size);
fres = std::string(fres.data(), fres.size()-end.size());
}
// Return final string
return fres;
LM_CORETURN fres;
}
unsigned get_context_size() const noexcept override {
return get_state()->tokens.size();
}
LM_ERRBOOL create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
auto& state = get_state();
sv.buf.resize(mpt_get_state_size(state->model));
mpt_copy_state_data(state->model, state->rng, sv.buf.data());
sv.tokens = state->tokens;
sv.prompt = state->prompt;
sv.ctx = generic_state;
return LM_BOOL_SUCCESS ;
LM_CORETURN LM_BOOL_SUCCESS ;
}
LM_ERRBOOL restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
auto& state = get_state();
if (sv.ctx != generic_state)
LM_THROW("Savestate does not match context", LM_BOOL_ERROR);
LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR);
mpt_set_state_data(&state->model, &state->rng, sv.buf.data());
state->tokens = sv.tokens;
state->prompt = sv.prompt;
return LM_BOOL_SUCCESS;
LM_CORETURN LM_BOOL_SUCCESS;
}
LM_ERRBOOL serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
auto& state = get_state();
// Get state size
auto state_size = mpt_get_state_size(state->model);
// Write sizes
for (const uint32_t s : {state->tokens.size(), state->prompt.size(), state_size}) {
if (!o.write(reinterpret_cast<const char*>(&s), sizeof(s))) {
LM_THROW("Failed to serialize data sizes", LM_BOOL_ERROR);
LM_COTHROW("Failed to serialize data sizes", LM_BOOL_ERROR);
}
}
// Write tokens
if (!o.write(reinterpret_cast<const char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
LM_THROW("Failed to serialize tokens", LM_BOOL_ERROR);
LM_COTHROW("Failed to serialize tokens", LM_BOOL_ERROR);
}
// Write prompt
if (!o.write(state->prompt.data(), state->prompt.size())) {
LM_THROW("Failed to serialize prompt", LM_BOOL_ERROR);
LM_COTHROW("Failed to serialize prompt", LM_BOOL_ERROR);
}
// Write state
std::vector<uint8_t> state_buf(state_size);
mpt_copy_state_data(state->model, state->rng, state_buf.data());
if (!o.write(reinterpret_cast<const char*>(state_buf.data()), state_size)) {
LM_THROW("Failed to serialize state", LM_BOOL_ERROR);
LM_COTHROW("Failed to serialize state", LM_BOOL_ERROR);
}
return LM_BOOL_SUCCESS;
LM_CORETURN LM_BOOL_SUCCESS;
}
LM_ERRBOOL deserialize(std::istream &i) LM_NOEXCEPTDECL override {
LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream &i) LM_NOEXCEPTDECL override {
auto& state = get_state();
uint32_t embd_size, promptsize, state_size;
// Initialization to prevent compiler complaints
@ -291,26 +286,26 @@ public:
// Read sizes
for (uint32_t *s : {&embd_size, &promptsize, &state_size}) {
if (!i.read(reinterpret_cast<char*>(s), sizeof(*s))) {
LM_THROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
LM_COTHROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
}
}
// Read tokens
state->tokens.resize(embd_size);
if (!i.read(reinterpret_cast<char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
LM_THROW("Failed to deserialize tokens", LM_BOOL_ERROR);
LM_COTHROW("Failed to deserialize tokens", LM_BOOL_ERROR);
}
// Read prompt
state->prompt.resize(promptsize);
if (!i.read(state->prompt.data(), state->prompt.size())) {
LM_THROW("Failed to deserialize prompt", LM_BOOL_ERROR);
LM_COTHROW("Failed to deserialize prompt", LM_BOOL_ERROR);
}
// Read state
std::vector<uint8_t> state_buf(state_size);
if (!i.read(reinterpret_cast<char*>(state_buf.data()), state_buf.size())) {
LM_THROW("Failed to deserialize state", LM_BOOL_ERROR);
LM_COTHROW("Failed to deserialize state", LM_BOOL_ERROR);
}
mpt_set_state_data(&state->model, &state->rng, state_buf.data());
return LM_BOOL_SUCCESS;
LM_CORETURN LM_BOOL_SUCCESS;
}
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
return get_state()->prompt;

View file

@ -6,7 +6,7 @@
bool LM::InferencePool::store_slot(Slot &slot) {
LM_SCHEDULABLE(bool) LM::InferencePool::store_slot(Slot &slot) {
auto inference = slot.get_inference();
// Open output file
std::ofstream f(get_slot_filename(slot.get_id()), std::ios::binary);
@ -17,61 +17,61 @@ bool LM::InferencePool::store_slot(Slot &slot) {
f.write(weights_path.data(), weights_path.size());
// Write params
if (!f.write(reinterpret_cast<const char*>(&inference->params), sizeof(inference->params))) {
return false;
LM_CORETURN false;
}
// Serialize instance
try {
inference->serialize(f);
LM_COAWAIT inference->serialize(f);
} catch (...) {
return false;
LM_CORETURN false;
}
// Return success
return true;
LM_CORETURN true;
}
LM::InferencePool::Slot *LM::InferencePool::load_slot(size_t id, Slot *suggested_slot) {
LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::load_slot(size_t id, Slot *suggested_slot) {
// Open input file
std::ifstream f(get_slot_filename(id), std::ios::binary);
if (!f) {
// Does not exist
return nullptr;
LM_CORETURN nullptr;
}
// Read weights path
std::string weights_path;
uint32_t weights_path_len;
if (!f.read(reinterpret_cast<char*>(&weights_path_len), sizeof(weights_path_len))) {
return nullptr;
LM_CORETURN nullptr;
}
weights_path.resize(weights_path_len);
if (!f.read(weights_path.data(), weights_path.size())) {
return nullptr;
LM_CORETURN nullptr;
}
// Read params
LM::Inference::Params p;
if (!f.read(reinterpret_cast<char*>(&p), sizeof(p))) {
return nullptr;
LM_CORETURN nullptr;
}
// Create instance
auto& slot = suggested_slot?*suggested_slot:*(get_free_slot());
auto& slot = suggested_slot?*suggested_slot:*(LM_COAWAIT get_free_slot());
auto inference = slot.create_inference(id, weights_path, p);
// Deserialize instance
try {
inference->deserialize(f);
LM_COAWAIT inference->deserialize(f);
} catch (...) {
slot.reset();
return nullptr;
LM_CORETURN nullptr;
}
// Return final slot
return &slot;
LM_CORETURN &slot;
}
LM::InferencePool::Slot *LM::InferencePool::get_free_slot() {
LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::get_free_slot() {
// Attempt to find free slot while finding oldest one
Slot *oldest = nullptr;
for (auto& slot : slots) {
// Take free slot
if (slot.is_free()) {
return &slot;
LM_CORETURN &slot;
}
// Update oldest
if (oldest == nullptr || slot.get_last_access() < oldest->get_last_access()) {
@ -80,17 +80,17 @@ LM::InferencePool::Slot *LM::InferencePool::get_free_slot() {
}
// Free up oldest slot and take that one
// Note: Since there has to be at least 1 slot, oldest is never going to be a nullptr
store_and_reset_slot(*oldest);
return oldest;
LM_COAWAIT store_and_reset_slot(*oldest);
LM_CORETURN oldest;
}
LM::InferencePool::Slot *LM::InferencePool::find_slot_by_id(size_t id, bool deserialize) {
LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::find_slot_by_id(size_t id, bool deserialize) {
// Attempt to find given slot while finding oldest one
Slot *oldest = nullptr;
for (auto& slot : slots) {
// Take slot with ID
if (slot.get_id() == id) {
return &slot;
LM_CORETURN &slot;
}
// Update oldest
if (oldest == nullptr || slot.get_last_access() < oldest->get_last_access()) {
@ -99,38 +99,38 @@ LM::InferencePool::Slot *LM::InferencePool::find_slot_by_id(size_t id, bool dese
}
// Slot not found, attempt to load it
if (deserialize) {
if (!oldest->is_free()) store_slot(*oldest);
if (!load_slot(id, oldest)) {
if (!oldest->is_free()) LM_COAWAIT store_slot(*oldest);
if (!LM_COAWAIT load_slot(id, oldest)) {
// In case slot loading failed, still reset slot for later use
//TODO: Make this configurable
oldest->reset();
} else {
return oldest;
LM_CORETURN oldest;
}
}
// Slot not found
return nullptr;
LM_CORETURN nullptr;
}
std::shared_ptr<LM::Inference> LM::InferencePool::get_inference(size_t id) {
auto slot = find_slot_by_id(id);
LM_SCHEDULABLE(std::shared_ptr<LM::Inference>) LM::InferencePool::get_inference(size_t id) {
auto slot = LM_COAWAIT find_slot_by_id(id);
if (slot) {
return slot->get_inference(true);
LM_CORETURN slot->get_inference(true);
}
return {};
LM_CORETURN {};
}
std::shared_ptr<LM::Inference> LM::InferencePool::get_or_create_inference(size_t id, const std::string &weights_path, const Inference::Params &p) {
auto slot = find_slot_by_id(id);
LM_SCHEDULABLE(std::shared_ptr<LM::Inference>) LM::InferencePool::get_or_create_inference(size_t id, const std::string &weights_path, const Inference::Params &p) {
auto slot = LM_COAWAIT find_slot_by_id(id);
if (slot) {
return slot->get_inference(true);
LM_CORETURN slot->get_inference(true);
}
slot = get_free_slot();
return slot->create_inference(id, weights_path, p);
slot = LM_COAWAIT get_free_slot();
LM_CORETURN slot->create_inference(id, weights_path, p);
}
void LM::InferencePool::delete_inference(size_t id) {
auto slot = find_slot_by_id(id, false);
LM_SCHEDULABLE(void) LM::InferencePool::delete_inference(size_t id) {
auto slot = LM_COAWAIT find_slot_by_id(id, false);
// Reset slot
if (slot) {
slot->reset();
@ -140,12 +140,12 @@ void LM::InferencePool::delete_inference(size_t id) {
std::filesystem::remove(get_slot_filename(id), ec);
}
void LM::InferencePool::store_all() {
LM_SCHEDULABLE(void) LM::InferencePool::store_all() {
for (auto& slot : slots) {
if (slot.is_free()) continue;
store_slot(slot);
LM_COAWAIT store_slot(slot);
}
return;
LM_CORETURN;
}
std::vector<size_t> LM::InferencePool::get_active_slot_ids() const {

View file

@ -1,39 +0,0 @@
#include "justlm_llama.hpp"
#include "justlm.hpp"
#include <string>
#include <string_view>
#include <fstream>
#include <cstdint>
extern "C" {
const LM::Implementation *get_justlm_implementation() {
static LM::Implementation fres{false};
return &fres;
}
bool magic_match(std::istream& f) {
// Check magic
uint32_t magic = 0;
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
return magic == 0x46554747;
}
LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) {
f.close();
return new LM::LLaMAInference(weights_path, p);
}
}
__attribute__((constructor))
static void init() {
llama_backend_init(true);
}
__attribute__((destructor))
static void deinit() {
llama_backend_free();
}

1
llama.cpp Submodule

@ -0,0 +1 @@
Subproject commit 0e018fe008eacebdbcfa2d61b6c988c245c961cd

@ -1 +0,0 @@
Subproject commit b9f47952ffae4e0d3420905526003c23333f6c98

View file

@ -1,452 +0,0 @@
cmake_minimum_required(VERSION 3.12) # Don't bump this version for no reason
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
endif()
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
set(LLAMA_STANDALONE ON)
# configure project version
# TODO
else()
set(LLAMA_STANDALONE OFF)
endif()
if (EMSCRIPTEN)
set(BUILD_SHARED_LIBS_DEFAULT OFF)
option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON)
else()
if (MINGW)
set(BUILD_SHARED_LIBS_DEFAULT OFF)
else()
set(BUILD_SHARED_LIBS_DEFAULT ON)
endif()
endif()
#
# Option list
#
# general
option(LLAMA_STATIC "llama: static link libraries" OFF)
option(LLAMA_NATIVE "llama: enable -march=native flag" OFF)
option(LLAMA_LTO "llama: enable link time optimization" OFF)
# debug
option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON)
option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF)
option(LLAMA_GPROF "llama: enable gprof" OFF)
# sanitizers
option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF)
option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF)
option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF)
# instruction set specific
option(LLAMA_AVX "llama: enable AVX" ON)
option(LLAMA_AVX2 "llama: enable AVX2" ON)
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
option(LLAMA_FMA "llama: enable FMA" ON)
# in MSVC F16C is implied with AVX2/AVX512
if (NOT MSVC)
option(LLAMA_F16C "llama: enable F16C" ON)
endif()
# 3rd party libs
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF)
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
option(LLAMA_METAL "llama: use Metal" OFF)
option(LLAMA_K_QUANTS "llama: use k-quants" ON)
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
#
# Compile flags
#
set(CMAKE_C_STANDARD 11)
set(CMAKE_C_STANDARD_REQUIRED true)
set(THREADS_PREFER_PTHREAD_FLAG ON)
find_package(Threads REQUIRED)
if (NOT MSVC)
if (LLAMA_SANITIZE_THREAD)
add_compile_options(-fsanitize=thread)
link_libraries(-fsanitize=thread)
endif()
if (LLAMA_SANITIZE_ADDRESS)
add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
link_libraries(-fsanitize=address)
endif()
if (LLAMA_SANITIZE_UNDEFINED)
add_compile_options(-fsanitize=undefined)
link_libraries(-fsanitize=undefined)
endif()
endif()
if (APPLE AND LLAMA_ACCELERATE)
find_library(ACCELERATE_FRAMEWORK Accelerate)
if (ACCELERATE_FRAMEWORK)
message(STATUS "Accelerate framework found")
add_compile_definitions(GGML_USE_ACCELERATE)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
else()
message(WARNING "Accelerate framework not found")
endif()
endif()
if (LLAMA_OPENBLAS)
if (LLAMA_STATIC)
set(BLA_STATIC ON)
endif()
set(BLA_VENDOR OpenBLAS)
find_package(BLAS)
if (BLAS_FOUND)
message(STATUS "OpenBLAS found")
add_compile_definitions(GGML_USE_OPENBLAS)
add_link_options(${BLAS_LIBRARIES})
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} openblas)
# find header file
set(OPENBLAS_INCLUDE_SEARCH_PATHS
/usr/include
/usr/include/openblas
/usr/include/openblas-base
/usr/local/include
/usr/local/include/openblas
/usr/local/include/openblas-base
/opt/OpenBLAS/include
$ENV{OpenBLAS_HOME}
$ENV{OpenBLAS_HOME}/include
)
find_path(OPENBLAS_INC NAMES cblas.h PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS})
add_compile_options(-I${OPENBLAS_INC})
else()
message(WARNING "OpenBLAS not found")
endif()
endif()
if (LLAMA_ALL_WARNINGS)
if (NOT MSVC)
set(c_flags
-Wall
-Wextra
-Wpedantic
-Wcast-qual
-Wdouble-promotion
-Wshadow
-Wstrict-prototypes
-Wpointer-arith
)
set(cxx_flags
-Wall
-Wextra
-Wpedantic
-Wcast-qual
-Wno-unused-function
-Wno-multichar
)
else()
# todo : msvc
endif()
add_compile_options(
"$<$<COMPILE_LANGUAGE:C>:${c_flags}>"
"$<$<COMPILE_LANGUAGE:CXX>:${cxx_flags}>"
)
endif()
if (MSVC)
add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
if (BUILD_SHARED_LIBS)
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
endif()
endif()
if (LLAMA_LTO)
include(CheckIPOSupported)
check_ipo_supported(RESULT result OUTPUT output)
if (result)
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
else()
message(WARNING "IPO is not supported: ${output}")
endif()
endif()
# Architecture specific
# TODO: probably these flags need to be tweaked on some architectures
# feel free to update the Makefile for your architecture and send a pull request or issue
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
if (NOT MSVC)
if (LLAMA_STATIC)
add_link_options(-static)
if (MINGW)
add_link_options(-static-libgcc -static-libstdc++)
endif()
endif()
if (LLAMA_GPROF)
add_compile_options(-pg)
endif()
if (LLAMA_NATIVE)
add_compile_options(-march=native)
endif()
endif()
function(remove_nonexistent SOURCES)
set(SOURCES_BAK ${${SOURCES}})
set(${SOURCES} )
foreach (FILE ${SOURCES_BAK})
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${FILE})
set(${SOURCES} ${${SOURCES}} ${FILE})
endif()
endforeach()
set(${SOURCES} ${${SOURCES}} PARENT_SCOPE)
endfunction()
function(include_ggml DIRECTORY SUFFIX WITH_LLAMA)
message(STATUS "Configuring ggml implementation target llama${SUFFIX} in ${CMAKE_CURRENT_SOURCE_DIR}/${DIRECTORY}")
#
# Build libraries
#
set(GGML_CUBLAS_USE NO)
if (LLAMA_CUBLAS)
cmake_minimum_required(VERSION 3.17)
find_package(CUDAToolkit)
if (CUDAToolkit_FOUND)
set(GGML_CUBLAS_USE YES)
message(STATUS "cuBLAS found")
enable_language(CUDA)
set(GGML_SOURCES_CUDA ${DIRECTORY}/ggml-cuda.cu ${DIRECTORY}/ggml-cuda.h)
if (LLAMA_STATIC)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
else()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
endif()
else()
message(WARNING "cuBLAS not found")
endif()
endif()
set(GGML_CLBLAST_USE NO)
if (LLAMA_CLBLAST)
find_package(CLBlast)
if (CLBlast_FOUND)
set(GGML_CLBLAST_USE YES)
message(STATUS "CLBlast found")
set(GGML_OPENCL_SOURCE_FILE ggml-opencl.cpp)
if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${DIRECTORY}/${GGML_OPENCL_SOURCE_FILE})
set(GGML_OPENCL_SOURCE_FILE ggml-opencl.c)
endif()
set(GGML_OPENCL_SOURCES ${DIRECTORY}/${GGML_OPENCL_SOURCE_FILE} ${DIRECTORY}/ggml-opencl.h)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast)
else()
message(WARNING "CLBlast not found")
endif()
endif()
set(GGML_METAL_SOURCES )
if (LLAMA_METAL)
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
set(GGML_METAL_SOURCES ${DIRECTORY}/ggml-metal.m ${DIRECTORY}/ggml-metal.h)
# get full path to the file
#add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/")
# copy ggml-metal.metal to bin directory
configure_file(${DIRECTORY}/ggml-metal.metal bin/ggml-metal.metal COPYONLY)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS}
${FOUNDATION_LIBRARY}
${METAL_FRAMEWORK}
${METALKIT_FRAMEWORK}
${METALPERFORMANCE_FRAMEWORK}
)
endif()
set(GGML_SOURCES
${DIRECTORY}/ggml.c
${DIRECTORY}/ggml.h
${DIRECTORY}/ggml-alloc.c
${DIRECTORY}/ggml-alloc.h
${DIRECTORY}/ggml-quants.c
${DIRECTORY}/ggml-quants.h
${DIRECTORY}/ggml-backend.c
${DIRECTORY}/ggml-backend.h}
${GGML_SOURCES_CUDA}
${GGML_METAL_SOURCES}
${GGML_OPENCL_SOURCES})
remove_nonexistent(GGML_SOURCES)
add_library(ggml${SUFFIX} OBJECT ${GGML_SOURCES})
target_compile_definitions(ggml${SUFFIX} PRIVATE _GNU_SOURCE)
if (LLAMA_K_QUANTS)
target_compile_definitions(ggml${SUFFIX} PUBLIC GGML_USE_K_QUANTS)
endif()
if (LLAMA_METAL AND GGML_METAL_SOURCES)
target_compile_definitions(ggml${SUFFIX} PUBLIC GGML_USE_METAL GGML_METAL_NDEBUG)
endif()
target_include_directories(ggml${SUFFIX} PUBLIC ${DIRECTORY})
target_compile_features(ggml${SUFFIX} PUBLIC c_std_11) # don't bump
if (BUILD_SHARED_LIBS)
set_target_properties(ggml${SUFFIX} PROPERTIES POSITION_INDEPENDENT_CODE ON)
endif()
if (WITH_LLAMA)
SET(LLAMA_SOURCES
${DIRECTORY}/llama.cpp
${DIRECTORY}/llama.h
${DIRECTORY}/common/grammar-parser.h
${DIRECTORY}/common/grammar-parser.cpp)
remove_nonexistent(LLAMA_SOURCES)
add_library(llama${SUFFIX} ${LLAMA_SOURCES})
if (LLAMA_METAL AND GGML_METAL_SOURCES)
target_compile_definitions(llama${SUFFIX} PUBLIC GGML_USE_METAL GGML_METAL_NDEBUG)
endif()
target_include_directories(llama${SUFFIX} PUBLIC ${DIRECTORY})
target_compile_features(llama${SUFFIX} PUBLIC cxx_std_11) # don't bump
if (BUILD_SHARED_LIBS)
set_target_properties(llama${SUFFIX} PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(llama${SUFFIX} PRIVATE LLAMA_SHARED LLAMA_BUILD)
endif()
endif()
if (GGML_SOURCES_CUDA)
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
set_property(TARGET ggml${SUFFIX} PROPERTY CUDA_ARCHITECTURES OFF)
set_property(TARGET ggml${SUFFIX} PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
if (WITH_LLAMA)
set_property(TARGET llama${SUFFIX} PROPERTY CUDA_ARCHITECTURES OFF)
endif()
endif()
if (GGML_CUBLAS_USE)
target_compile_definitions(ggml${SUFFIX} PRIVATE
GGML_USE_CUBLAS
GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}
GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
if (WITH_LLAMA)
target_compile_definitions(llama${SUFFIX} PRIVATE
GGML_USE_CUBLAS
GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}
GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
endif()
endif()
if (GGML_CLBLAST_USE)
if (WITH_LLAMA)
target_compile_definitions(llama${SUFFIX} PRIVATE GGML_USE_CLBLAST)
endif()
target_compile_definitions(ggml${SUFFIX} PRIVATE GGML_USE_CLBLAST)
endif()
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
message(STATUS "ARM detected")
if (MSVC)
# TODO: arm msvc?
else()
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
target_compile_options(ggml${SUFFIX} PRIVATE -mcpu=native)
endif()
# TODO: armv6,7,8 version specific flags
endif()
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
message(STATUS "x86 detected")
if (MSVC)
if (LLAMA_AVX512)
target_compile_definitions(ggml${SUFFIX} PRIVATE
$<$<COMPILE_LANGUAGE:C>:/arch:AVX512>
$<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
# MSVC has no compile-time flags enabling specific
# AVX512 extensions, neither it defines the
# macros corresponding to the extensions.
# Do it manually.
if (LLAMA_AVX512_VBMI)
target_compile_definitions(ggml${SUFFIX} PRIVATE
$<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>
$<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
endif()
if (LLAMA_AVX512_VNNI)
target_compile_definitions(ggml${SUFFIX} PRIVATE
$<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>
$<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
endif()
elseif (LLAMA_AVX2)
target_compile_definitions(ggml${SUFFIX} PRIVATE
$<$<COMPILE_LANGUAGE:C>:/arch:AVX2>
$<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>)
elseif (LLAMA_AVX)
target_compile_definitions(ggml${SUFFIX} PRIVATE
$<$<COMPILE_LANGUAGE:C>:/arch:AVX>
$<$<COMPILE_LANGUAGE:CXX>:/arch:AVX>)
endif()
else()
if (LLAMA_F16C)
target_compile_options(ggml${SUFFIX} PRIVATE -mf16c)
endif()
if (LLAMA_FMA)
target_compile_options(ggml${SUFFIX} PRIVATE -mfma)
endif()
if (LLAMA_AVX)
target_compile_options(ggml${SUFFIX} PRIVATE -mavx)
endif()
if (LLAMA_AVX2)
target_compile_options(ggml${SUFFIX} PRIVATE -mavx2)
endif()
if (LLAMA_AVX512)
target_compile_options(ggml${SUFFIX} PRIVATE -mavx512f)
target_compile_options(ggml${SUFFIX} PRIVATE -mavx512bw)
endif()
if (LLAMA_AVX512_VBMI)
target_compile_options(ggml${SUFFIX} PRIVATE -mavx512vbmi)
endif()
if (LLAMA_AVX512_VNNI)
target_compile_options(ggml${SUFFIX} PRIVATE -mavx512vnni)
endif()
endif()
else()
# TODO: support PowerPC
message(STATUS "Unknown architecture")
endif()
target_link_libraries(ggml${SUFFIX} PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
if (WITH_LLAMA)
target_link_libraries(llama${SUFFIX} PRIVATE ggml${SUFFIX} ${LLAMA_EXTRA_LIBS})
endif()
endfunction()

26
mpt.cpp
View file

@ -1,26 +0,0 @@
#include "justlm_mpt.hpp"
#include "justlm.hpp"
#include <string>
#include <string_view>
#include <fstream>
#include <cstdint>
extern "C" {
const LM::Implementation *get_justlm_implementation() {
static LM::Implementation fres{false};
return &fres;
}
bool magic_match(std::istream& f) {
uint32_t magic;
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
return magic == 0x67676d6d;
}
LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) {
return new LM::MPTInference(weights_path, f, p);
}
}

View file

@ -1,5 +1,4 @@
#include "mpt.hpp"
#include "../g4a_common.hpp"
#include <cassert>
#include <cmath>
@ -11,7 +10,7 @@
#include <string>
#include <vector>
#include <iostream>
#include "../msvc_compat_unistd.h"
#include <unistd.h>
#include <sstream>
#include <thread>
#include <unordered_set>
@ -19,7 +18,7 @@
#include <ggml.h>
inline
unsigned long long operator ""_MiB(unsigned long long bytes) {
unsigned long long operator ""_MB(unsigned long long bytes) {
return bytes*1024*1024;
}
@ -34,7 +33,7 @@ static bool kv_cache_init(
const int64_t n_mem = (int64_t)n_layer*n_ctx;
const int64_t n_elements = n_embd*n_mem;
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MiB);
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MB);
struct ggml_init_params params;
params.mem_size = cache.buf.size;
@ -54,8 +53,13 @@ static bool kv_cache_init(
return true;
}
std::string regex_escape(const std::string &s) {
static const std::regex metacharacters(R"([\.\^\$\-\+\(\)\[\]\{\}\|\?\*])");
return std::regex_replace(s, metacharacters, "\\$&");
}
// load the model's weights from a stream
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, gpt_vocab & vocab) {
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, mpt_vocab & vocab) {
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
// verify magic
@ -119,6 +123,8 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
vocab.id_to_token[i] = word;
}
// TODO: this only kind-of works, the gpt_tokenize can still incorrectly
// tokenize special tokens
if(special) {
vocab.add_special_token(word);
}
@ -326,7 +332,7 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
}
// load the model's weights from a file path
bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vocab) {
bool mpt_model_load(const std::string & fname, mpt_model & model, mpt_vocab & vocab) {
auto fin = std::ifstream(fname, std::ios::binary);
if (!fin) {
@ -356,31 +362,30 @@ bool mpt_eval(
const int n_head = hparams.n_head;
const int n_vocab = hparams.n_vocab;
const size_t init_buf_size = 1024_MiB;
if (!model.buf.addr || model.buf.size < init_buf_size)
model.buf.resize(init_buf_size);
static size_t buf_size = 256u*1024*1024;
static void * buf = malloc(buf_size);
if (mem_per_token > 0 && mem_per_token*N > model.buf.size) {
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
// printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, model.buf.size, buf_size_new);
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
// reallocate
model.buf.resize(buf_size_new);
if (model.buf.addr == nullptr) {
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, model.buf.size);
buf_size = buf_size_new;
buf = realloc(buf, buf_size);
if (buf == nullptr) {
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
return false;
}
}
struct ggml_init_params params = {
model.buf.size,
model.buf.addr,
false
.mem_size = buf_size,
.mem_buffer = buf,
.no_alloc = false,
};
struct ggml_context * ctx0 = ggml_init(params);
struct ggml_cgraph gf{};
gf.n_threads = n_threads;
struct ggml_cgraph gf = { .n_threads = n_threads };
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
@ -516,12 +521,10 @@ bool mpt_eval(
out = ggml_mul_mat(ctx0, model.wte, out);
}
// run the computation
ggml_build_forward_expand(&gf, out);
ggml_graph_compute (ctx0, &gf);
// return result for just the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(out) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
@ -536,6 +539,98 @@ bool mpt_eval(
return true;
}
std::vector<int> mpt_tokenize_inner(const mpt_vocab & vocab, const std::string & text) {
// taken from stablelm example in ggml
// they both use the gpt-neox tokenizer
// not sure if this entirely right?
std::vector<std::string> words;
// first split the text into words
{
std::string str = text;
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
std::regex re(pat);
std::smatch m;
while (std::regex_search(str, m, re)) {
for (auto x : m) {
words.push_back(x);
}
str = m.suffix();
}
}
// find the longest tokens that form the words:
std::vector<mpt_vocab::id> tokens;
for (const auto & word : words) {
if (word.size() == 0) continue;
int i = 0;
int n = word.size();
while (i < n) {
int j = n;
while (j > i) {
auto it = vocab.token_to_id.find(word.substr(i, j-i));
if (it != vocab.token_to_id.end()) {
tokens.push_back(it->second);
i = j;
break;
}
--j;
}
if (i == n) {
break;
}
if (j == i) {
auto sub = word.substr(i, 1);
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
tokens.push_back(vocab.token_to_id.at(sub));
} else {
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
}
++i;
}
}
}
return tokens;
}
std::vector<mpt_vocab::id> mpt_tokenize(const mpt_vocab & vocab, const std::string & text) {
// Generate the subpattern from the special_tokens vector if it's not empty
if (!vocab.special_tokens.empty()) {
std::vector<mpt_vocab::id> out;
std::vector<std::string> chunks;
std::string str = text;
std::string special_tokens_subpattern;
for (const auto &token : vocab.special_tokens) {
if (!special_tokens_subpattern.empty()) {
special_tokens_subpattern += "|";
}
special_tokens_subpattern += regex_escape(token);
}
std::regex re(special_tokens_subpattern);
std::smatch m;
while (std::regex_search(str, m, re)) {
auto tok = vocab.token_to_id.find(m.str());
if (tok != vocab.token_to_id.end()) {
auto tokid = tok->second;
auto pfxtoks = mpt_tokenize_inner(vocab, m.prefix());
out.insert(out.end(), pfxtoks.begin(), pfxtoks.end());
out.push_back(tokid);
str = m.suffix();
}
}
if (!str.empty()) {
auto tokrest = mpt_tokenize_inner(vocab, str);
out.insert(out.end(), tokrest.begin(), tokrest.end());
}
return out;
} else {
return mpt_tokenize_inner(vocab, text);
}
}
#define MPT_MAX_RNG_STATE 64*1024
@ -596,6 +691,104 @@ size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint
return written;
}
mpt_vocab::id mpt_sample_top_k_top_p(
const mpt_vocab & vocab,
const size_t actualVocabSize,
const int32_t * last_n_tokens_data,
int last_n_tokens_size,
const std::vector<float> logits,
int top_k,
double top_p,
double temp,
float repeat_penalty,
std::mt19937 & rng) {
int n_logits = actualVocabSize;
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
const auto * plogits = logits.data() + logits.size() - n_logits;
std::vector<std::pair<double, mpt_vocab::id>> logits_id;
logits_id.reserve(n_logits);
{
const float scale = 1.0f/temp;
for (int i = 0; i < n_logits; ++i) {
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if (plogits[i] < 0.0f) {
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
} else {
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
}
} else {
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
}
}
}
// find the top K tokens
std::partial_sort(
logits_id.begin(),
logits_id.begin() + top_k, logits_id.end(),
[](const std::pair<double, mpt_vocab::id> & a, const std::pair<double, mpt_vocab::id> & b) {
return a.first > b.first;
});
logits_id.resize(top_k);
double maxl = -INFINITY;
for (const auto & kv : logits_id) {
maxl = std::max(maxl, kv.first);
}
// compute probs for the top K tokens
std::vector<double> probs;
probs.reserve(logits_id.size());
double sum = 0.0;
for (const auto & kv : logits_id) {
double p = exp(kv.first - maxl);
probs.push_back(p);
sum += p;
}
// normalize the probs
for (auto & p : probs) {
p /= sum;
}
if (top_p < 1.0f) {
double cumsum = 0.0f;
for (int i = 0; i < top_k; i++) {
cumsum += probs[i];
if (cumsum >= top_p) {
top_k = i + 1;
probs.resize(top_k);
logits_id.resize(top_k);
break;
}
}
cumsum = 1.0/cumsum;
for (int i = 0; i < (int) probs.size(); i++) {
probs[i] *= cumsum;
}
}
//printf("\n");
//for (int i = 0; i < (int) probs.size(); i++) {
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
//}
//exit(0);
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
return logits_id[idx].second;
}
size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *src)
{
const uint8_t * in = src;

View file

@ -1,7 +1,5 @@
#ifndef MPT_H
#define MPT_H
#include "../g4a_common.hpp"
#include <string>
#include <vector>
#include <map>
@ -85,6 +83,7 @@ struct mpt_model {
struct ggml_context * ctx;
std::map<std::string, struct ggml_tensor *> tensors;
mpt_buffer buf;
~mpt_model() {
@ -94,9 +93,24 @@ struct mpt_model {
}
};
struct mpt_vocab {
using id = int32_t;
using token = std::string;
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, gpt_vocab& vocab);
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;
std::vector<std::string> special_tokens;
void add_special_token(const std::string &token) {
special_tokens.push_back(token);
}
};
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, mpt_vocab & vocab);
bool mpt_eval(mpt_model& model, const int n_threads, const int n_past, const std::vector<int>& embd_inp, std::vector<float>& embd_w, size_t& mem_per_token);
std::vector<mpt_vocab::id> mpt_tokenize(const mpt_vocab & vocab, const std::string & text);
mpt_vocab::id mpt_sample_top_k_top_p(const mpt_vocab& vocab, const size_t actualVocabSize, const int32_t *last_n_tokens_data, int last_n_tokens_size, const std::vector<float> logits, int top_k, double top_p, double temp, float repeat_penalty, std::mt19937& rng);
size_t mpt_get_state_size(const mpt_model &model);
size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937& rng, uint8_t *dest);
size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *src);

View file

@ -1,11 +0,0 @@
#if defined(_WIN32) && defined(_MSC_VER)
#define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX
#endif
#include <windows.h>
#include <io.h>
#include <stdio.h> // for _fseeki64
#else
#include <unistd.h>
#endif

View file

@ -9,7 +9,7 @@ namespace py = pybind11;
PYBIND11_MODULE(justlm_py, m) {
PYBIND11_MODULE(libjustlm_py, m) {
using namespace LM;
py::class_<Inference::Params>(m, "Params")
.def(py::init<>())
@ -24,23 +24,16 @@ PYBIND11_MODULE(justlm_py, m) {
.def_readwrite("top_p", &Inference::Params::top_p)
.def_readwrite("temp", &Inference::Params::temp)
.def_readwrite("repeat_penalty", &Inference::Params::repeat_penalty)
.def_readwrite("eos_ignores", &Inference::Params::n_eos_ignores)
.def_readwrite("use_mlock", &Inference::Params::use_mlock)
.def_readwrite("prefer_mirostat", &Inference::Params::prefer_mirostat)
.def_readwrite("mirostat_learning_rate", &Inference::Params::mirostat_learning_rate)
.def_readwrite("mirostat_target_entropy", &Inference::Params::mirostat_target_entropy);
.def_readwrite("eos_ignores", &Inference::Params::eos_ignores)
.def_readwrite("use_mlock", &Inference::Params::use_mlock);
py::class_<Inference>(m, "Inference")
.def_static("construct", &Inference::construct, py::arg("weights_path"), py::arg("params") = Inference::Params())
.def("append", &Inference::append, py::arg("prompt"), py::arg("on_tick") = nullptr)
.def("run", &Inference::run, py::arg("end") = "", py::arg("on_tick") = nullptr, py::arg("pre_tick") = nullptr)
.def("run", &Inference::run, py::arg("end") = "", py::arg("on_tick") = nullptr)
.def("create_savestate", &Inference::create_savestate)
.def("restore_savestate", &Inference::restore_savestate)
.def("get_prompt", &Inference::get_prompt)
.def("get_context_size", &Inference::get_context_size)
.def("is_mirostat_available", &Inference::is_mirostat_available)
.def("is_grammar_available", &Inference::is_grammar_available)
.def("load_grammar", &Inference::load_grammar)
.def("unload_grammar", &Inference::unload_grammar)
.def_readwrite("params", &Inference::params);
py::class_<Inference::Savestate>(m, "Savestate")
.def(py::init<>());