mirror of
https://gitlab.com/niansa/libjustlm.git
synced 2025-03-06 20:49:17 +01:00
Compare commits
No commits in common. "master" and "v0.1" have entirely different histories.
24 changed files with 592 additions and 1270 deletions
12
.gitmodules
vendored
12
.gitmodules
vendored
|
@ -1,12 +1,6 @@
|
||||||
|
[submodule "llama.cpp"]
|
||||||
|
path = llama.cpp
|
||||||
|
url = https://github.com/ggerganov/llama.cpp.git
|
||||||
[submodule "llama.cpp-alibi"]
|
[submodule "llama.cpp-alibi"]
|
||||||
path = llama.cpp-alibi
|
path = llama.cpp-alibi
|
||||||
url = https://github.com/manyoso/llama.cpp.git
|
url = https://github.com/manyoso/llama.cpp.git
|
||||||
[submodule "llama.cpp-230511"]
|
|
||||||
path = llama.cpp-230511
|
|
||||||
url = https://github.com/ggerganov/llama.cpp.git
|
|
||||||
[submodule "llama.cpp-230519"]
|
|
||||||
path = llama.cpp-230519
|
|
||||||
url = https://github.com/ggerganov/llama.cpp.git
|
|
||||||
[submodule "llama.cpp-mainline"]
|
|
||||||
path = llama.cpp-mainline
|
|
||||||
url = https://github.com/ggerganov/llama.cpp.git
|
|
||||||
|
|
|
@ -1,68 +1,52 @@
|
||||||
cmake_minimum_required(VERSION 3.18)
|
cmake_minimum_required(VERSION 3.14)
|
||||||
|
|
||||||
project(justlm LANGUAGES C CXX)
|
|
||||||
|
|
||||||
|
project(libjustlm LANGUAGES C CXX)
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
|
||||||
|
|
||||||
option(LM_PYBIND "If justlm Python bindings should be build" OFF)
|
set(LM_PYBIND No CACHE BOOL "If Libjustlm Python bindings should be build")
|
||||||
option(LM_NOEXCEPT "If justlm exceptions should be disabled" OFF)
|
set(LM_COSCHED No CACHE BOOL "If Libjustlm should make use of CoSched")
|
||||||
option(LM_LLAMA "If LLaMa model support should be built into justlm" ON)
|
set(LM_NOEXCEPT No CACHE BOOL "If exceptions should be disabled")
|
||||||
option(LM_GPTJ "If GPT-J model support should be built into justlm" ON)
|
set(LM_MPT No CACHE BOOL "If MPT model support should be built")
|
||||||
option(LM_MPT "If MPT model support should be built into justlm" ON)
|
|
||||||
|
|
||||||
|
|
||||||
function(target_justlm_setup TARGET_NAME)
|
|
||||||
message(STATUS "Configuring model implementation target ${TARGET_NAME}")
|
|
||||||
target_include_directories(${TARGET_NAME} PUBLIC include/)
|
|
||||||
if (LM_NOEXCEPT)
|
|
||||||
target_compile_definitions(${TARGET_NAME} PUBLIC LM_NOEXCEPT)
|
|
||||||
endif()
|
|
||||||
endfunction()
|
|
||||||
|
|
||||||
|
|
||||||
include(llama.cpp.cmake)
|
|
||||||
|
|
||||||
include_ggml(llama.cpp-mainline _mainline Yes)
|
|
||||||
include_ggml(llama.cpp-alibi _alibi No)
|
|
||||||
|
|
||||||
|
|
||||||
add_library(justlm_g4a_common SHARED g4a_common.cpp g4a_common.hpp)
|
|
||||||
|
|
||||||
|
if (LM_COSCHED)
|
||||||
|
set(CMAKE_CXX_STANDARD 20)
|
||||||
|
endif()
|
||||||
|
|
||||||
if (LM_MPT)
|
if (LM_MPT)
|
||||||
add_library(justlm_mpt SHARED mpt.cpp justlm_mpt.hpp mpt/mpt.cpp mpt/mpt.hpp)
|
set(LM_MPT_SOURCES justlm_mpt.hpp mpt/mpt.cpp mpt/mpt.hpp)
|
||||||
target_link_libraries(justlm_mpt PRIVATE ggml_alibi justlm_g4a_common)
|
add_subdirectory(llama.cpp-alibi)
|
||||||
target_justlm_setup(justlm_mpt)
|
else()
|
||||||
|
set(LM_MPT_SOURCES )
|
||||||
|
add_subdirectory(llama.cpp)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
if (LM_GPTJ)
|
add_library(libjustlm STATIC
|
||||||
add_library(justlm_gptj SHARED gptj.cpp justlm_gptj.hpp gptj/gptj.cpp gptj/gptj.hpp)
|
|
||||||
target_link_libraries(justlm_gptj PRIVATE ggml_alibi justlm_g4a_common)
|
|
||||||
target_justlm_setup(justlm_gptj)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LM_LLAMA)
|
|
||||||
add_library(justlm_llama SHARED llama.cpp justlm_llama.hpp)
|
|
||||||
target_link_libraries(justlm_llama PRIVATE ggml_mainline llama_mainline)
|
|
||||||
target_compile_definitions(justlm_llama PRIVATE LLAMA_DATE=999999)
|
|
||||||
target_justlm_setup(justlm_llama)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
|
|
||||||
add_library(justlm STATIC
|
|
||||||
include/justlm.hpp justlm.cpp
|
include/justlm.hpp justlm.cpp
|
||||||
|
justlm_llama.hpp
|
||||||
|
g4a-common.cpp g4a-common.hpp
|
||||||
|
justlm_gptj.hpp gptj/gptj.cpp gptj/gptj.hpp
|
||||||
|
${LM_MPT_SOURCES}
|
||||||
include/justlm_pool.hpp justlm_pool.cpp
|
include/justlm_pool.hpp justlm_pool.cpp
|
||||||
dlhandle.hpp
|
|
||||||
)
|
)
|
||||||
add_library(libjustlm ALIAS justlm)
|
target_link_libraries(libjustlm PRIVATE llama)
|
||||||
target_link_libraries(justlm PRIVATE dl)
|
|
||||||
target_include_directories(justlm PUBLIC include/)
|
if (LM_MPT)
|
||||||
target_compile_definitions(justlm PRIVATE LIB_FILE_EXT="${CMAKE_SHARED_LIBRARY_SUFFIX}")
|
target_compile_definitions(libjustlm PUBLIC LM_MPT)
|
||||||
target_justlm_setup(justlm)
|
endif()
|
||||||
|
|
||||||
|
if (LM_COSCHED)
|
||||||
|
target_compile_definitions(libjustlm PUBLIC LM_COSCHED)
|
||||||
|
target_link_libraries(libjustlm PRIVATE cosched)
|
||||||
|
|
||||||
|
set(LM_COSCHED Yes CACHE BOOL "If Libjustlm should make use of CoSched" FORCE)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LM_NOEXCEPT)
|
||||||
|
target_compile_definitions(libjustlm PUBLIC LM_NOEXCEPT)
|
||||||
|
endif()
|
||||||
|
|
||||||
if (LM_PYBIND)
|
if (LM_PYBIND)
|
||||||
if (LM_COSCHED)
|
if (LM_COSCHED)
|
||||||
|
@ -71,6 +55,8 @@ if (LM_PYBIND)
|
||||||
|
|
||||||
find_package(Python COMPONENTS Interpreter Development)
|
find_package(Python COMPONENTS Interpreter Development)
|
||||||
find_package(pybind11 CONFIG)
|
find_package(pybind11 CONFIG)
|
||||||
pybind11_add_module(justlm_py pybind.cpp)
|
pybind11_add_module(libjustlm_py pybind.cpp)
|
||||||
target_link_libraries(justlm_py PRIVATE justlm)
|
target_link_libraries(libjustlm_py PRIVATE libjustlm)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
target_include_directories(libjustlm PUBLIC include/)
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
# JustLM
|
# JustLM
|
||||||
Super easy to use library for doing LLaMA/GPT-J/MPT stuff!
|
Super easy to use library for doing LLaMA/GPT-J stuff!
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
This library implements an easy to use interface to LLaMa, GPT-J and MPT, with optional Python bindings.
|
This library implements an easy to use interface to both LLaMa and GPT-J, with optional Python bindings.
|
||||||
|
|
||||||
Context scrolling is automatic and supports a top window bar.
|
Context scrolling is automatic and supports a top window bar.
|
||||||
|
|
||||||
|
|
99
dlhandle.hpp
99
dlhandle.hpp
|
@ -1,99 +0,0 @@
|
||||||
#ifndef DLHANDLE_H
|
|
||||||
#define DLHANDLE_H
|
|
||||||
#ifndef __WIN32
|
|
||||||
#include <string>
|
|
||||||
#include <stdexcept>
|
|
||||||
#include <utility>
|
|
||||||
#include <dlfcn.h>
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Dlhandle {
|
|
||||||
void *chandle;
|
|
||||||
|
|
||||||
public:
|
|
||||||
class Exception : public std::runtime_error {
|
|
||||||
public:
|
|
||||||
using std::runtime_error::runtime_error;
|
|
||||||
};
|
|
||||||
|
|
||||||
Dlhandle() : chandle(nullptr) {}
|
|
||||||
Dlhandle(const std::string& fpath, int flags = RTLD_LAZY) {
|
|
||||||
chandle = dlopen(fpath.c_str(), flags);
|
|
||||||
if (!chandle) {
|
|
||||||
throw Exception("dlopen(\""+fpath+"\"): "+dlerror());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Dlhandle(const Dlhandle& o) = delete;
|
|
||||||
Dlhandle(Dlhandle&& o) : chandle(o.chandle) {
|
|
||||||
o.chandle = nullptr;
|
|
||||||
}
|
|
||||||
~Dlhandle() {
|
|
||||||
if (chandle) dlclose(chandle);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto operator =(Dlhandle&& o) {
|
|
||||||
chandle = std::exchange(o.chandle, nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_valid() const {
|
|
||||||
return chandle != nullptr;
|
|
||||||
}
|
|
||||||
operator bool() const {
|
|
||||||
return is_valid();
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
T* get(const std::string& fname) {
|
|
||||||
auto fres = reinterpret_cast<T*>(dlsym(chandle, fname.c_str()));
|
|
||||||
return (dlerror()==NULL)?fres:nullptr;
|
|
||||||
}
|
|
||||||
auto get_fnc(const std::string& fname) {
|
|
||||||
return get<void*(...)>(fname);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
#else
|
|
||||||
#include <string>
|
|
||||||
#include <exception>
|
|
||||||
#include <libloaderapi.h>
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Dlhandle {
|
|
||||||
HMODULE chandle;
|
|
||||||
|
|
||||||
public:
|
|
||||||
class Exception : public std::runtime_error {
|
|
||||||
public:
|
|
||||||
using std::runtime_error::runtime_error;
|
|
||||||
};
|
|
||||||
|
|
||||||
Dlhandle() : chandle(nullptr) {}
|
|
||||||
Dlhandle(const std::string& fpath) {
|
|
||||||
chandle = LoadLibraryA(fpath.c_str());
|
|
||||||
if (!chandle) {
|
|
||||||
throw Exception("dlopen(\""+fpath+"\"): Error");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Dlhandle(const Dlhandle& o) = delete;
|
|
||||||
Dlhandle(Dlhandle&& o) : chandle(o.chandle) {
|
|
||||||
o.chandle = nullptr;
|
|
||||||
}
|
|
||||||
~Dlhandle() {
|
|
||||||
if (chandle) FreeLibrary(chandle);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_valid() const {
|
|
||||||
return chandle != nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename T>
|
|
||||||
T* get(const std::string& fname) {
|
|
||||||
return reinterpret_cast<T*>(GetProcAddress(chandle, fname.c_str()));
|
|
||||||
}
|
|
||||||
auto get_fnc(const std::string& fname) {
|
|
||||||
return get<void*(...)>(fname);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
#endif
|
|
||||||
#endif // DLHANDLE_H
|
|
|
@ -1,4 +1,4 @@
|
||||||
#include "g4a_common.hpp"
|
#include "g4a-common.hpp"
|
||||||
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <regex>
|
#include <regex>
|
||||||
|
@ -102,7 +102,7 @@ std::map<std::string, int32_t> json_parse(const std::string & fname) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<gpt_vocab::id> gpt_tokenize_inner(const gpt_vocab & vocab, const std::string & text) {
|
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
||||||
std::vector<std::string> words;
|
std::vector<std::string> words;
|
||||||
|
|
||||||
// first split the text into words
|
// first split the text into words
|
||||||
|
@ -157,47 +157,6 @@ std::vector<gpt_vocab::id> gpt_tokenize_inner(const gpt_vocab & vocab, const std
|
||||||
return tokens;
|
return tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string regex_escape(const std::string &s) {
|
|
||||||
static const std::regex metacharacters(R"([\.\^\$\-\+\(\)\[\]\{\}\|\?\*])");
|
|
||||||
return std::regex_replace(s, metacharacters, "\\$&");
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<gpt_vocab::id> gpt_tokenize(const gpt_vocab & vocab, const std::string & text) {
|
|
||||||
// Generate the subpattern from the special_tokens vector if it's not empty
|
|
||||||
if (!vocab.special_tokens.empty()) {
|
|
||||||
std::vector<gpt_vocab::id> out;
|
|
||||||
std::vector<std::string> chunks;
|
|
||||||
std::string str = text;
|
|
||||||
std::string special_tokens_subpattern;
|
|
||||||
for (const auto &token : vocab.special_tokens) {
|
|
||||||
if (!special_tokens_subpattern.empty()) {
|
|
||||||
special_tokens_subpattern += "|";
|
|
||||||
}
|
|
||||||
special_tokens_subpattern += regex_escape(token);
|
|
||||||
}
|
|
||||||
std::regex re(special_tokens_subpattern);
|
|
||||||
std::smatch m;
|
|
||||||
while (std::regex_search(str, m, re)) {
|
|
||||||
auto tok = vocab.token_to_id.find(m.str());
|
|
||||||
if (tok != vocab.token_to_id.end()) {
|
|
||||||
auto tokid = tok->second;
|
|
||||||
auto pfxtoks = gpt_tokenize_inner(vocab, m.prefix());
|
|
||||||
out.insert(out.end(), pfxtoks.begin(), pfxtoks.end());
|
|
||||||
out.push_back(tokid);
|
|
||||||
str = m.suffix();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!str.empty()) {
|
|
||||||
auto tokrest = gpt_tokenize_inner(vocab, str);
|
|
||||||
out.insert(out.end(), tokrest.begin(), tokrest.end());
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
} else {
|
|
||||||
return gpt_tokenize_inner(vocab, text);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
||||||
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
|
printf("%s: loading vocab from '%s'\n", __func__, fname.c_str());
|
||||||
|
|
||||||
|
@ -218,7 +177,7 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab) {
|
||||||
}
|
}
|
||||||
|
|
||||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||||
const size_t actualVocabSize,
|
const gpt_vocab & vocab,
|
||||||
const int32_t * last_n_tokens_data,
|
const int32_t * last_n_tokens_data,
|
||||||
int last_n_tokens_size,
|
int last_n_tokens_size,
|
||||||
const std::vector<float> logits,
|
const std::vector<float> logits,
|
||||||
|
@ -227,7 +186,7 @@ gpt_vocab::id gpt_sample_top_k_top_p(
|
||||||
double temp,
|
double temp,
|
||||||
float repeat_penalty,
|
float repeat_penalty,
|
||||||
std::mt19937 & rng) {
|
std::mt19937 & rng) {
|
||||||
int n_logits = actualVocabSize;
|
int n_logits = vocab.id_to_token.size();
|
||||||
|
|
||||||
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
|
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
|
||||||
const auto * plogits = logits.data() + logits.size() - n_logits;
|
const auto * plogits = logits.data() + logits.size() - n_logits;
|
|
@ -44,11 +44,6 @@ struct gpt_vocab {
|
||||||
|
|
||||||
std::map<token, id> token_to_id;
|
std::map<token, id> token_to_id;
|
||||||
std::map<id, token> id_to_token;
|
std::map<id, token> id_to_token;
|
||||||
std::vector<std::string> special_tokens;
|
|
||||||
|
|
||||||
void add_special_token(const std::string &token) {
|
|
||||||
special_tokens.push_back(token);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
void replace(std::string & str, const std::string & needle, const std::string & replacement);
|
void replace(std::string & str, const std::string & needle, const std::string & replacement);
|
||||||
|
@ -79,7 +74,7 @@ bool gpt_vocab_init(const std::string & fname, gpt_vocab & vocab);
|
||||||
// TODO: not sure if this implementation is correct
|
// TODO: not sure if this implementation is correct
|
||||||
//
|
//
|
||||||
gpt_vocab::id gpt_sample_top_k_top_p(
|
gpt_vocab::id gpt_sample_top_k_top_p(
|
||||||
const size_t actualVocabSize,
|
const gpt_vocab & vocab,
|
||||||
const int32_t * last_n_tokens_data,
|
const int32_t * last_n_tokens_data,
|
||||||
int last_n_tokens_size,
|
int last_n_tokens_size,
|
||||||
const std::vector<float> logits,
|
const std::vector<float> logits,
|
26
gptj.cpp
26
gptj.cpp
|
@ -1,26 +0,0 @@
|
||||||
#include "justlm_gptj.hpp"
|
|
||||||
#include "justlm.hpp"
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <string_view>
|
|
||||||
#include <fstream>
|
|
||||||
#include <cstdint>
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
const LM::Implementation *get_justlm_implementation() {
|
|
||||||
static LM::Implementation fres{false};
|
|
||||||
return &fres;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool magic_match(std::istream& f) {
|
|
||||||
uint32_t magic;
|
|
||||||
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
|
||||||
return magic == 0x67676d6c;
|
|
||||||
}
|
|
||||||
|
|
||||||
LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) {
|
|
||||||
return new LM::GPTJInference(weights_path, f, p);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,6 +1,6 @@
|
||||||
#include "gptj.hpp"
|
#include "gptj.hpp"
|
||||||
|
|
||||||
#include "../g4a_common.hpp"
|
#include "../g4a-common.hpp"
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
@ -11,13 +11,13 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include "../msvc_compat_unistd.h"
|
#include <unistd.h>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
#include <ggml.h>
|
#include <ggml.h>
|
||||||
|
|
||||||
constexpr inline
|
constexpr inline
|
||||||
unsigned long long operator ""_MiB(unsigned long long bytes) {
|
unsigned long long operator ""_MB(unsigned long long bytes) {
|
||||||
return bytes*1024*1024;
|
return bytes*1024*1024;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ static bool kv_cache_init(
|
||||||
const int64_t n_mem = (int64_t)n_layer*n_ctx;
|
const int64_t n_mem = (int64_t)n_layer*n_ctx;
|
||||||
const int64_t n_elements = n_embd*n_mem;
|
const int64_t n_elements = n_embd*n_mem;
|
||||||
|
|
||||||
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MiB);
|
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MB);
|
||||||
|
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
params.mem_size = cache.buf.size;
|
params.mem_size = cache.buf.size;
|
||||||
|
@ -394,7 +394,7 @@ bool gptj_eval(
|
||||||
const int n_vocab = hparams.n_vocab;
|
const int n_vocab = hparams.n_vocab;
|
||||||
const int n_rot = hparams.n_rot;
|
const int n_rot = hparams.n_rot;
|
||||||
|
|
||||||
static size_t buf_size = 1024_MiB;
|
static size_t buf_size = 1024_MB;
|
||||||
if (!model.buf.addr || model.buf.size < buf_size)
|
if (!model.buf.addr || model.buf.size < buf_size)
|
||||||
model.buf.resize(buf_size);
|
model.buf.resize(buf_size);
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
#include <map>
|
#include <map>
|
||||||
#include <ggml.h>
|
#include <ggml.h>
|
||||||
|
|
||||||
#include "../g4a_common.hpp"
|
#include "../g4a-common.hpp"
|
||||||
|
|
||||||
|
|
||||||
// default hparams (GPT-J 6B)
|
// default hparams (GPT-J 6B)
|
||||||
|
|
|
@ -7,28 +7,45 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
|
|
||||||
|
#ifdef LM_COSCHED
|
||||||
|
# include <scheduler.hpp>
|
||||||
|
# define LM_SCHEDULABLE(type) ::CoSched::AwaitableTask<type>
|
||||||
|
# define LM_CORETURN co_return
|
||||||
|
# define LM_COAWAIT co_await
|
||||||
|
# define LM_TASKYIELD (co_await ::CoSched::Task::get_current().yield())
|
||||||
|
#else
|
||||||
|
# define LM_SCHEDULABLE(type) type
|
||||||
|
# define LM_CORETURN return
|
||||||
|
# define LM_COAWAIT
|
||||||
|
# define LM_TASKYIELD (true)
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef LM_NOEXCEPT
|
#ifdef LM_NOEXCEPT
|
||||||
# define LM_NOEXCEPTDECL noexcept
|
# define LM_NOEXCEPTDECL noexcept
|
||||||
# define LM_THROW(t, r) do {this->last_error = (t); return r;} while (0)
|
# define LM_THROW(t, r) this->last_error = (t); return r;
|
||||||
|
# define LM_COTHROW(t, r) this->last_error = (t); LM_CORETURN r;
|
||||||
# define LM_LAST_ERROR_STORAGE mutable std::string last_error;
|
# define LM_LAST_ERROR_STORAGE mutable std::string last_error;
|
||||||
# define LM_LAST_ERROR_GETTER const std::string& get_last_error() const {return last_error;}
|
# define LM_LAST_ERROR_GETTER const std::string& get_last_error() const {return last_error;}
|
||||||
# define LM_ERRBOOL bool
|
# define LM_ERRBOOL bool
|
||||||
# define LM_BOOL_ERROR false
|
# define LM_BOOL_ERROR false
|
||||||
# define LM_BOOL_SUCCESS true
|
# define LM_BOOL_SUCCESS true
|
||||||
# define LM_RETHROW(x) return x
|
# define LM_ERROR_FORWARD(x) {auto v = x; if (!v) LM_CORETURN x;} 0
|
||||||
# define LM_ERROR_CATCH(x, errval, ...) {auto v = x; if (v == (errval)) __VA_ARGS__}
|
|
||||||
# define LM_ERROR_FORWARD(x, errval) do {auto v = x; if (v == (errval)) return x;} while (0)
|
|
||||||
#else
|
#else
|
||||||
# define LM_NOEXCEPTDECL
|
# define LM_NOEXCEPTDECL
|
||||||
# define LM_THROW(t, r) throw Exception(t)
|
# define LM_THROW(t, r) throw Exception(t)
|
||||||
|
# define LM_COTHROW(t, r) throw Exception(t)
|
||||||
# define LM_LAST_ERROR_STORAGE
|
# define LM_LAST_ERROR_STORAGE
|
||||||
# define LM_LAST_ERROR_GETTER
|
# define LM_LAST_ERROR_GETTER
|
||||||
# define LM_ERRBOOL void
|
# define LM_ERRBOOL void
|
||||||
# define LM_BOOL_ERROR
|
# define LM_BOOL_ERROR
|
||||||
# define LM_BOOL_SUCCESS
|
# define LM_BOOL_SUCCESS
|
||||||
# define LM_RETHROW(x) std::rethrow_exception(std::current_exception())
|
# define LM_ERROR_FORWARD(x) {x;}
|
||||||
# define LM_ERROR_CATCH(x, errval, ...) try {x;} catch (...) __VA_ARGS__
|
#endif
|
||||||
# define LM_ERROR_FORWARD(x, errval) {x;}
|
|
||||||
|
#ifdef LM_COSCHED
|
||||||
|
#ifndef LM_NOEXCEPT
|
||||||
|
#warning Exceptions should not be enabled in combination with CoSched. Any exceptions thrown will lead to a std::terminate() call
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if _MSC_VER
|
#if _MSC_VER
|
||||||
|
@ -41,15 +58,18 @@ namespace LM {
|
||||||
using ssize_t = SSIZE_T;
|
using ssize_t = SSIZE_T;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
using GenerateCallback = std::function<bool (const char *generated)>;
|
|
||||||
using AppendCallback = std::function<bool (float progress)>;
|
|
||||||
|
|
||||||
class Inference {
|
class Inference {
|
||||||
protected:
|
protected:
|
||||||
AppendCallback on_scroll = nullptr;
|
std::function<bool (float)> on_scroll = nullptr;
|
||||||
|
|
||||||
void *generic_state = nullptr;
|
void *generic_state = nullptr;
|
||||||
|
|
||||||
|
static inline
|
||||||
|
bool ends_with(std::string_view str, std::string_view suffix) noexcept {
|
||||||
|
if (suffix.empty()) return false;
|
||||||
|
return str.size() >= suffix.size() && 0 == str.compare(str.size()-suffix.size(), suffix.size(), suffix);
|
||||||
|
}
|
||||||
|
|
||||||
LM_LAST_ERROR_STORAGE
|
LM_LAST_ERROR_STORAGE
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -59,25 +79,21 @@ public:
|
||||||
|
|
||||||
struct Params {
|
struct Params {
|
||||||
int seed = 0; // RNG seed
|
int seed = 0; // RNG seed
|
||||||
unsigned n_threads = 0; // Amount of threads to use, immutable after Inference was constructed
|
unsigned n_threads = 0;
|
||||||
unsigned n_ctx = 2024; // Context size
|
unsigned n_ctx = 2012; // Context size
|
||||||
unsigned n_ctx_window_top_bar = 0; // Top bar of context window. Must be smaller than context size
|
unsigned n_ctx_window_top_bar = 0; // Top bar of context window. Must be smaller than context size
|
||||||
unsigned n_batch = 8; // Batch size
|
unsigned n_batch = 8; // Batch size
|
||||||
unsigned n_repeat_last = 0;
|
unsigned n_repeat_last = 0; // llama.cpp specific
|
||||||
unsigned n_eos_ignores = 0;
|
|
||||||
|
|
||||||
float scroll_keep = 0.0f; // 0.4f to keep 40% of context below top bar when scrolling; 0.0f to remove everything after top bar
|
float scroll_keep = 0.0f; // 0.4f to keep 40% of context below top bar when scrolling; 0.0f to remove everything after top bar
|
||||||
|
|
||||||
unsigned top_k = 40;
|
unsigned top_k = 40;
|
||||||
float top_p = 0.9f;
|
float top_p = 0.9f;
|
||||||
float temp = 0.72f;
|
float temp = 0.72f;
|
||||||
float mirostat_learning_rate = 0.1f; // mirostat specific
|
float repeat_penalty = 1.0f; // llama.cpp specific
|
||||||
float mirostat_target_entropy = 5.0f; // mirostat specific
|
unsigned eos_ignores = 0; // llama.cpp specific
|
||||||
float repeat_penalty = 1.0f;
|
|
||||||
|
|
||||||
unsigned n_gpu_layers = 38;
|
bool use_mlock = true; // llama.cpp specific
|
||||||
bool use_mlock = true; // llama specific
|
|
||||||
int prefer_mirostat = 0; // Use given mirostat version if available (see is_mirostat_available()); llama specific
|
|
||||||
} params;
|
} params;
|
||||||
|
|
||||||
struct Savestate {
|
struct Savestate {
|
||||||
|
@ -108,42 +124,27 @@ public:
|
||||||
static
|
static
|
||||||
Inference *construct(const std::string& weights_path, const Params& p);
|
Inference *construct(const std::string& weights_path, const Params& p);
|
||||||
|
|
||||||
void set_scroll_callback(const AppendCallback& scroll_cb) noexcept {
|
void set_scroll_callback(const std::function<bool (float)>& scroll_cb) noexcept {
|
||||||
on_scroll = scroll_cb;
|
on_scroll = scroll_cb;
|
||||||
}
|
}
|
||||||
|
|
||||||
// This must be called with a non-empty prompt!
|
// This must be called with a non-empty prompt!
|
||||||
virtual LM_ERRBOOL append(const std::string& prompt, const AppendCallback& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
|
virtual LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float progress)>& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
|
||||||
|
|
||||||
// append() must have been called at least once before calling this!
|
// append() must have been called at least once before calling this!
|
||||||
virtual std::string run(std::string_view end = "", const GenerateCallback& on_tick = nullptr, const GenerateCallback& pre_tick = nullptr) LM_NOEXCEPTDECL = 0;
|
virtual LM_SCHEDULABLE(std::string) run(std::string_view end = "", const std::function<bool (const char *generated)>& on_tick = nullptr) LM_NOEXCEPTDECL = 0;
|
||||||
|
|
||||||
virtual unsigned get_context_size() const noexcept = 0;
|
virtual unsigned get_context_size() const noexcept = 0;
|
||||||
|
|
||||||
virtual LM_ERRBOOL create_savestate(Savestate&) const LM_NOEXCEPTDECL = 0;
|
virtual LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate&) const LM_NOEXCEPTDECL = 0;
|
||||||
virtual LM_ERRBOOL restore_savestate(const Savestate&) LM_NOEXCEPTDECL = 0;
|
virtual LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate&) LM_NOEXCEPTDECL = 0;
|
||||||
|
|
||||||
virtual LM_ERRBOOL serialize(std::ostream&) const LM_NOEXCEPTDECL = 0;
|
virtual LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream&) const LM_NOEXCEPTDECL = 0;
|
||||||
virtual LM_ERRBOOL deserialize(std::istream&) LM_NOEXCEPTDECL = 0;
|
virtual LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream&) LM_NOEXCEPTDECL = 0;
|
||||||
|
|
||||||
virtual LM_ERRBOOL load_grammar(const std::string&, bool override_temperature [[maybe_unused]] = false) LM_NOEXCEPTDECL {
|
|
||||||
LM_THROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
|
|
||||||
}
|
|
||||||
virtual LM_ERRBOOL unload_grammar() LM_NOEXCEPTDECL {
|
|
||||||
LM_THROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual const std::string& get_prompt() const LM_NOEXCEPTDECL = 0;
|
virtual const std::string& get_prompt() const LM_NOEXCEPTDECL = 0;
|
||||||
|
|
||||||
virtual bool is_mirostat_available() const noexcept {return false;}
|
|
||||||
virtual bool is_grammar_available() const noexcept {return false;}
|
|
||||||
|
|
||||||
LM_LAST_ERROR_GETTER
|
LM_LAST_ERROR_GETTER
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
struct Implementation {
|
|
||||||
bool is_fallback = false;
|
|
||||||
};
|
|
||||||
}
|
}
|
||||||
#endif // JUSTLM_HPP
|
#endif // JUSTLM_HPP
|
||||||
|
|
|
@ -63,21 +63,21 @@ class InferencePool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns false on error
|
// Returns false on error
|
||||||
bool store_slot(Slot& slot);
|
LM_SCHEDULABLE(bool) store_slot(Slot& slot);
|
||||||
// Returns nullptr on error
|
// Returns nullptr on error
|
||||||
Slot *load_slot(size_t id, Slot *suggested_slot = nullptr);
|
LM_SCHEDULABLE(Slot*) load_slot(size_t id, Slot *suggested_slot = nullptr);
|
||||||
|
|
||||||
void store_and_reset_slot(Slot& slot) {
|
LM_SCHEDULABLE(void) store_and_reset_slot(Slot& slot) {
|
||||||
store_slot(slot); //TODO: Should handle errors somehow
|
LM_COAWAIT store_slot(slot); //TODO: Should handle errors somehow
|
||||||
slot.reset();
|
slot.reset();
|
||||||
return;
|
LM_CORETURN;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Doesn't fail
|
// Doesn't fail
|
||||||
Slot *get_free_slot();
|
LM_SCHEDULABLE(Slot*) get_free_slot();
|
||||||
|
|
||||||
// Returns nullptr if not found
|
// Returns nullptr if not found
|
||||||
Slot *find_slot_by_id(size_t id, bool deserialize = true);
|
LM_SCHEDULABLE(Slot*) find_slot_by_id(size_t id, bool deserialize = true);
|
||||||
|
|
||||||
public:
|
public:
|
||||||
// The pool_name must be unique amonst all applications in cwd
|
// The pool_name must be unique amonst all applications in cwd
|
||||||
|
@ -93,14 +93,14 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<Inference> create_inference(size_t id, const std::string& weights_path, const Inference::Params& p) {
|
LM_SCHEDULABLE(std::shared_ptr<Inference>) create_inference(size_t id, const std::string& weights_path, const Inference::Params& p) {
|
||||||
auto slot = get_free_slot();
|
auto slot = LM_COAWAIT get_free_slot();
|
||||||
return slot->create_inference(id, weights_path, p);
|
LM_CORETURN slot->create_inference(id, weights_path, p);
|
||||||
}
|
}
|
||||||
std::shared_ptr<Inference> get_inference(size_t id);
|
LM_SCHEDULABLE(std::shared_ptr<Inference>) get_inference(size_t id);
|
||||||
std::shared_ptr<Inference> get_or_create_inference(size_t id, const std::string& weights_path, const Inference::Params& p);
|
LM_SCHEDULABLE(std::shared_ptr<Inference>) get_or_create_inference(size_t id, const std::string& weights_path, const Inference::Params& p);
|
||||||
void delete_inference(size_t id);
|
LM_SCHEDULABLE(void) delete_inference(size_t id);
|
||||||
void store_all();
|
LM_SCHEDULABLE(void) store_all();
|
||||||
std::vector<size_t> get_active_slot_ids() const;
|
std::vector<size_t> get_active_slot_ids() const;
|
||||||
|
|
||||||
void cleanup();
|
void cleanup();
|
||||||
|
|
74
justlm.cpp
74
justlm.cpp
|
@ -1,66 +1,30 @@
|
||||||
#include "justlm.hpp"
|
#include "justlm.hpp"
|
||||||
#include "dlhandle.hpp"
|
#include "justlm_llama.hpp"
|
||||||
|
#include "justlm_gptj.hpp"
|
||||||
|
#ifdef LM_MPT
|
||||||
|
# include "justlm_mpt.hpp"
|
||||||
|
#endif
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <filesystem>
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
static
|
|
||||||
Dlhandle get_implementation(std::ifstream& input_f) {
|
|
||||||
Dlhandle matching;
|
|
||||||
Dlhandle fallback;
|
|
||||||
// Iterate over all libraries
|
|
||||||
for (const auto& f : std::filesystem::directory_iterator(".")) {
|
|
||||||
// Get path
|
|
||||||
const auto& p = f.path();
|
|
||||||
// Check extension
|
|
||||||
if (p.extension() != LIB_FILE_EXT) continue;
|
|
||||||
// Load library
|
|
||||||
try {
|
|
||||||
Dlhandle dl(p);
|
|
||||||
// Get implementation info getter
|
|
||||||
auto implementation_getter = dl.get<const LM::Implementation *()>("get_justlm_implementation");
|
|
||||||
if (!implementation_getter) continue;
|
|
||||||
// Get implementation info
|
|
||||||
const auto *implementation_info = implementation_getter();
|
|
||||||
// Set if fallback
|
|
||||||
if (implementation_info->is_fallback) {
|
|
||||||
fallback = std::move(dl);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
// Set if matching magic
|
|
||||||
input_f.seekg(0);
|
|
||||||
auto magic_match = dl.get<bool(std::ifstream&)>("magic_match");
|
|
||||||
if (magic_match && magic_match(input_f)) {
|
|
||||||
matching = std::move(dl);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
} catch (...) {}
|
|
||||||
}
|
|
||||||
// Return matching if any, fallback otherwise
|
|
||||||
if (matching) return matching;
|
|
||||||
return fallback;
|
|
||||||
}
|
|
||||||
|
|
||||||
LM::Inference *LM::Inference::construct(const std::string &weights_path, const Params &p) {
|
LM::Inference *LM::Inference::construct(const std::string &weights_path, const Params &p) {
|
||||||
static std::vector<Dlhandle> dls;
|
|
||||||
// Read magic
|
// Read magic
|
||||||
std::ifstream f(weights_path, std::ios::binary);
|
std::ifstream f(weights_path, std::ios::binary);
|
||||||
if (!f) {
|
uint32_t magic;
|
||||||
throw Exception("Failed to open weights file for reading at "+weights_path);
|
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||||
|
// Create inference instance
|
||||||
|
if (magic == 0x67676d6c) {
|
||||||
|
f.seekg(0);
|
||||||
|
return new GPTJInference(weights_path, f, p);
|
||||||
|
# ifdef LM_MPT
|
||||||
|
} else if (magic == 0x67676d6d) {
|
||||||
|
f.seekg(0);
|
||||||
|
return new MPTInference(weights_path, f, p);
|
||||||
|
# endif
|
||||||
|
} else {
|
||||||
|
f.close();
|
||||||
|
return new LLaMaInference(weights_path, p);
|
||||||
}
|
}
|
||||||
// Get correct implementation
|
|
||||||
auto impl = get_implementation(f);
|
|
||||||
if (!impl) return nullptr;
|
|
||||||
// Get inference constructor
|
|
||||||
auto constructor = impl.get<LM::Inference *(const std::string &, std::ifstream&, const LM::Inference::Params &)>("construct");
|
|
||||||
if (!constructor) return nullptr;
|
|
||||||
// Back up Dlhandle
|
|
||||||
dls.push_back(std::move(impl));
|
|
||||||
// Construct inference
|
|
||||||
f.seekg(0);
|
|
||||||
return constructor(weights_path, f, p);
|
|
||||||
}
|
}
|
||||||
|
|
108
justlm_gptj.hpp
108
justlm_gptj.hpp
|
@ -4,7 +4,7 @@
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include "gptj/gptj.hpp"
|
#include "gptj/gptj.hpp"
|
||||||
#include "g4a_common.hpp"
|
#include "g4a-common.hpp"
|
||||||
|
|
||||||
|
|
||||||
namespace LM {
|
namespace LM {
|
||||||
|
@ -53,18 +53,19 @@ class GPTJInference final : public Inference {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
|
|
||||||
if (state) {
|
if (state) {
|
||||||
|
if (state->model.ctx) ggml_free(state->model.ctx); //TODO: Is that enough?
|
||||||
delete state;
|
delete state;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function reduces the size of our tokens vector according to some parameters
|
// This function reduces the size of our tokens vector according to some parameters
|
||||||
// All tokens will be evaluated if scrolling was needed and true will be returned
|
// All tokens will be evaluated if scrolling was needed and true will be returned
|
||||||
bool window_scroll() LM_NOEXCEPTDECL {
|
LM_SCHEDULABLE(bool) window_scroll() LM_NOEXCEPTDECL {
|
||||||
auto &state = get_state();
|
auto &state = get_state();
|
||||||
// Check that we actually need to scroll
|
// Check that we actually need to scroll
|
||||||
if (state->tokens.size() <= params.n_ctx) {
|
if (state->tokens.size() <= params.n_ctx) {
|
||||||
// Nope
|
// Nope
|
||||||
return false;
|
LM_CORETURN false;
|
||||||
}
|
}
|
||||||
// Start scrolling
|
// Start scrolling
|
||||||
if (params.scroll_keep > 0.0f) {
|
if (params.scroll_keep > 0.0f) {
|
||||||
|
@ -81,11 +82,11 @@ class GPTJInference final : public Inference {
|
||||||
state->tokens.resize(params.n_ctx_window_top_bar);
|
state->tokens.resize(params.n_ctx_window_top_bar);
|
||||||
}
|
}
|
||||||
// Evaluate tokens
|
// Evaluate tokens
|
||||||
LM_ERROR_FORWARD(evaluate_tokens(0, on_scroll), LM_BOOL_ERROR);
|
LM_ERROR_FORWARD(LM_COAWAIT evaluate_tokens(0, on_scroll));
|
||||||
return true;
|
LM_CORETURN true;
|
||||||
}
|
}
|
||||||
|
|
||||||
LM_ERRBOOL evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick = nullptr) LM_NOEXCEPTDECL {
|
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
|
|
||||||
// Evaluate tokens in batches
|
// Evaluate tokens in batches
|
||||||
|
@ -96,7 +97,7 @@ class GPTJInference final : public Inference {
|
||||||
// Evaluate
|
// Evaluate
|
||||||
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+params.n_batch);
|
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+params.n_batch);
|
||||||
if (!gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
|
if (!gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
|
||||||
LM_THROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tick
|
// Tick
|
||||||
|
@ -104,7 +105,8 @@ class GPTJInference final : public Inference {
|
||||||
// Calculate progress
|
// Calculate progress
|
||||||
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
|
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
|
||||||
// Tick and yield
|
// Tick and yield
|
||||||
if (!on_tick(progress)) return LM_BOOL_SUCCESS;
|
if (!on_tick(progress)) LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
|
else if (!LM_TASKYIELD) LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -114,7 +116,7 @@ class GPTJInference final : public Inference {
|
||||||
//TODO: This is extremely inefficient! Don't do that...
|
//TODO: This is extremely inefficient! Don't do that...
|
||||||
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+1);
|
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+1);
|
||||||
if (!gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
|
if (!gptj_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
|
||||||
LM_THROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -122,7 +124,7 @@ class GPTJInference final : public Inference {
|
||||||
// Notify about completion
|
// Notify about completion
|
||||||
if (on_tick) on_tick(100.f);
|
if (on_tick) on_tick(100.f);
|
||||||
|
|
||||||
return LM_BOOL_SUCCESS;
|
LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -133,7 +135,7 @@ public:
|
||||||
deinit();
|
deinit();
|
||||||
}
|
}
|
||||||
|
|
||||||
LM_ERRBOOL append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
|
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
|
|
||||||
// Append to current prompt
|
// Append to current prompt
|
||||||
|
@ -151,123 +153,119 @@ public:
|
||||||
);
|
);
|
||||||
|
|
||||||
// Make sure token limit isn't being hit
|
// Make sure token limit isn't being hit
|
||||||
if (window_scroll()) {
|
if (LM_COAWAIT window_scroll()) {
|
||||||
// That function already has evaluated our tokens since scrolling was needed
|
// That function already has evaluated our tokens since scrolling was needed
|
||||||
return LM_BOOL_SUCCESS;
|
LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Evaluate new tokens
|
// Evaluate new tokens
|
||||||
return evaluate_tokens(old_token_count, on_tick);
|
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
|
LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
std::string fres;
|
std::string fres;
|
||||||
|
|
||||||
// Loop until done
|
// Loop until done
|
||||||
bool abort = false;
|
bool abort = false;
|
||||||
unsigned eos_count = 0;
|
unsigned eos_count = 0;
|
||||||
size_t last_size = 0;
|
while (!abort && !ends_with(fres, end)) {
|
||||||
while (!abort && (end.empty() || fres.find(end) == fres.npos)) {
|
|
||||||
last_size = fres.size();
|
|
||||||
// Sample top p and top k
|
// Sample top p and top k
|
||||||
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
|
auto id = gpt_sample_top_k_top_p(state->vocab, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-params.n_repeat_last):nullptr, params.n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
|
||||||
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
|
|
||||||
|
|
||||||
if (id == 50256) {
|
if (id == 50256) {
|
||||||
if (eos_count++ == params.n_eos_ignores) {
|
if (eos_count++ == params.eos_ignores) {
|
||||||
abort = true;
|
abort = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
id = gpt_tokenize(state->vocab, "\n")[0];
|
id = gpt_tokenize(state->vocab, "\n")[0];
|
||||||
|
state->tokens.push_back(id);
|
||||||
|
} else {
|
||||||
|
// Add token
|
||||||
|
state->tokens.push_back(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add token
|
|
||||||
state->tokens.push_back(id);
|
|
||||||
|
|
||||||
// Make sure token limit isn't being hit
|
// Make sure token limit isn't being hit
|
||||||
window_scroll();
|
LM_COAWAIT window_scroll();
|
||||||
|
|
||||||
// Get token as string
|
// Get token as string
|
||||||
const std::string_view str = state->vocab.id_to_token[id];
|
const auto str = state->vocab.id_to_token[id];
|
||||||
|
|
||||||
// Append string to function result
|
// Append string to function result
|
||||||
state->prompt.append(str);
|
|
||||||
fres.append(str);
|
fres.append(str);
|
||||||
|
|
||||||
if (pre_tick && !pre_tick(str.data())) abort = true;
|
// Evaluate token
|
||||||
else {
|
// TODO: Respect batch size
|
||||||
// Evaluate token
|
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
|
||||||
// TODO: Respect batch size
|
if (!gptj_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
|
||||||
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
|
LM_COTHROW("Failed to evaluate new tokens", "");
|
||||||
if (!gptj_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
|
|
||||||
LM_THROW("Failed to evaluate new tokens", "");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tick
|
// Tick
|
||||||
if (on_tick && !on_tick(str.data())) abort = true;
|
if (on_tick && !on_tick(str.c_str())) abort = true;
|
||||||
|
else if (!LM_TASKYIELD) abort = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create final string TODO: Could be optimized
|
// Create final string TODO: Could be optimized
|
||||||
|
state->prompt.append(fres);
|
||||||
if (!abort) {
|
if (!abort) {
|
||||||
fres = std::string(fres.data(), last_size);
|
fres = std::string(fres.data(), fres.size()-end.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return final string
|
// Return final string
|
||||||
return fres;
|
LM_CORETURN fres;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned get_context_size() const noexcept override {
|
unsigned get_context_size() const noexcept override {
|
||||||
return get_state()->tokens.size();
|
return get_state()->tokens.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
LM_ERRBOOL create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
|
LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
sv.buf.resize(gptj_get_state_size(state->model));
|
sv.buf.resize(gptj_get_state_size(state->model));
|
||||||
gptj_copy_state_data(state->model, state->rng, sv.buf.data());
|
gptj_copy_state_data(state->model, state->rng, sv.buf.data());
|
||||||
sv.tokens = state->tokens;
|
sv.tokens = state->tokens;
|
||||||
sv.prompt = state->prompt;
|
sv.prompt = state->prompt;
|
||||||
sv.ctx = generic_state;
|
sv.ctx = generic_state;
|
||||||
return LM_BOOL_SUCCESS;
|
LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
LM_ERRBOOL restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
|
LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
if (sv.ctx != generic_state)
|
if (sv.ctx != generic_state)
|
||||||
LM_THROW("Savestate does not match context", LM_BOOL_ERROR);
|
LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR);
|
||||||
gptj_set_state_data(&state->model, &state->rng, sv.buf.data());
|
gptj_set_state_data(&state->model, &state->rng, sv.buf.data());
|
||||||
state->tokens = sv.tokens;
|
state->tokens = sv.tokens;
|
||||||
state->prompt = sv.prompt;
|
state->prompt = sv.prompt;
|
||||||
return LM_BOOL_SUCCESS;
|
LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
LM_ERRBOOL serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
|
LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
// Get state size
|
// Get state size
|
||||||
auto state_size = gptj_get_state_size(state->model);
|
auto state_size = gptj_get_state_size(state->model);
|
||||||
// Write sizes
|
// Write sizes
|
||||||
for (const uint32_t s : {state->tokens.size(), state->prompt.size(), state_size}) {
|
for (const uint32_t s : {state->tokens.size(), state->prompt.size(), state_size}) {
|
||||||
if (!o.write(reinterpret_cast<const char*>(&s), sizeof(s))) {
|
if (!o.write(reinterpret_cast<const char*>(&s), sizeof(s))) {
|
||||||
LM_THROW("Failed to serialize data sizes", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to serialize data sizes", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Write tokens
|
// Write tokens
|
||||||
if (!o.write(reinterpret_cast<const char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
if (!o.write(reinterpret_cast<const char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
||||||
LM_THROW("Failed to serialize tokens", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to serialize tokens", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
// Write prompt
|
// Write prompt
|
||||||
if (!o.write(state->prompt.data(), state->prompt.size())) {
|
if (!o.write(state->prompt.data(), state->prompt.size())) {
|
||||||
LM_THROW("Failed to serialize prompt", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to serialize prompt", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
// Write state
|
// Write state
|
||||||
std::vector<uint8_t> state_buf(state_size);
|
std::vector<uint8_t> state_buf(state_size);
|
||||||
gptj_copy_state_data(state->model, state->rng, state_buf.data());
|
gptj_copy_state_data(state->model, state->rng, state_buf.data());
|
||||||
if (!o.write(reinterpret_cast<const char*>(state_buf.data()), state_size)) {
|
if (!o.write(reinterpret_cast<const char*>(state_buf.data()), state_size)) {
|
||||||
LM_THROW("Failed to serialize state", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to serialize state", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
return LM_BOOL_SUCCESS;
|
LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
LM_ERRBOOL deserialize(std::istream &i) LM_NOEXCEPTDECL override {
|
LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream &i) LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
uint32_t embd_size, prompt_size, state_size;
|
uint32_t embd_size, prompt_size, state_size;
|
||||||
// Initialization to prevent compiler complaints
|
// Initialization to prevent compiler complaints
|
||||||
|
@ -275,26 +273,26 @@ public:
|
||||||
// Read sizes
|
// Read sizes
|
||||||
for (uint32_t *s : {&embd_size, &prompt_size, &state_size}) {
|
for (uint32_t *s : {&embd_size, &prompt_size, &state_size}) {
|
||||||
if (!i.read(reinterpret_cast<char*>(s), sizeof(*s))) {
|
if (!i.read(reinterpret_cast<char*>(s), sizeof(*s))) {
|
||||||
LM_THROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Read tokens
|
// Read tokens
|
||||||
state->tokens.resize(embd_size);
|
state->tokens.resize(embd_size);
|
||||||
if (!i.read(reinterpret_cast<char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
if (!i.read(reinterpret_cast<char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
||||||
LM_THROW("Failed to deserialize tokens", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to deserialize tokens", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
// Read prompt
|
// Read prompt
|
||||||
state->prompt.resize(prompt_size);
|
state->prompt.resize(prompt_size);
|
||||||
if (!i.read(state->prompt.data(), state->prompt.size())) {
|
if (!i.read(state->prompt.data(), state->prompt.size())) {
|
||||||
LM_THROW("Failed to deserialize prompt", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to deserialize prompt", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
// Read state
|
// Read state
|
||||||
std::vector<uint8_t> state_buf(state_size);
|
std::vector<uint8_t> state_buf(state_size);
|
||||||
if (!i.read(reinterpret_cast<char*>(state_buf.data()), state_buf.size())) {
|
if (!i.read(reinterpret_cast<char*>(state_buf.data()), state_buf.size())) {
|
||||||
LM_THROW("Failed to deserialize state", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to deserialize state", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
gptj_set_state_data(&state->model, &state->rng, state_buf.data());
|
gptj_set_state_data(&state->model, &state->rng, state_buf.data());
|
||||||
return LM_BOOL_SUCCESS;
|
LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
|
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
|
||||||
return get_state()->prompt;
|
return get_state()->prompt;
|
||||||
|
|
239
justlm_llama.hpp
239
justlm_llama.hpp
|
@ -3,20 +3,15 @@
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <ggml.h>
|
#include <ggml.h>
|
||||||
#include <llama.h>
|
#include <llama.h>
|
||||||
#include <common/grammar-parser.h>
|
|
||||||
|
|
||||||
|
|
||||||
namespace LM {
|
namespace LM {
|
||||||
class LLaMAInference final : public Inference {
|
class LLaMaInference final : public Inference {
|
||||||
struct State {
|
struct State {
|
||||||
llama_context *ctx = nullptr;
|
llama_context *ctx = nullptr;
|
||||||
llama_model *model;
|
|
||||||
llama_grammar *grammar = nullptr;
|
|
||||||
bool grammar_override_temp;
|
|
||||||
grammar_parser::parse_state parsed_grammar;
|
|
||||||
std::string prompt; // Mostly here for easy "debugging"
|
std::string prompt; // Mostly here for easy "debugging"
|
||||||
std::vector<int> tokens;
|
std::vector<int> tokens;
|
||||||
unsigned n_ctx;
|
int n_ctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
State*& get_state() {
|
State*& get_state() {
|
||||||
|
@ -36,24 +31,12 @@ class LLaMAInference final : public Inference {
|
||||||
auto lparams = llama_context_default_params();
|
auto lparams = llama_context_default_params();
|
||||||
lparams.seed = params.seed;
|
lparams.seed = params.seed;
|
||||||
lparams.n_ctx = params.n_ctx = params.n_ctx>0?params.n_ctx:2024;
|
lparams.n_ctx = params.n_ctx = params.n_ctx>0?params.n_ctx:2024;
|
||||||
lparams.n_threads = params.n_threads;
|
lparams.use_mlock = params.use_mlock;
|
||||||
//lparams.n_threads_batch = params.n_threads; TODO: Is this sane?
|
|
||||||
|
|
||||||
// Get model parameters
|
|
||||||
auto mparams = llama_model_default_params();
|
|
||||||
mparams.use_mlock = params.use_mlock;
|
|
||||||
mparams.n_gpu_layers = params.n_gpu_layers;
|
|
||||||
|
|
||||||
// Load model
|
|
||||||
state->model = llama_load_model_from_file(weights_path.c_str(), mparams);
|
|
||||||
if (!state->model) {
|
|
||||||
LM_THROW("Failed to initialize llama model from file", LM_BOOL_ERROR);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create context
|
// Create context
|
||||||
state->ctx = llama_new_context_with_model(state->model, lparams);
|
state->ctx = llama_init_from_file(weights_path.c_str(), lparams);
|
||||||
if (!state->ctx) {
|
if (!state->ctx) {
|
||||||
LM_THROW("Failed to initialize llama context from model", LM_BOOL_ERROR);
|
LM_THROW("Failed to initialize llama from file", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize some variables
|
// Initialize some variables
|
||||||
|
@ -64,12 +47,12 @@ class LLaMAInference final : public Inference {
|
||||||
|
|
||||||
// This function reduces the size of our tokens vector according to some parameters
|
// This function reduces the size of our tokens vector according to some parameters
|
||||||
// All tokens will be evaluated if scrolling was needed and true will be returned
|
// All tokens will be evaluated if scrolling was needed and true will be returned
|
||||||
bool window_scroll() LM_NOEXCEPTDECL {
|
LM_SCHEDULABLE(bool) window_scroll() LM_NOEXCEPTDECL {
|
||||||
auto &state = get_state();
|
auto &state = get_state();
|
||||||
// Check that we actually need to scroll
|
// Check that we actually need to scroll
|
||||||
if (state->tokens.size() <= state->n_ctx) {
|
if (state->tokens.size() <= state->n_ctx) {
|
||||||
// Nope
|
// Nope
|
||||||
return false;
|
LM_CORETURN false;
|
||||||
}
|
}
|
||||||
// Start scrolling
|
// Start scrolling
|
||||||
if (params.scroll_keep > 0.0f) {
|
if (params.scroll_keep > 0.0f) {
|
||||||
|
@ -86,11 +69,11 @@ class LLaMAInference final : public Inference {
|
||||||
state->tokens.resize(params.n_ctx_window_top_bar);
|
state->tokens.resize(params.n_ctx_window_top_bar);
|
||||||
}
|
}
|
||||||
// Evaluate tokens
|
// Evaluate tokens
|
||||||
LM_ERROR_FORWARD(evaluate_tokens(0, on_scroll), LM_BOOL_ERROR);
|
LM_ERROR_FORWARD(LM_COAWAIT evaluate_tokens(0, on_scroll));
|
||||||
return true;
|
LM_CORETURN true;
|
||||||
}
|
}
|
||||||
|
|
||||||
LM_ERRBOOL evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick = nullptr) LM_NOEXCEPTDECL {
|
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
|
|
||||||
// Evaluate tokens in batches
|
// Evaluate tokens in batches
|
||||||
|
@ -99,9 +82,8 @@ class LLaMAInference final : public Inference {
|
||||||
if (it + params.n_batch >= ssize_t(state->tokens.size())) break;
|
if (it + params.n_batch >= ssize_t(state->tokens.size())) break;
|
||||||
|
|
||||||
// Evaluate
|
// Evaluate
|
||||||
const auto batch = llama_batch_get_one(state->tokens.data()+it, params.n_batch, it, 0);
|
if (llama_eval(state->ctx, state->tokens.data()+it, params.n_batch, it, params.n_threads)) {
|
||||||
if (llama_decode(state->ctx, batch)) {
|
LM_COTHROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
|
||||||
LM_THROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tick
|
// Tick
|
||||||
|
@ -109,16 +91,16 @@ class LLaMAInference final : public Inference {
|
||||||
// Calculate progress
|
// Calculate progress
|
||||||
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
|
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
|
||||||
// Tick and yield
|
// Tick and yield
|
||||||
if (!on_tick(progress)) return LM_BOOL_SUCCESS;
|
if (!on_tick(progress)) LM_BOOL_SUCCESS;
|
||||||
|
else if (!LM_TASKYIELD) LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Evaluate remaining tokens
|
// Evaluate remaining tokens
|
||||||
if (it < state->tokens.size()) {
|
if (it < state->tokens.size()) {
|
||||||
for (; it != state->tokens.size(); it++) {
|
for (; it != state->tokens.size(); it++) {
|
||||||
const auto batch = llama_batch_get_one(state->tokens.data()+it, 1, it, 0);
|
if (llama_eval(state->ctx, state->tokens.data()+it, 1, it, params.n_threads)) {
|
||||||
if (llama_decode(state->ctx, batch)) {
|
LM_COTHROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
|
||||||
LM_THROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -126,69 +108,14 @@ class LLaMAInference final : public Inference {
|
||||||
// Notify about completion
|
// Notify about completion
|
||||||
if (on_tick) on_tick(100.f);
|
if (on_tick) on_tick(100.f);
|
||||||
|
|
||||||
return LM_BOOL_SUCCESS;
|
LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
|
||||||
|
|
||||||
int accept_token(int t) {
|
|
||||||
auto& state = get_state();
|
|
||||||
if (state->grammar)
|
|
||||||
llama_grammar_accept_token(state->ctx, state->grammar, t);
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
|
|
||||||
int llama_sample_top_p_top_k() {
|
|
||||||
auto& state = get_state();
|
|
||||||
auto logits = llama_get_logits(state->ctx);
|
|
||||||
auto n_vocab = llama_n_vocab(state->model);
|
|
||||||
// Populate initial list of all candidates
|
|
||||||
std::vector<llama_token_data> candidates;
|
|
||||||
candidates.reserve(n_vocab);
|
|
||||||
for (int token_id = 0; token_id < n_vocab; token_id++) {
|
|
||||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
|
||||||
}
|
|
||||||
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
|
|
||||||
// Sample repeat penalty
|
|
||||||
auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
|
|
||||||
llama_sample_repetition_penalties(state->ctx, &candidates_p, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-n_repeat_last):nullptr, n_repeat_last, params.repeat_penalty, 1.0f, 1.0f); // Might be wrong
|
|
||||||
// Grammar sampling
|
|
||||||
if (state->grammar) {
|
|
||||||
llama_sample_grammar(state->ctx, &candidates_p, state->grammar);
|
|
||||||
}
|
|
||||||
if (!(state->grammar && state->grammar_override_temp) && (params.temp > 0.01f || params.temp < -0.01f)) {
|
|
||||||
// Temperature sampling
|
|
||||||
switch (params.prefer_mirostat) {
|
|
||||||
case 0: {
|
|
||||||
llama_sample_top_k(state->ctx, &candidates_p, params.top_k, 1);
|
|
||||||
llama_sample_tail_free(state->ctx, &candidates_p, 1.0f, 1);
|
|
||||||
llama_sample_typical(state->ctx, &candidates_p, 1.0f, 1);
|
|
||||||
llama_sample_top_p(state->ctx, &candidates_p, params.top_p, 1);
|
|
||||||
llama_sample_temp(state->ctx, &candidates_p, params.temp);
|
|
||||||
return accept_token(llama_sample_token(state->ctx, &candidates_p));
|
|
||||||
}
|
|
||||||
case 1: {
|
|
||||||
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
|
|
||||||
const int mirostat_m = 100;
|
|
||||||
llama_sample_temp(state->ctx, &candidates_p, params.temp);
|
|
||||||
return accept_token(llama_sample_token_mirostat(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, mirostat_m, &mirostat_mu));
|
|
||||||
}
|
|
||||||
case 2: {
|
|
||||||
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
|
|
||||||
llama_sample_temp(state->ctx, &candidates_p, params.temp);
|
|
||||||
return accept_token(llama_sample_token_mirostat_v2(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, &mirostat_mu));
|
|
||||||
}
|
|
||||||
default: LM_THROW("Invalid mirostat version "+std::to_string(params.prefer_mirostat), LM_BOOL_ERROR);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Greedy sampling
|
|
||||||
return accept_token(llama_sample_token(state->ctx, &candidates_p));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
LLaMAInference(const std::string& weights_path, const Params& p) : Inference(p) {
|
LLaMaInference(const std::string& weights_path, const Params& p) : Inference(p) {
|
||||||
init(weights_path);
|
init(weights_path);
|
||||||
}
|
}
|
||||||
~LLaMAInference() override {
|
~LLaMaInference() override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
|
|
||||||
if (state) {
|
if (state) {
|
||||||
|
@ -197,7 +124,7 @@ public:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
LM_ERRBOOL append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
|
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
|
|
||||||
// Check if prompt was empty
|
// Check if prompt was empty
|
||||||
|
@ -211,44 +138,37 @@ public:
|
||||||
state->tokens.resize(old_token_count+state->prompt.size());
|
state->tokens.resize(old_token_count+state->prompt.size());
|
||||||
|
|
||||||
// Run tokenizer
|
// Run tokenizer
|
||||||
const auto token_count = llama_tokenize(state->model, prompt.c_str(), prompt.size(), state->tokens.data()+old_token_count, state->tokens.size()-old_token_count, was_empty, false);
|
const auto token_count = llama_tokenize(state->ctx, prompt.c_str(), state->tokens.data()+old_token_count, state->tokens.size()-old_token_count, was_empty);
|
||||||
state->tokens.resize(old_token_count+token_count);
|
state->tokens.resize(old_token_count+token_count);
|
||||||
|
|
||||||
// Make sure token limit isn't being hit
|
// Make sure token limit isn't being hit
|
||||||
if (window_scroll()) {
|
if (LM_COAWAIT window_scroll()) {
|
||||||
// That function already has evaluated our tokens since scrolling was needed
|
// That function already has evaluated our tokens since scrolling was needed
|
||||||
return LM_BOOL_SUCCESS;
|
LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Evaluate new tokens
|
// Evaluate new tokens
|
||||||
return evaluate_tokens(old_token_count, on_tick);
|
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
|
LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
std::string fres;
|
std::string fres;
|
||||||
|
|
||||||
// Loop until done
|
// Loop until done
|
||||||
bool abort = false;
|
bool abort = false;
|
||||||
unsigned eos_count = 0;
|
unsigned eos_count = 0;
|
||||||
size_t last_size = 0;
|
while (!abort && !ends_with(fres, end)) {
|
||||||
while (!abort && (end.empty() || fres.find(end) == fres.npos)) {
|
|
||||||
last_size = fres.size();
|
|
||||||
// Sample top p and top k
|
// Sample top p and top k
|
||||||
int id;
|
auto id = llama_sample_top_p_top_k(state->ctx, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-params.n_repeat_last):nullptr, params.n_repeat_last, params.top_k, params.top_p, params.temp, params.repeat_penalty);
|
||||||
try {
|
|
||||||
id = llama_sample_top_p_top_k();
|
|
||||||
} catch (const std::exception& e) {
|
|
||||||
LM_THROW(e.what(), "");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (id == llama_token_eos(state->model)) {
|
if (id == llama_token_eos()) {
|
||||||
if (eos_count++ == params.n_eos_ignores) {
|
if (eos_count++ == params.eos_ignores) {
|
||||||
abort = true;
|
abort = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
state->tokens.push_back(0);
|
state->tokens.push_back(0);
|
||||||
llama_tokenize(state->model, "\n", 1, &state->tokens.back(), 1, false, false);
|
llama_tokenize(state->ctx, "\n", &state->tokens.back(), 1, false);
|
||||||
id = state->tokens.back();
|
id = state->tokens.back();
|
||||||
} else {
|
} else {
|
||||||
// Add token
|
// Add token
|
||||||
|
@ -256,90 +176,85 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make sure token limit isn't hit
|
// Make sure token limit isn't hit
|
||||||
window_scroll();
|
LM_COAWAIT window_scroll();
|
||||||
|
|
||||||
// Get token as string
|
// Get token as string
|
||||||
std::string str(14, ' ');
|
const auto str = llama_token_to_str(state->ctx, id);
|
||||||
str.resize(llama_token_to_piece(state->model, id, str.data(), 14));
|
|
||||||
|
|
||||||
// Append string to function result
|
// Append string to function result
|
||||||
state->prompt.append(str);
|
|
||||||
fres.append(str);
|
fres.append(str);
|
||||||
|
|
||||||
// Tick
|
// Evaluate token
|
||||||
if (pre_tick && !pre_tick(str.data())) abort = true;
|
// TODO: Respect batch size
|
||||||
else {
|
if (llama_eval(state->ctx, state->tokens.data()+state->tokens.size()-1, 1, state->tokens.size()-1, params.n_threads)) {
|
||||||
// Evaluate token
|
LM_COTHROW("Failed to evaluate new tokens", "");
|
||||||
// TODO: Respect batch size
|
|
||||||
const auto batch = llama_batch_get_one(state->tokens.data()+state->tokens.size()-1, 1, state->tokens.size()-1, 0);
|
|
||||||
if (llama_decode(state->ctx, batch)) {
|
|
||||||
LM_THROW("Failed to evaluate new tokens", "");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tick and yield
|
// Tick and yield
|
||||||
if (on_tick && !on_tick(str.data())) abort = true;
|
if (on_tick && !on_tick(str)) abort = true;
|
||||||
|
else if (!LM_TASKYIELD) abort = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create final string TODO: Could be optimized
|
// Create final string TODO: Could be optimized
|
||||||
if (!abort && fres.size() > end.size()) {
|
state->prompt.append(fres);
|
||||||
fres = std::string(fres.data(), last_size);
|
if (!abort) {
|
||||||
|
fres = std::string(fres.data(), fres.size()-end.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return final string
|
// Return final string
|
||||||
return fres;
|
LM_CORETURN fres;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned get_context_size() const noexcept override {
|
unsigned get_context_size() const noexcept override {
|
||||||
return get_state()->tokens.size();
|
return get_state()->tokens.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
LM_ERRBOOL create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
|
LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
sv.buf.resize(llama_get_state_size(state->ctx));
|
sv.buf.resize(llama_get_state_size(state->ctx));
|
||||||
llama_copy_state_data(state->ctx, sv.buf.data());
|
llama_copy_state_data(state->ctx, sv.buf.data());
|
||||||
sv.tokens = state->tokens;
|
sv.tokens = state->tokens;
|
||||||
sv.prompt = state->prompt;
|
sv.prompt = state->prompt;
|
||||||
sv.ctx = generic_state;
|
sv.ctx = generic_state;
|
||||||
return LM_BOOL_SUCCESS;
|
LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
LM_ERRBOOL restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
|
LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
if (sv.ctx != generic_state)
|
if (sv.ctx != generic_state)
|
||||||
LM_THROW("Savestate does not match context", LM_BOOL_ERROR);
|
LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR);
|
||||||
llama_set_state_data(state->ctx, const_cast<uint8_t*>(sv.buf.data()));
|
llama_set_state_data(state->ctx, sv.buf.data());
|
||||||
state->tokens = sv.tokens;
|
state->tokens = sv.tokens;
|
||||||
state->prompt = sv.prompt;
|
state->prompt = sv.prompt;
|
||||||
return LM_BOOL_SUCCESS;
|
LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
LM_ERRBOOL serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
|
LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
// Get state size
|
// Get state size
|
||||||
auto state_size = llama_get_state_size(state->ctx);
|
auto state_size = llama_get_state_size(state->ctx);
|
||||||
// Write sizes
|
// Write sizes
|
||||||
for (const uint32_t s : {static_cast<size_t>(state->n_ctx), state->tokens.size(), state->prompt.size(), state_size}) {
|
for (const uint32_t s : {static_cast<size_t>(state->n_ctx), state->tokens.size(), state->prompt.size(), state_size}) {
|
||||||
if (!o.write(reinterpret_cast<const char*>(&s), sizeof(s))) {
|
if (!o.write(reinterpret_cast<const char*>(&s), sizeof(s))) {
|
||||||
LM_THROW("Failed to serialize data sizes", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to serialize data sizes", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Write tokens
|
// Write tokens
|
||||||
if (!o.write(reinterpret_cast<const char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
if (!o.write(reinterpret_cast<const char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
||||||
LM_THROW("Failed to serialize tokens", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to serialize tokens", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
// Write prompt
|
// Write prompt
|
||||||
if (!o.write(state->prompt.data(), state->prompt.size())) {
|
if (!o.write(state->prompt.data(), state->prompt.size())) {
|
||||||
LM_THROW("Failed to serialize prompt", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to serialize prompt", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
// Write state
|
// Write state
|
||||||
std::vector<uint8_t> state_buf(state_size);
|
std::vector<uint8_t> state_buf(state_size);
|
||||||
llama_copy_state_data(state->ctx, state_buf.data());
|
llama_copy_state_data(state->ctx, state_buf.data());
|
||||||
if (!o.write(reinterpret_cast<const char*>(state_buf.data()), state_size)) {
|
if (!o.write(reinterpret_cast<const char*>(state_buf.data()), state_size)) {
|
||||||
LM_THROW("Failed to serialize state", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to serialize state", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
return LM_BOOL_SUCCESS;
|
LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
LM_ERRBOOL deserialize(std::istream &i) LM_NOEXCEPTDECL override {
|
LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream &i) LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
uint32_t n_ctx, embd_size, prompt_size, state_size;
|
uint32_t n_ctx, embd_size, prompt_size, state_size;
|
||||||
// Initialization to prevent compiler complaints
|
// Initialization to prevent compiler complaints
|
||||||
|
@ -347,65 +262,33 @@ public:
|
||||||
// Read sizes
|
// Read sizes
|
||||||
for (uint32_t *s : {&n_ctx, &embd_size, &prompt_size, &state_size}) {
|
for (uint32_t *s : {&n_ctx, &embd_size, &prompt_size, &state_size}) {
|
||||||
if (!i.read(reinterpret_cast<char*>(s), sizeof(*s))) {
|
if (!i.read(reinterpret_cast<char*>(s), sizeof(*s))) {
|
||||||
LM_THROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (state->n_ctx != n_ctx) {
|
if (state->n_ctx != n_ctx) {
|
||||||
LM_THROW("Context length differs (My "+std::to_string(state->n_ctx)+" vs. files "+std::to_string(n_ctx)+')', LM_BOOL_ERROR);
|
LM_COTHROW("Context length differs (My "+std::to_string(state->n_ctx)+" vs. files "+std::to_string(n_ctx)+')', LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
// Read tokens
|
// Read tokens
|
||||||
state->tokens.resize(embd_size);
|
state->tokens.resize(embd_size);
|
||||||
if (!i.read(reinterpret_cast<char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
if (!i.read(reinterpret_cast<char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
||||||
LM_THROW("Failed to deserialize tokens", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to deserialize tokens", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
// Read prompt
|
// Read prompt
|
||||||
state->prompt.resize(prompt_size);
|
state->prompt.resize(prompt_size);
|
||||||
if (!i.read(state->prompt.data(), state->prompt.size())) {
|
if (!i.read(state->prompt.data(), state->prompt.size())) {
|
||||||
LM_THROW("Failed to deserialize prompt", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to deserialize prompt", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
// Read state
|
// Read state
|
||||||
std::vector<uint8_t> state_buf(state_size);
|
std::vector<uint8_t> state_buf(state_size);
|
||||||
if (!i.read(reinterpret_cast<char*>(state_buf.data()), state_buf.size())) {
|
if (!i.read(reinterpret_cast<char*>(state_buf.data()), state_buf.size())) {
|
||||||
LM_THROW("Failed to deserialize state", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to deserialize state", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
llama_set_state_data(state->ctx, state_buf.data());
|
llama_set_state_data(state->ctx, state_buf.data());
|
||||||
return LM_BOOL_SUCCESS;
|
LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
|
||||||
|
|
||||||
LM_ERRBOOL load_grammar(const std::string& src, bool override_temperature) LM_NOEXCEPTDECL override {
|
|
||||||
auto& state = get_state();
|
|
||||||
|
|
||||||
state->parsed_grammar = grammar_parser::parse(src.c_str());
|
|
||||||
if (state->parsed_grammar.rules.empty()) {
|
|
||||||
LM_THROW("Failed to parse grammar (or no rules)", LM_BOOL_ERROR);
|
|
||||||
}
|
|
||||||
|
|
||||||
auto rules = state->parsed_grammar.c_rules();
|
|
||||||
state->grammar = llama_grammar_init(rules.data(), rules.size(), state->parsed_grammar.symbol_ids.at("root"));
|
|
||||||
if (!state->grammar) {
|
|
||||||
LM_THROW("Failed to generate llama grammar", LM_BOOL_ERROR);
|
|
||||||
}
|
|
||||||
|
|
||||||
state->grammar_override_temp = override_temperature;
|
|
||||||
|
|
||||||
return LM_BOOL_SUCCESS;
|
|
||||||
}
|
|
||||||
LM_ERRBOOL unload_grammar() LM_NOEXCEPTDECL override {
|
|
||||||
get_state()->grammar = nullptr;
|
|
||||||
|
|
||||||
return LM_BOOL_SUCCESS;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
|
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
|
||||||
return get_state()->prompt;
|
return get_state()->prompt;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_mirostat_available() const noexcept override {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool is_grammar_available() const noexcept override {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
145
justlm_mpt.hpp
145
justlm_mpt.hpp
|
@ -4,7 +4,7 @@
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include "mpt/mpt.hpp"
|
#include "mpt/mpt.hpp"
|
||||||
#include "g4a_common.hpp"
|
#include "g4a-common.hpp"
|
||||||
|
|
||||||
|
|
||||||
namespace LM {
|
namespace LM {
|
||||||
|
@ -12,14 +12,13 @@ class MPTInference final : public Inference {
|
||||||
std::string weights_path;
|
std::string weights_path;
|
||||||
|
|
||||||
struct State {
|
struct State {
|
||||||
gpt_vocab vocab;
|
mpt_vocab vocab;
|
||||||
mpt_model model;
|
mpt_model model;
|
||||||
std::string prompt; // Mostly here for easy "debugging"
|
std::string prompt; // Mostly here for easy "debugging"
|
||||||
std::vector<int> tokens;
|
std::vector<int> tokens;
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
size_t mem_per_token = 0;
|
size_t mem_per_token = 0;
|
||||||
std::mt19937 rng;
|
std::mt19937 rng;
|
||||||
int im_end = 0;
|
|
||||||
|
|
||||||
State(int32_t seed) : rng(seed) {}
|
State(int32_t seed) : rng(seed) {}
|
||||||
};
|
};
|
||||||
|
@ -48,32 +47,25 @@ class MPTInference final : public Inference {
|
||||||
static std::vector<gpt_vocab::id> r_instruct;
|
static std::vector<gpt_vocab::id> r_instruct;
|
||||||
mpt_eval(state->model, params.n_threads, 0, { 0, 1, 2, 3 }, state->logits, state->mem_per_token);
|
mpt_eval(state->model, params.n_threads, 0, { 0, 1, 2, 3 }, state->logits, state->mem_per_token);
|
||||||
|
|
||||||
// Find im_end token
|
|
||||||
{
|
|
||||||
auto res = state->vocab.token_to_id.find("<|im_end|>");
|
|
||||||
if (res != state->vocab.token_to_id.end()) {
|
|
||||||
state->im_end = res->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return LM_BOOL_SUCCESS;
|
return LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
void deinit() LM_NOEXCEPTDECL {
|
void deinit() LM_NOEXCEPTDECL {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
|
|
||||||
if (state) {
|
if (state) {
|
||||||
|
if (state->model.ctx) ggml_free(state->model.ctx); //TODO: Is that enough?
|
||||||
delete state;
|
delete state;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This function reduces the size of our tokens vector according to some parameters
|
// This function reduces the size of our tokens vector according to some parameters
|
||||||
// All tokens will be evaluated if scrolling was needed and true will be returned
|
// All tokens will be evaluated if scrolling was needed and true will be returned
|
||||||
bool window_scroll() LM_NOEXCEPTDECL {
|
LM_SCHEDULABLE(bool) window_scroll() LM_NOEXCEPTDECL {
|
||||||
auto &state = get_state();
|
auto &state = get_state();
|
||||||
// Check that we actually need to scroll
|
// Check that we actually need to scroll
|
||||||
if (state->tokens.size() <= params.n_ctx) {
|
if (state->tokens.size() <= params.n_ctx) {
|
||||||
// Nope
|
// Nope
|
||||||
return false;
|
LM_CORETURN false;
|
||||||
}
|
}
|
||||||
// Start scrolling
|
// Start scrolling
|
||||||
if (params.scroll_keep > 0.0f) {
|
if (params.scroll_keep > 0.0f) {
|
||||||
|
@ -90,11 +82,11 @@ class MPTInference final : public Inference {
|
||||||
state->tokens.resize(params.n_ctx_window_top_bar);
|
state->tokens.resize(params.n_ctx_window_top_bar);
|
||||||
}
|
}
|
||||||
// Evaluate tokens
|
// Evaluate tokens
|
||||||
LM_ERROR_FORWARD(evaluate_tokens(0, on_scroll), LM_BOOL_ERROR);
|
LM_ERROR_FORWARD(LM_COAWAIT evaluate_tokens(0, on_scroll));
|
||||||
return true;
|
LM_CORETURN true;
|
||||||
}
|
}
|
||||||
|
|
||||||
LM_ERRBOOL evaluate_tokens(size_t starting_offset, const AppendCallback &on_tick) LM_NOEXCEPTDECL {
|
LM_SCHEDULABLE(LM_ERRBOOL) evaluate_tokens(size_t starting_offset, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
|
|
||||||
// Evaluate tokens in batches
|
// Evaluate tokens in batches
|
||||||
|
@ -105,7 +97,7 @@ class MPTInference final : public Inference {
|
||||||
// Evaluate
|
// Evaluate
|
||||||
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+params.n_batch);
|
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+params.n_batch);
|
||||||
if (!mpt_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
|
if (!mpt_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
|
||||||
LM_THROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to evaluate tokens in batches", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tick
|
// Tick
|
||||||
|
@ -113,7 +105,8 @@ class MPTInference final : public Inference {
|
||||||
// Calculate progress
|
// Calculate progress
|
||||||
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
|
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
|
||||||
// Tick and yield
|
// Tick and yield
|
||||||
if (!on_tick(progress)) return LM_BOOL_SUCCESS;
|
if (!on_tick(progress)) LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
|
else if (!LM_TASKYIELD) LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -123,7 +116,7 @@ class MPTInference final : public Inference {
|
||||||
//TODO: This is extremely inefficient! Don't do that...
|
//TODO: This is extremely inefficient! Don't do that...
|
||||||
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+1);
|
std::vector<int> batch(state->tokens.begin()+it, state->tokens.begin()+it+1);
|
||||||
if (!mpt_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
|
if (!mpt_eval(state->model, params.n_threads, it, batch, state->logits, state->mem_per_token)) {
|
||||||
LM_THROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to evaluate individual tokens", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -131,7 +124,7 @@ class MPTInference final : public Inference {
|
||||||
// Notify about completion
|
// Notify about completion
|
||||||
if (on_tick) on_tick(100.f);
|
if (on_tick) on_tick(100.f);
|
||||||
|
|
||||||
return LM_BOOL_SUCCESS;
|
LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
@ -142,7 +135,7 @@ public:
|
||||||
deinit();
|
deinit();
|
||||||
}
|
}
|
||||||
|
|
||||||
LM_ERRBOOL append(const std::string& prompt, const AppendCallback &on_tick) LM_NOEXCEPTDECL override {
|
LM_SCHEDULABLE(LM_ERRBOOL) append(const std::string& prompt, const std::function<bool (float)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
|
|
||||||
// Append to current prompt
|
// Append to current prompt
|
||||||
|
@ -152,7 +145,7 @@ public:
|
||||||
const auto old_token_count = state->tokens.size();
|
const auto old_token_count = state->tokens.size();
|
||||||
|
|
||||||
// Run tokenizer
|
// Run tokenizer
|
||||||
const auto tokens = gpt_tokenize(state->vocab, prompt);
|
const auto tokens = mpt_tokenize(state->vocab, prompt);
|
||||||
state->tokens.insert(
|
state->tokens.insert(
|
||||||
state->tokens.end(),
|
state->tokens.end(),
|
||||||
std::make_move_iterator(tokens.begin()),
|
std::make_move_iterator(tokens.begin()),
|
||||||
|
@ -160,130 +153,132 @@ public:
|
||||||
);
|
);
|
||||||
|
|
||||||
// Make sure token limit isn't being hit
|
// Make sure token limit isn't being hit
|
||||||
if (window_scroll()) {
|
if (LM_COAWAIT window_scroll()) {
|
||||||
// That function already has evaluated our tokens since scrolling was needed
|
// That function already has evaluated our tokens since scrolling was needed
|
||||||
return LM_BOOL_SUCCESS;
|
LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Evaluate new tokens
|
// Evaluate new tokens
|
||||||
return evaluate_tokens(old_token_count, on_tick);
|
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string run(std::string_view end, const GenerateCallback &on_tick, const GenerateCallback& pre_tick) LM_NOEXCEPTDECL override {
|
/*mpt_vocab::id mpt_sample_top_k_top_p(
|
||||||
|
const mpt_vocab & vocab,
|
||||||
|
const size_t actualVocabSize,
|
||||||
|
const int32_t * last_n_tokens_data,
|
||||||
|
int last_n_tokens_size,
|
||||||
|
const std::vector<float> logits,
|
||||||
|
int top_k,
|
||||||
|
double top_p,
|
||||||
|
double temp,
|
||||||
|
float repeat_penalty,
|
||||||
|
std::mt19937 & rng)
|
||||||
|
*/
|
||||||
|
|
||||||
|
LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
std::string fres;
|
std::string fres;
|
||||||
|
|
||||||
// Loop until done
|
// Loop until done
|
||||||
bool abort = false;
|
bool abort = false;
|
||||||
unsigned eos_count = 0;
|
unsigned eos_count = 0;
|
||||||
size_t last_size = 0;
|
while (!abort && !ends_with(fres, end)) {
|
||||||
while (!abort && (end.empty() || fres.find(end) == fres.npos)) {
|
|
||||||
last_size = fres.size();
|
|
||||||
// Sample top p and top k
|
// Sample top p and top k
|
||||||
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
|
auto id = mpt_sample_top_k_top_p(state->vocab, state->model.hparams.n_vocab, state->tokens.data(), state->tokens.size(), state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
|
||||||
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
|
|
||||||
|
|
||||||
if (state->im_end && id == state->im_end) {
|
if (id == state->vocab.token_to_id["<|im_end|>"]) {
|
||||||
if (eos_count++ == params.n_eos_ignores) {
|
if (eos_count++ == params.eos_ignores) {
|
||||||
abort = true;
|
abort = true;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
id = gpt_tokenize(state->vocab, "\n")[0];
|
id = mpt_tokenize(state->vocab, "\n")[0];
|
||||||
} else if (id == 0) {
|
state->tokens.push_back(id);
|
||||||
if (eos_count++ == params.n_eos_ignores) {
|
} else {
|
||||||
abort = true;
|
// Add token
|
||||||
continue;
|
state->tokens.push_back(id);
|
||||||
}
|
|
||||||
id = gpt_tokenize(state->vocab, "\n")[0];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add token
|
|
||||||
state->tokens.push_back(id);
|
|
||||||
|
|
||||||
// Make sure token limit isn't being hit
|
// Make sure token limit isn't being hit
|
||||||
window_scroll();
|
LM_COAWAIT window_scroll();
|
||||||
|
|
||||||
// Get token as string
|
// Get token as string
|
||||||
const std::string_view str = state->vocab.id_to_token[id];
|
const auto str = state->vocab.id_to_token[id];
|
||||||
|
|
||||||
// Append string to function result
|
// Append string to function result
|
||||||
fres.append(str);
|
fres.append(str);
|
||||||
state->prompt.append(str);
|
|
||||||
|
|
||||||
// Tick
|
// Evaluate token
|
||||||
if (pre_tick && !pre_tick(str.data())) abort = true;
|
// TODO: Respect batch size
|
||||||
else {
|
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
|
||||||
// Evaluate token
|
if (!mpt_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
|
||||||
// TODO: Respect batch size
|
LM_COTHROW("Failed to evaluate new tokens", "");
|
||||||
std::vector<int> batch(state->tokens.begin()+state->tokens.size()-1, state->tokens.begin()+state->tokens.size());
|
|
||||||
if (!mpt_eval(state->model, params.n_threads, state->tokens.size()-1, batch, state->logits, state->mem_per_token)) {
|
|
||||||
LM_THROW("Failed to evaluate new tokens", "");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tick
|
// Tick
|
||||||
if (on_tick && !on_tick(str.data())) abort = true;
|
if (on_tick && !on_tick(str.c_str())) abort = true;
|
||||||
|
else if (!LM_TASKYIELD) abort = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create final string TODO: Could be optimized
|
// Create final string TODO: Could be optimized
|
||||||
|
state->prompt.append(fres);
|
||||||
if (!abort) {
|
if (!abort) {
|
||||||
fres = std::string(fres.data(), last_size);
|
fres = std::string(fres.data(), fres.size()-end.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return final string
|
// Return final string
|
||||||
return fres;
|
LM_CORETURN fres;
|
||||||
}
|
}
|
||||||
|
|
||||||
unsigned get_context_size() const noexcept override {
|
unsigned get_context_size() const noexcept override {
|
||||||
return get_state()->tokens.size();
|
return get_state()->tokens.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
LM_ERRBOOL create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
|
LM_SCHEDULABLE(LM_ERRBOOL) create_savestate(Savestate &sv) const LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
sv.buf.resize(mpt_get_state_size(state->model));
|
sv.buf.resize(mpt_get_state_size(state->model));
|
||||||
mpt_copy_state_data(state->model, state->rng, sv.buf.data());
|
mpt_copy_state_data(state->model, state->rng, sv.buf.data());
|
||||||
sv.tokens = state->tokens;
|
sv.tokens = state->tokens;
|
||||||
sv.prompt = state->prompt;
|
sv.prompt = state->prompt;
|
||||||
sv.ctx = generic_state;
|
sv.ctx = generic_state;
|
||||||
return LM_BOOL_SUCCESS ;
|
LM_CORETURN LM_BOOL_SUCCESS ;
|
||||||
}
|
}
|
||||||
LM_ERRBOOL restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
|
LM_SCHEDULABLE(LM_ERRBOOL) restore_savestate(const Savestate &sv) LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
if (sv.ctx != generic_state)
|
if (sv.ctx != generic_state)
|
||||||
LM_THROW("Savestate does not match context", LM_BOOL_ERROR);
|
LM_COTHROW("Savestate does not match context", LM_BOOL_ERROR);
|
||||||
mpt_set_state_data(&state->model, &state->rng, sv.buf.data());
|
mpt_set_state_data(&state->model, &state->rng, sv.buf.data());
|
||||||
state->tokens = sv.tokens;
|
state->tokens = sv.tokens;
|
||||||
state->prompt = sv.prompt;
|
state->prompt = sv.prompt;
|
||||||
return LM_BOOL_SUCCESS;
|
LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
LM_ERRBOOL serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
|
LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream &o) const LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
// Get state size
|
// Get state size
|
||||||
auto state_size = mpt_get_state_size(state->model);
|
auto state_size = mpt_get_state_size(state->model);
|
||||||
// Write sizes
|
// Write sizes
|
||||||
for (const uint32_t s : {state->tokens.size(), state->prompt.size(), state_size}) {
|
for (const uint32_t s : {state->tokens.size(), state->prompt.size(), state_size}) {
|
||||||
if (!o.write(reinterpret_cast<const char*>(&s), sizeof(s))) {
|
if (!o.write(reinterpret_cast<const char*>(&s), sizeof(s))) {
|
||||||
LM_THROW("Failed to serialize data sizes", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to serialize data sizes", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Write tokens
|
// Write tokens
|
||||||
if (!o.write(reinterpret_cast<const char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
if (!o.write(reinterpret_cast<const char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
||||||
LM_THROW("Failed to serialize tokens", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to serialize tokens", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
// Write prompt
|
// Write prompt
|
||||||
if (!o.write(state->prompt.data(), state->prompt.size())) {
|
if (!o.write(state->prompt.data(), state->prompt.size())) {
|
||||||
LM_THROW("Failed to serialize prompt", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to serialize prompt", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
// Write state
|
// Write state
|
||||||
std::vector<uint8_t> state_buf(state_size);
|
std::vector<uint8_t> state_buf(state_size);
|
||||||
mpt_copy_state_data(state->model, state->rng, state_buf.data());
|
mpt_copy_state_data(state->model, state->rng, state_buf.data());
|
||||||
if (!o.write(reinterpret_cast<const char*>(state_buf.data()), state_size)) {
|
if (!o.write(reinterpret_cast<const char*>(state_buf.data()), state_size)) {
|
||||||
LM_THROW("Failed to serialize state", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to serialize state", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
return LM_BOOL_SUCCESS;
|
LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
LM_ERRBOOL deserialize(std::istream &i) LM_NOEXCEPTDECL override {
|
LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream &i) LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
uint32_t embd_size, promptsize, state_size;
|
uint32_t embd_size, promptsize, state_size;
|
||||||
// Initialization to prevent compiler complaints
|
// Initialization to prevent compiler complaints
|
||||||
|
@ -291,26 +286,26 @@ public:
|
||||||
// Read sizes
|
// Read sizes
|
||||||
for (uint32_t *s : {&embd_size, &promptsize, &state_size}) {
|
for (uint32_t *s : {&embd_size, &promptsize, &state_size}) {
|
||||||
if (!i.read(reinterpret_cast<char*>(s), sizeof(*s))) {
|
if (!i.read(reinterpret_cast<char*>(s), sizeof(*s))) {
|
||||||
LM_THROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to deserialize data sizes", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Read tokens
|
// Read tokens
|
||||||
state->tokens.resize(embd_size);
|
state->tokens.resize(embd_size);
|
||||||
if (!i.read(reinterpret_cast<char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
if (!i.read(reinterpret_cast<char*>(state->tokens.data()), state->tokens.size()*sizeof(int))) {
|
||||||
LM_THROW("Failed to deserialize tokens", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to deserialize tokens", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
// Read prompt
|
// Read prompt
|
||||||
state->prompt.resize(promptsize);
|
state->prompt.resize(promptsize);
|
||||||
if (!i.read(state->prompt.data(), state->prompt.size())) {
|
if (!i.read(state->prompt.data(), state->prompt.size())) {
|
||||||
LM_THROW("Failed to deserialize prompt", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to deserialize prompt", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
// Read state
|
// Read state
|
||||||
std::vector<uint8_t> state_buf(state_size);
|
std::vector<uint8_t> state_buf(state_size);
|
||||||
if (!i.read(reinterpret_cast<char*>(state_buf.data()), state_buf.size())) {
|
if (!i.read(reinterpret_cast<char*>(state_buf.data()), state_buf.size())) {
|
||||||
LM_THROW("Failed to deserialize state", LM_BOOL_ERROR);
|
LM_COTHROW("Failed to deserialize state", LM_BOOL_ERROR);
|
||||||
}
|
}
|
||||||
mpt_set_state_data(&state->model, &state->rng, state_buf.data());
|
mpt_set_state_data(&state->model, &state->rng, state_buf.data());
|
||||||
return LM_BOOL_SUCCESS;
|
LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
|
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
|
||||||
return get_state()->prompt;
|
return get_state()->prompt;
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
bool LM::InferencePool::store_slot(Slot &slot) {
|
LM_SCHEDULABLE(bool) LM::InferencePool::store_slot(Slot &slot) {
|
||||||
auto inference = slot.get_inference();
|
auto inference = slot.get_inference();
|
||||||
// Open output file
|
// Open output file
|
||||||
std::ofstream f(get_slot_filename(slot.get_id()), std::ios::binary);
|
std::ofstream f(get_slot_filename(slot.get_id()), std::ios::binary);
|
||||||
|
@ -17,61 +17,61 @@ bool LM::InferencePool::store_slot(Slot &slot) {
|
||||||
f.write(weights_path.data(), weights_path.size());
|
f.write(weights_path.data(), weights_path.size());
|
||||||
// Write params
|
// Write params
|
||||||
if (!f.write(reinterpret_cast<const char*>(&inference->params), sizeof(inference->params))) {
|
if (!f.write(reinterpret_cast<const char*>(&inference->params), sizeof(inference->params))) {
|
||||||
return false;
|
LM_CORETURN false;
|
||||||
}
|
}
|
||||||
// Serialize instance
|
// Serialize instance
|
||||||
try {
|
try {
|
||||||
inference->serialize(f);
|
LM_COAWAIT inference->serialize(f);
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
return false;
|
LM_CORETURN false;
|
||||||
}
|
}
|
||||||
// Return success
|
// Return success
|
||||||
return true;
|
LM_CORETURN true;
|
||||||
}
|
}
|
||||||
|
|
||||||
LM::InferencePool::Slot *LM::InferencePool::load_slot(size_t id, Slot *suggested_slot) {
|
LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::load_slot(size_t id, Slot *suggested_slot) {
|
||||||
// Open input file
|
// Open input file
|
||||||
std::ifstream f(get_slot_filename(id), std::ios::binary);
|
std::ifstream f(get_slot_filename(id), std::ios::binary);
|
||||||
if (!f) {
|
if (!f) {
|
||||||
// Does not exist
|
// Does not exist
|
||||||
return nullptr;
|
LM_CORETURN nullptr;
|
||||||
}
|
}
|
||||||
// Read weights path
|
// Read weights path
|
||||||
std::string weights_path;
|
std::string weights_path;
|
||||||
uint32_t weights_path_len;
|
uint32_t weights_path_len;
|
||||||
if (!f.read(reinterpret_cast<char*>(&weights_path_len), sizeof(weights_path_len))) {
|
if (!f.read(reinterpret_cast<char*>(&weights_path_len), sizeof(weights_path_len))) {
|
||||||
return nullptr;
|
LM_CORETURN nullptr;
|
||||||
}
|
}
|
||||||
weights_path.resize(weights_path_len);
|
weights_path.resize(weights_path_len);
|
||||||
if (!f.read(weights_path.data(), weights_path.size())) {
|
if (!f.read(weights_path.data(), weights_path.size())) {
|
||||||
return nullptr;
|
LM_CORETURN nullptr;
|
||||||
}
|
}
|
||||||
// Read params
|
// Read params
|
||||||
LM::Inference::Params p;
|
LM::Inference::Params p;
|
||||||
if (!f.read(reinterpret_cast<char*>(&p), sizeof(p))) {
|
if (!f.read(reinterpret_cast<char*>(&p), sizeof(p))) {
|
||||||
return nullptr;
|
LM_CORETURN nullptr;
|
||||||
}
|
}
|
||||||
// Create instance
|
// Create instance
|
||||||
auto& slot = suggested_slot?*suggested_slot:*(get_free_slot());
|
auto& slot = suggested_slot?*suggested_slot:*(LM_COAWAIT get_free_slot());
|
||||||
auto inference = slot.create_inference(id, weights_path, p);
|
auto inference = slot.create_inference(id, weights_path, p);
|
||||||
// Deserialize instance
|
// Deserialize instance
|
||||||
try {
|
try {
|
||||||
inference->deserialize(f);
|
LM_COAWAIT inference->deserialize(f);
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
slot.reset();
|
slot.reset();
|
||||||
return nullptr;
|
LM_CORETURN nullptr;
|
||||||
}
|
}
|
||||||
// Return final slot
|
// Return final slot
|
||||||
return &slot;
|
LM_CORETURN &slot;
|
||||||
}
|
}
|
||||||
|
|
||||||
LM::InferencePool::Slot *LM::InferencePool::get_free_slot() {
|
LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::get_free_slot() {
|
||||||
// Attempt to find free slot while finding oldest one
|
// Attempt to find free slot while finding oldest one
|
||||||
Slot *oldest = nullptr;
|
Slot *oldest = nullptr;
|
||||||
for (auto& slot : slots) {
|
for (auto& slot : slots) {
|
||||||
// Take free slot
|
// Take free slot
|
||||||
if (slot.is_free()) {
|
if (slot.is_free()) {
|
||||||
return &slot;
|
LM_CORETURN &slot;
|
||||||
}
|
}
|
||||||
// Update oldest
|
// Update oldest
|
||||||
if (oldest == nullptr || slot.get_last_access() < oldest->get_last_access()) {
|
if (oldest == nullptr || slot.get_last_access() < oldest->get_last_access()) {
|
||||||
|
@ -80,17 +80,17 @@ LM::InferencePool::Slot *LM::InferencePool::get_free_slot() {
|
||||||
}
|
}
|
||||||
// Free up oldest slot and take that one
|
// Free up oldest slot and take that one
|
||||||
// Note: Since there has to be at least 1 slot, oldest is never going to be a nullptr
|
// Note: Since there has to be at least 1 slot, oldest is never going to be a nullptr
|
||||||
store_and_reset_slot(*oldest);
|
LM_COAWAIT store_and_reset_slot(*oldest);
|
||||||
return oldest;
|
LM_CORETURN oldest;
|
||||||
}
|
}
|
||||||
|
|
||||||
LM::InferencePool::Slot *LM::InferencePool::find_slot_by_id(size_t id, bool deserialize) {
|
LM_SCHEDULABLE(LM::InferencePool::Slot*) LM::InferencePool::find_slot_by_id(size_t id, bool deserialize) {
|
||||||
// Attempt to find given slot while finding oldest one
|
// Attempt to find given slot while finding oldest one
|
||||||
Slot *oldest = nullptr;
|
Slot *oldest = nullptr;
|
||||||
for (auto& slot : slots) {
|
for (auto& slot : slots) {
|
||||||
// Take slot with ID
|
// Take slot with ID
|
||||||
if (slot.get_id() == id) {
|
if (slot.get_id() == id) {
|
||||||
return &slot;
|
LM_CORETURN &slot;
|
||||||
}
|
}
|
||||||
// Update oldest
|
// Update oldest
|
||||||
if (oldest == nullptr || slot.get_last_access() < oldest->get_last_access()) {
|
if (oldest == nullptr || slot.get_last_access() < oldest->get_last_access()) {
|
||||||
|
@ -99,38 +99,38 @@ LM::InferencePool::Slot *LM::InferencePool::find_slot_by_id(size_t id, bool dese
|
||||||
}
|
}
|
||||||
// Slot not found, attempt to load it
|
// Slot not found, attempt to load it
|
||||||
if (deserialize) {
|
if (deserialize) {
|
||||||
if (!oldest->is_free()) store_slot(*oldest);
|
if (!oldest->is_free()) LM_COAWAIT store_slot(*oldest);
|
||||||
if (!load_slot(id, oldest)) {
|
if (!LM_COAWAIT load_slot(id, oldest)) {
|
||||||
// In case slot loading failed, still reset slot for later use
|
// In case slot loading failed, still reset slot for later use
|
||||||
//TODO: Make this configurable
|
//TODO: Make this configurable
|
||||||
oldest->reset();
|
oldest->reset();
|
||||||
} else {
|
} else {
|
||||||
return oldest;
|
LM_CORETURN oldest;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Slot not found
|
// Slot not found
|
||||||
return nullptr;
|
LM_CORETURN nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<LM::Inference> LM::InferencePool::get_inference(size_t id) {
|
LM_SCHEDULABLE(std::shared_ptr<LM::Inference>) LM::InferencePool::get_inference(size_t id) {
|
||||||
auto slot = find_slot_by_id(id);
|
auto slot = LM_COAWAIT find_slot_by_id(id);
|
||||||
if (slot) {
|
if (slot) {
|
||||||
return slot->get_inference(true);
|
LM_CORETURN slot->get_inference(true);
|
||||||
}
|
}
|
||||||
return {};
|
LM_CORETURN {};
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<LM::Inference> LM::InferencePool::get_or_create_inference(size_t id, const std::string &weights_path, const Inference::Params &p) {
|
LM_SCHEDULABLE(std::shared_ptr<LM::Inference>) LM::InferencePool::get_or_create_inference(size_t id, const std::string &weights_path, const Inference::Params &p) {
|
||||||
auto slot = find_slot_by_id(id);
|
auto slot = LM_COAWAIT find_slot_by_id(id);
|
||||||
if (slot) {
|
if (slot) {
|
||||||
return slot->get_inference(true);
|
LM_CORETURN slot->get_inference(true);
|
||||||
}
|
}
|
||||||
slot = get_free_slot();
|
slot = LM_COAWAIT get_free_slot();
|
||||||
return slot->create_inference(id, weights_path, p);
|
LM_CORETURN slot->create_inference(id, weights_path, p);
|
||||||
}
|
}
|
||||||
|
|
||||||
void LM::InferencePool::delete_inference(size_t id) {
|
LM_SCHEDULABLE(void) LM::InferencePool::delete_inference(size_t id) {
|
||||||
auto slot = find_slot_by_id(id, false);
|
auto slot = LM_COAWAIT find_slot_by_id(id, false);
|
||||||
// Reset slot
|
// Reset slot
|
||||||
if (slot) {
|
if (slot) {
|
||||||
slot->reset();
|
slot->reset();
|
||||||
|
@ -140,12 +140,12 @@ void LM::InferencePool::delete_inference(size_t id) {
|
||||||
std::filesystem::remove(get_slot_filename(id), ec);
|
std::filesystem::remove(get_slot_filename(id), ec);
|
||||||
}
|
}
|
||||||
|
|
||||||
void LM::InferencePool::store_all() {
|
LM_SCHEDULABLE(void) LM::InferencePool::store_all() {
|
||||||
for (auto& slot : slots) {
|
for (auto& slot : slots) {
|
||||||
if (slot.is_free()) continue;
|
if (slot.is_free()) continue;
|
||||||
store_slot(slot);
|
LM_COAWAIT store_slot(slot);
|
||||||
}
|
}
|
||||||
return;
|
LM_CORETURN;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<size_t> LM::InferencePool::get_active_slot_ids() const {
|
std::vector<size_t> LM::InferencePool::get_active_slot_ids() const {
|
||||||
|
|
39
llama.cpp
39
llama.cpp
|
@ -1,39 +0,0 @@
|
||||||
#include "justlm_llama.hpp"
|
|
||||||
#include "justlm.hpp"
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <string_view>
|
|
||||||
#include <fstream>
|
|
||||||
#include <cstdint>
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
const LM::Implementation *get_justlm_implementation() {
|
|
||||||
static LM::Implementation fres{false};
|
|
||||||
return &fres;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool magic_match(std::istream& f) {
|
|
||||||
// Check magic
|
|
||||||
uint32_t magic = 0;
|
|
||||||
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
|
||||||
return magic == 0x46554747;
|
|
||||||
}
|
|
||||||
|
|
||||||
LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) {
|
|
||||||
f.close();
|
|
||||||
return new LM::LLaMAInference(weights_path, p);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
__attribute__((constructor))
|
|
||||||
static void init() {
|
|
||||||
llama_backend_init(true);
|
|
||||||
}
|
|
||||||
|
|
||||||
__attribute__((destructor))
|
|
||||||
static void deinit() {
|
|
||||||
llama_backend_free();
|
|
||||||
}
|
|
1
llama.cpp
Submodule
1
llama.cpp
Submodule
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 0e018fe008eacebdbcfa2d61b6c988c245c961cd
|
|
@ -1 +0,0 @@
|
||||||
Subproject commit b9f47952ffae4e0d3420905526003c23333f6c98
|
|
452
llama.cpp.cmake
452
llama.cpp.cmake
|
@ -1,452 +0,0 @@
|
||||||
cmake_minimum_required(VERSION 3.12) # Don't bump this version for no reason
|
|
||||||
|
|
||||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
|
||||||
|
|
||||||
if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
|
|
||||||
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
|
|
||||||
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
|
||||||
|
|
||||||
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
|
||||||
set(LLAMA_STANDALONE ON)
|
|
||||||
|
|
||||||
# configure project version
|
|
||||||
# TODO
|
|
||||||
else()
|
|
||||||
set(LLAMA_STANDALONE OFF)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (EMSCRIPTEN)
|
|
||||||
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
|
||||||
|
|
||||||
option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON)
|
|
||||||
else()
|
|
||||||
if (MINGW)
|
|
||||||
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
|
||||||
else()
|
|
||||||
set(BUILD_SHARED_LIBS_DEFAULT ON)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
|
|
||||||
#
|
|
||||||
# Option list
|
|
||||||
#
|
|
||||||
|
|
||||||
# general
|
|
||||||
option(LLAMA_STATIC "llama: static link libraries" OFF)
|
|
||||||
option(LLAMA_NATIVE "llama: enable -march=native flag" OFF)
|
|
||||||
option(LLAMA_LTO "llama: enable link time optimization" OFF)
|
|
||||||
|
|
||||||
# debug
|
|
||||||
option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON)
|
|
||||||
option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF)
|
|
||||||
option(LLAMA_GPROF "llama: enable gprof" OFF)
|
|
||||||
|
|
||||||
# sanitizers
|
|
||||||
option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF)
|
|
||||||
option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF)
|
|
||||||
option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF)
|
|
||||||
|
|
||||||
# instruction set specific
|
|
||||||
option(LLAMA_AVX "llama: enable AVX" ON)
|
|
||||||
option(LLAMA_AVX2 "llama: enable AVX2" ON)
|
|
||||||
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
|
|
||||||
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
|
|
||||||
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
|
|
||||||
option(LLAMA_FMA "llama: enable FMA" ON)
|
|
||||||
# in MSVC F16C is implied with AVX2/AVX512
|
|
||||||
if (NOT MSVC)
|
|
||||||
option(LLAMA_F16C "llama: enable F16C" ON)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# 3rd party libs
|
|
||||||
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
|
|
||||||
option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF)
|
|
||||||
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
|
|
||||||
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
|
|
||||||
option(LLAMA_METAL "llama: use Metal" OFF)
|
|
||||||
option(LLAMA_K_QUANTS "llama: use k-quants" ON)
|
|
||||||
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
|
|
||||||
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
|
|
||||||
set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
|
|
||||||
|
|
||||||
#
|
|
||||||
# Compile flags
|
|
||||||
#
|
|
||||||
|
|
||||||
set(CMAKE_C_STANDARD 11)
|
|
||||||
set(CMAKE_C_STANDARD_REQUIRED true)
|
|
||||||
set(THREADS_PREFER_PTHREAD_FLAG ON)
|
|
||||||
find_package(Threads REQUIRED)
|
|
||||||
|
|
||||||
if (NOT MSVC)
|
|
||||||
if (LLAMA_SANITIZE_THREAD)
|
|
||||||
add_compile_options(-fsanitize=thread)
|
|
||||||
link_libraries(-fsanitize=thread)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LLAMA_SANITIZE_ADDRESS)
|
|
||||||
add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
|
|
||||||
link_libraries(-fsanitize=address)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LLAMA_SANITIZE_UNDEFINED)
|
|
||||||
add_compile_options(-fsanitize=undefined)
|
|
||||||
link_libraries(-fsanitize=undefined)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (APPLE AND LLAMA_ACCELERATE)
|
|
||||||
find_library(ACCELERATE_FRAMEWORK Accelerate)
|
|
||||||
if (ACCELERATE_FRAMEWORK)
|
|
||||||
message(STATUS "Accelerate framework found")
|
|
||||||
|
|
||||||
add_compile_definitions(GGML_USE_ACCELERATE)
|
|
||||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
|
|
||||||
else()
|
|
||||||
message(WARNING "Accelerate framework not found")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LLAMA_OPENBLAS)
|
|
||||||
if (LLAMA_STATIC)
|
|
||||||
set(BLA_STATIC ON)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(BLA_VENDOR OpenBLAS)
|
|
||||||
find_package(BLAS)
|
|
||||||
if (BLAS_FOUND)
|
|
||||||
message(STATUS "OpenBLAS found")
|
|
||||||
|
|
||||||
add_compile_definitions(GGML_USE_OPENBLAS)
|
|
||||||
add_link_options(${BLAS_LIBRARIES})
|
|
||||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} openblas)
|
|
||||||
|
|
||||||
# find header file
|
|
||||||
set(OPENBLAS_INCLUDE_SEARCH_PATHS
|
|
||||||
/usr/include
|
|
||||||
/usr/include/openblas
|
|
||||||
/usr/include/openblas-base
|
|
||||||
/usr/local/include
|
|
||||||
/usr/local/include/openblas
|
|
||||||
/usr/local/include/openblas-base
|
|
||||||
/opt/OpenBLAS/include
|
|
||||||
$ENV{OpenBLAS_HOME}
|
|
||||||
$ENV{OpenBLAS_HOME}/include
|
|
||||||
)
|
|
||||||
find_path(OPENBLAS_INC NAMES cblas.h PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS})
|
|
||||||
add_compile_options(-I${OPENBLAS_INC})
|
|
||||||
else()
|
|
||||||
message(WARNING "OpenBLAS not found")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LLAMA_ALL_WARNINGS)
|
|
||||||
if (NOT MSVC)
|
|
||||||
set(c_flags
|
|
||||||
-Wall
|
|
||||||
-Wextra
|
|
||||||
-Wpedantic
|
|
||||||
-Wcast-qual
|
|
||||||
-Wdouble-promotion
|
|
||||||
-Wshadow
|
|
||||||
-Wstrict-prototypes
|
|
||||||
-Wpointer-arith
|
|
||||||
)
|
|
||||||
set(cxx_flags
|
|
||||||
-Wall
|
|
||||||
-Wextra
|
|
||||||
-Wpedantic
|
|
||||||
-Wcast-qual
|
|
||||||
-Wno-unused-function
|
|
||||||
-Wno-multichar
|
|
||||||
)
|
|
||||||
else()
|
|
||||||
# todo : msvc
|
|
||||||
endif()
|
|
||||||
|
|
||||||
add_compile_options(
|
|
||||||
"$<$<COMPILE_LANGUAGE:C>:${c_flags}>"
|
|
||||||
"$<$<COMPILE_LANGUAGE:CXX>:${cxx_flags}>"
|
|
||||||
)
|
|
||||||
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (MSVC)
|
|
||||||
add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
|
|
||||||
|
|
||||||
if (BUILD_SHARED_LIBS)
|
|
||||||
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LLAMA_LTO)
|
|
||||||
include(CheckIPOSupported)
|
|
||||||
check_ipo_supported(RESULT result OUTPUT output)
|
|
||||||
if (result)
|
|
||||||
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
|
|
||||||
else()
|
|
||||||
message(WARNING "IPO is not supported: ${output}")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
# Architecture specific
|
|
||||||
# TODO: probably these flags need to be tweaked on some architectures
|
|
||||||
# feel free to update the Makefile for your architecture and send a pull request or issue
|
|
||||||
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
|
||||||
if (NOT MSVC)
|
|
||||||
if (LLAMA_STATIC)
|
|
||||||
add_link_options(-static)
|
|
||||||
if (MINGW)
|
|
||||||
add_link_options(-static-libgcc -static-libstdc++)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
if (LLAMA_GPROF)
|
|
||||||
add_compile_options(-pg)
|
|
||||||
endif()
|
|
||||||
if (LLAMA_NATIVE)
|
|
||||||
add_compile_options(-march=native)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
function(remove_nonexistent SOURCES)
|
|
||||||
set(SOURCES_BAK ${${SOURCES}})
|
|
||||||
set(${SOURCES} )
|
|
||||||
foreach (FILE ${SOURCES_BAK})
|
|
||||||
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${FILE})
|
|
||||||
set(${SOURCES} ${${SOURCES}} ${FILE})
|
|
||||||
endif()
|
|
||||||
endforeach()
|
|
||||||
set(${SOURCES} ${${SOURCES}} PARENT_SCOPE)
|
|
||||||
endfunction()
|
|
||||||
|
|
||||||
function(include_ggml DIRECTORY SUFFIX WITH_LLAMA)
|
|
||||||
message(STATUS "Configuring ggml implementation target llama${SUFFIX} in ${CMAKE_CURRENT_SOURCE_DIR}/${DIRECTORY}")
|
|
||||||
|
|
||||||
#
|
|
||||||
# Build libraries
|
|
||||||
#
|
|
||||||
|
|
||||||
set(GGML_CUBLAS_USE NO)
|
|
||||||
if (LLAMA_CUBLAS)
|
|
||||||
cmake_minimum_required(VERSION 3.17)
|
|
||||||
|
|
||||||
find_package(CUDAToolkit)
|
|
||||||
if (CUDAToolkit_FOUND)
|
|
||||||
set(GGML_CUBLAS_USE YES)
|
|
||||||
message(STATUS "cuBLAS found")
|
|
||||||
|
|
||||||
enable_language(CUDA)
|
|
||||||
|
|
||||||
set(GGML_SOURCES_CUDA ${DIRECTORY}/ggml-cuda.cu ${DIRECTORY}/ggml-cuda.h)
|
|
||||||
|
|
||||||
if (LLAMA_STATIC)
|
|
||||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
|
||||||
else()
|
|
||||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
else()
|
|
||||||
message(WARNING "cuBLAS not found")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(GGML_CLBLAST_USE NO)
|
|
||||||
if (LLAMA_CLBLAST)
|
|
||||||
find_package(CLBlast)
|
|
||||||
if (CLBlast_FOUND)
|
|
||||||
set(GGML_CLBLAST_USE YES)
|
|
||||||
message(STATUS "CLBlast found")
|
|
||||||
|
|
||||||
set(GGML_OPENCL_SOURCE_FILE ggml-opencl.cpp)
|
|
||||||
if (NOT EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${DIRECTORY}/${GGML_OPENCL_SOURCE_FILE})
|
|
||||||
set(GGML_OPENCL_SOURCE_FILE ggml-opencl.c)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(GGML_OPENCL_SOURCES ${DIRECTORY}/${GGML_OPENCL_SOURCE_FILE} ${DIRECTORY}/ggml-opencl.h)
|
|
||||||
|
|
||||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast)
|
|
||||||
else()
|
|
||||||
message(WARNING "CLBlast not found")
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(GGML_METAL_SOURCES )
|
|
||||||
|
|
||||||
if (LLAMA_METAL)
|
|
||||||
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
|
|
||||||
find_library(METAL_FRAMEWORK Metal REQUIRED)
|
|
||||||
find_library(METALKIT_FRAMEWORK MetalKit REQUIRED)
|
|
||||||
find_library(METALPERFORMANCE_FRAMEWORK MetalPerformanceShaders REQUIRED)
|
|
||||||
|
|
||||||
set(GGML_METAL_SOURCES ${DIRECTORY}/ggml-metal.m ${DIRECTORY}/ggml-metal.h)
|
|
||||||
# get full path to the file
|
|
||||||
#add_compile_definitions(GGML_METAL_DIR_KERNELS="${CMAKE_CURRENT_SOURCE_DIR}/")
|
|
||||||
|
|
||||||
# copy ggml-metal.metal to bin directory
|
|
||||||
configure_file(${DIRECTORY}/ggml-metal.metal bin/ggml-metal.metal COPYONLY)
|
|
||||||
|
|
||||||
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS}
|
|
||||||
${FOUNDATION_LIBRARY}
|
|
||||||
${METAL_FRAMEWORK}
|
|
||||||
${METALKIT_FRAMEWORK}
|
|
||||||
${METALPERFORMANCE_FRAMEWORK}
|
|
||||||
)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
set(GGML_SOURCES
|
|
||||||
${DIRECTORY}/ggml.c
|
|
||||||
${DIRECTORY}/ggml.h
|
|
||||||
${DIRECTORY}/ggml-alloc.c
|
|
||||||
${DIRECTORY}/ggml-alloc.h
|
|
||||||
${DIRECTORY}/ggml-quants.c
|
|
||||||
${DIRECTORY}/ggml-quants.h
|
|
||||||
${DIRECTORY}/ggml-backend.c
|
|
||||||
${DIRECTORY}/ggml-backend.h}
|
|
||||||
${GGML_SOURCES_CUDA}
|
|
||||||
${GGML_METAL_SOURCES}
|
|
||||||
${GGML_OPENCL_SOURCES})
|
|
||||||
remove_nonexistent(GGML_SOURCES)
|
|
||||||
add_library(ggml${SUFFIX} OBJECT ${GGML_SOURCES})
|
|
||||||
target_compile_definitions(ggml${SUFFIX} PRIVATE _GNU_SOURCE)
|
|
||||||
|
|
||||||
if (LLAMA_K_QUANTS)
|
|
||||||
target_compile_definitions(ggml${SUFFIX} PUBLIC GGML_USE_K_QUANTS)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LLAMA_METAL AND GGML_METAL_SOURCES)
|
|
||||||
target_compile_definitions(ggml${SUFFIX} PUBLIC GGML_USE_METAL GGML_METAL_NDEBUG)
|
|
||||||
endif()
|
|
||||||
target_include_directories(ggml${SUFFIX} PUBLIC ${DIRECTORY})
|
|
||||||
target_compile_features(ggml${SUFFIX} PUBLIC c_std_11) # don't bump
|
|
||||||
|
|
||||||
if (BUILD_SHARED_LIBS)
|
|
||||||
set_target_properties(ggml${SUFFIX} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (WITH_LLAMA)
|
|
||||||
SET(LLAMA_SOURCES
|
|
||||||
${DIRECTORY}/llama.cpp
|
|
||||||
${DIRECTORY}/llama.h
|
|
||||||
${DIRECTORY}/common/grammar-parser.h
|
|
||||||
${DIRECTORY}/common/grammar-parser.cpp)
|
|
||||||
remove_nonexistent(LLAMA_SOURCES)
|
|
||||||
add_library(llama${SUFFIX} ${LLAMA_SOURCES})
|
|
||||||
|
|
||||||
if (LLAMA_METAL AND GGML_METAL_SOURCES)
|
|
||||||
target_compile_definitions(llama${SUFFIX} PUBLIC GGML_USE_METAL GGML_METAL_NDEBUG)
|
|
||||||
endif()
|
|
||||||
target_include_directories(llama${SUFFIX} PUBLIC ${DIRECTORY})
|
|
||||||
target_compile_features(llama${SUFFIX} PUBLIC cxx_std_11) # don't bump
|
|
||||||
|
|
||||||
if (BUILD_SHARED_LIBS)
|
|
||||||
set_target_properties(llama${SUFFIX} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
|
||||||
target_compile_definitions(llama${SUFFIX} PRIVATE LLAMA_SHARED LLAMA_BUILD)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_SOURCES_CUDA)
|
|
||||||
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
|
|
||||||
set_property(TARGET ggml${SUFFIX} PROPERTY CUDA_ARCHITECTURES OFF)
|
|
||||||
set_property(TARGET ggml${SUFFIX} PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
|
|
||||||
if (WITH_LLAMA)
|
|
||||||
set_property(TARGET llama${SUFFIX} PROPERTY CUDA_ARCHITECTURES OFF)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (GGML_CUBLAS_USE)
|
|
||||||
target_compile_definitions(ggml${SUFFIX} PRIVATE
|
|
||||||
GGML_USE_CUBLAS
|
|
||||||
GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}
|
|
||||||
GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
|
|
||||||
if (WITH_LLAMA)
|
|
||||||
target_compile_definitions(llama${SUFFIX} PRIVATE
|
|
||||||
GGML_USE_CUBLAS
|
|
||||||
GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}
|
|
||||||
GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
if (GGML_CLBLAST_USE)
|
|
||||||
if (WITH_LLAMA)
|
|
||||||
target_compile_definitions(llama${SUFFIX} PRIVATE GGML_USE_CLBLAST)
|
|
||||||
endif()
|
|
||||||
target_compile_definitions(ggml${SUFFIX} PRIVATE GGML_USE_CLBLAST)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
|
|
||||||
message(STATUS "ARM detected")
|
|
||||||
if (MSVC)
|
|
||||||
# TODO: arm msvc?
|
|
||||||
else()
|
|
||||||
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
|
|
||||||
target_compile_options(ggml${SUFFIX} PRIVATE -mcpu=native)
|
|
||||||
endif()
|
|
||||||
# TODO: armv6,7,8 version specific flags
|
|
||||||
endif()
|
|
||||||
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
|
|
||||||
message(STATUS "x86 detected")
|
|
||||||
if (MSVC)
|
|
||||||
if (LLAMA_AVX512)
|
|
||||||
target_compile_definitions(ggml${SUFFIX} PRIVATE
|
|
||||||
$<$<COMPILE_LANGUAGE:C>:/arch:AVX512>
|
|
||||||
$<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
|
|
||||||
# MSVC has no compile-time flags enabling specific
|
|
||||||
# AVX512 extensions, neither it defines the
|
|
||||||
# macros corresponding to the extensions.
|
|
||||||
# Do it manually.
|
|
||||||
if (LLAMA_AVX512_VBMI)
|
|
||||||
target_compile_definitions(ggml${SUFFIX} PRIVATE
|
|
||||||
$<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>
|
|
||||||
$<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
|
|
||||||
endif()
|
|
||||||
if (LLAMA_AVX512_VNNI)
|
|
||||||
target_compile_definitions(ggml${SUFFIX} PRIVATE
|
|
||||||
$<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>
|
|
||||||
$<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
|
|
||||||
endif()
|
|
||||||
elseif (LLAMA_AVX2)
|
|
||||||
target_compile_definitions(ggml${SUFFIX} PRIVATE
|
|
||||||
$<$<COMPILE_LANGUAGE:C>:/arch:AVX2>
|
|
||||||
$<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>)
|
|
||||||
elseif (LLAMA_AVX)
|
|
||||||
target_compile_definitions(ggml${SUFFIX} PRIVATE
|
|
||||||
$<$<COMPILE_LANGUAGE:C>:/arch:AVX>
|
|
||||||
$<$<COMPILE_LANGUAGE:CXX>:/arch:AVX>)
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
if (LLAMA_F16C)
|
|
||||||
target_compile_options(ggml${SUFFIX} PRIVATE -mf16c)
|
|
||||||
endif()
|
|
||||||
if (LLAMA_FMA)
|
|
||||||
target_compile_options(ggml${SUFFIX} PRIVATE -mfma)
|
|
||||||
endif()
|
|
||||||
if (LLAMA_AVX)
|
|
||||||
target_compile_options(ggml${SUFFIX} PRIVATE -mavx)
|
|
||||||
endif()
|
|
||||||
if (LLAMA_AVX2)
|
|
||||||
target_compile_options(ggml${SUFFIX} PRIVATE -mavx2)
|
|
||||||
endif()
|
|
||||||
if (LLAMA_AVX512)
|
|
||||||
target_compile_options(ggml${SUFFIX} PRIVATE -mavx512f)
|
|
||||||
target_compile_options(ggml${SUFFIX} PRIVATE -mavx512bw)
|
|
||||||
endif()
|
|
||||||
if (LLAMA_AVX512_VBMI)
|
|
||||||
target_compile_options(ggml${SUFFIX} PRIVATE -mavx512vbmi)
|
|
||||||
endif()
|
|
||||||
if (LLAMA_AVX512_VNNI)
|
|
||||||
target_compile_options(ggml${SUFFIX} PRIVATE -mavx512vnni)
|
|
||||||
endif()
|
|
||||||
endif()
|
|
||||||
else()
|
|
||||||
# TODO: support PowerPC
|
|
||||||
message(STATUS "Unknown architecture")
|
|
||||||
endif()
|
|
||||||
|
|
||||||
target_link_libraries(ggml${SUFFIX} PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
|
|
||||||
if (WITH_LLAMA)
|
|
||||||
target_link_libraries(llama${SUFFIX} PRIVATE ggml${SUFFIX} ${LLAMA_EXTRA_LIBS})
|
|
||||||
endif()
|
|
||||||
endfunction()
|
|
26
mpt.cpp
26
mpt.cpp
|
@ -1,26 +0,0 @@
|
||||||
#include "justlm_mpt.hpp"
|
|
||||||
#include "justlm.hpp"
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
#include <string_view>
|
|
||||||
#include <fstream>
|
|
||||||
#include <cstdint>
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
const LM::Implementation *get_justlm_implementation() {
|
|
||||||
static LM::Implementation fres{false};
|
|
||||||
return &fres;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool magic_match(std::istream& f) {
|
|
||||||
uint32_t magic;
|
|
||||||
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
|
||||||
return magic == 0x67676d6d;
|
|
||||||
}
|
|
||||||
|
|
||||||
LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) {
|
|
||||||
return new LM::MPTInference(weights_path, f, p);
|
|
||||||
}
|
|
||||||
}
|
|
235
mpt/mpt.cpp
235
mpt/mpt.cpp
|
@ -1,5 +1,4 @@
|
||||||
#include "mpt.hpp"
|
#include "mpt.hpp"
|
||||||
#include "../g4a_common.hpp"
|
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
@ -11,7 +10,7 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include "../msvc_compat_unistd.h"
|
#include <unistd.h>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
#include <thread>
|
#include <thread>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
@ -19,7 +18,7 @@
|
||||||
#include <ggml.h>
|
#include <ggml.h>
|
||||||
|
|
||||||
inline
|
inline
|
||||||
unsigned long long operator ""_MiB(unsigned long long bytes) {
|
unsigned long long operator ""_MB(unsigned long long bytes) {
|
||||||
return bytes*1024*1024;
|
return bytes*1024*1024;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,7 +33,7 @@ static bool kv_cache_init(
|
||||||
const int64_t n_mem = (int64_t)n_layer*n_ctx;
|
const int64_t n_mem = (int64_t)n_layer*n_ctx;
|
||||||
const int64_t n_elements = n_embd*n_mem;
|
const int64_t n_elements = n_embd*n_mem;
|
||||||
|
|
||||||
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MiB);
|
cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2_MB);
|
||||||
|
|
||||||
struct ggml_init_params params;
|
struct ggml_init_params params;
|
||||||
params.mem_size = cache.buf.size;
|
params.mem_size = cache.buf.size;
|
||||||
|
@ -54,8 +53,13 @@ static bool kv_cache_init(
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string regex_escape(const std::string &s) {
|
||||||
|
static const std::regex metacharacters(R"([\.\^\$\-\+\(\)\[\]\{\}\|\?\*])");
|
||||||
|
return std::regex_replace(s, metacharacters, "\\$&");
|
||||||
|
}
|
||||||
|
|
||||||
// load the model's weights from a stream
|
// load the model's weights from a stream
|
||||||
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, gpt_vocab & vocab) {
|
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, mpt_vocab & vocab) {
|
||||||
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
printf("%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||||
|
|
||||||
// verify magic
|
// verify magic
|
||||||
|
@ -119,6 +123,8 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
|
||||||
vocab.id_to_token[i] = word;
|
vocab.id_to_token[i] = word;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: this only kind-of works, the gpt_tokenize can still incorrectly
|
||||||
|
// tokenize special tokens
|
||||||
if(special) {
|
if(special) {
|
||||||
vocab.add_special_token(word);
|
vocab.add_special_token(word);
|
||||||
}
|
}
|
||||||
|
@ -326,7 +332,7 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
|
||||||
}
|
}
|
||||||
|
|
||||||
// load the model's weights from a file path
|
// load the model's weights from a file path
|
||||||
bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vocab) {
|
bool mpt_model_load(const std::string & fname, mpt_model & model, mpt_vocab & vocab) {
|
||||||
|
|
||||||
auto fin = std::ifstream(fname, std::ios::binary);
|
auto fin = std::ifstream(fname, std::ios::binary);
|
||||||
if (!fin) {
|
if (!fin) {
|
||||||
|
@ -356,31 +362,30 @@ bool mpt_eval(
|
||||||
const int n_head = hparams.n_head;
|
const int n_head = hparams.n_head;
|
||||||
const int n_vocab = hparams.n_vocab;
|
const int n_vocab = hparams.n_vocab;
|
||||||
|
|
||||||
const size_t init_buf_size = 1024_MiB;
|
static size_t buf_size = 256u*1024*1024;
|
||||||
if (!model.buf.addr || model.buf.size < init_buf_size)
|
static void * buf = malloc(buf_size);
|
||||||
model.buf.resize(init_buf_size);
|
|
||||||
|
|
||||||
if (mem_per_token > 0 && mem_per_token*N > model.buf.size) {
|
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
|
||||||
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
|
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
|
||||||
// printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, model.buf.size, buf_size_new);
|
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||||
|
|
||||||
// reallocate
|
// reallocate
|
||||||
model.buf.resize(buf_size_new);
|
buf_size = buf_size_new;
|
||||||
if (model.buf.addr == nullptr) {
|
buf = realloc(buf, buf_size);
|
||||||
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, model.buf.size);
|
if (buf == nullptr) {
|
||||||
|
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
model.buf.size,
|
.mem_size = buf_size,
|
||||||
model.buf.addr,
|
.mem_buffer = buf,
|
||||||
false
|
.no_alloc = false,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_context * ctx0 = ggml_init(params);
|
struct ggml_context * ctx0 = ggml_init(params);
|
||||||
struct ggml_cgraph gf{};
|
struct ggml_cgraph gf = { .n_threads = n_threads };
|
||||||
gf.n_threads = n_threads;
|
|
||||||
|
|
||||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||||
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
||||||
|
@ -516,12 +521,10 @@ bool mpt_eval(
|
||||||
out = ggml_mul_mat(ctx0, model.wte, out);
|
out = ggml_mul_mat(ctx0, model.wte, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// run the computation
|
// run the computation
|
||||||
ggml_build_forward_expand(&gf, out);
|
ggml_build_forward_expand(&gf, out);
|
||||||
ggml_graph_compute (ctx0, &gf);
|
ggml_graph_compute (ctx0, &gf);
|
||||||
|
|
||||||
|
|
||||||
// return result for just the last token
|
// return result for just the last token
|
||||||
embd_w.resize(n_vocab);
|
embd_w.resize(n_vocab);
|
||||||
memcpy(embd_w.data(), (float *) ggml_get_data(out) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
memcpy(embd_w.data(), (float *) ggml_get_data(out) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
|
||||||
|
@ -536,6 +539,98 @@ bool mpt_eval(
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int> mpt_tokenize_inner(const mpt_vocab & vocab, const std::string & text) {
|
||||||
|
// taken from stablelm example in ggml
|
||||||
|
// they both use the gpt-neox tokenizer
|
||||||
|
// not sure if this entirely right?
|
||||||
|
std::vector<std::string> words;
|
||||||
|
|
||||||
|
|
||||||
|
// first split the text into words
|
||||||
|
{
|
||||||
|
std::string str = text;
|
||||||
|
std::string pat = R"('s|'t|'re|'ve|'m|'ll|'d| ?[[:alpha:]]+| ?[[:digit:]]+| ?[^\s[:alpha:][:digit:]]+|\s+(?!\S)|\s+)";
|
||||||
|
std::regex re(pat);
|
||||||
|
std::smatch m;
|
||||||
|
|
||||||
|
while (std::regex_search(str, m, re)) {
|
||||||
|
for (auto x : m) {
|
||||||
|
words.push_back(x);
|
||||||
|
}
|
||||||
|
str = m.suffix();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the longest tokens that form the words:
|
||||||
|
std::vector<mpt_vocab::id> tokens;
|
||||||
|
for (const auto & word : words) {
|
||||||
|
if (word.size() == 0) continue;
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
int n = word.size();
|
||||||
|
while (i < n) {
|
||||||
|
int j = n;
|
||||||
|
while (j > i) {
|
||||||
|
auto it = vocab.token_to_id.find(word.substr(i, j-i));
|
||||||
|
if (it != vocab.token_to_id.end()) {
|
||||||
|
tokens.push_back(it->second);
|
||||||
|
i = j;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
--j;
|
||||||
|
}
|
||||||
|
if (i == n) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (j == i) {
|
||||||
|
auto sub = word.substr(i, 1);
|
||||||
|
if (vocab.token_to_id.find(sub) != vocab.token_to_id.end()) {
|
||||||
|
tokens.push_back(vocab.token_to_id.at(sub));
|
||||||
|
} else {
|
||||||
|
fprintf(stderr, "%s: unknown token '%s'\n", __func__, sub.data());
|
||||||
|
}
|
||||||
|
++i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<mpt_vocab::id> mpt_tokenize(const mpt_vocab & vocab, const std::string & text) {
|
||||||
|
// Generate the subpattern from the special_tokens vector if it's not empty
|
||||||
|
if (!vocab.special_tokens.empty()) {
|
||||||
|
std::vector<mpt_vocab::id> out;
|
||||||
|
std::vector<std::string> chunks;
|
||||||
|
std::string str = text;
|
||||||
|
std::string special_tokens_subpattern;
|
||||||
|
for (const auto &token : vocab.special_tokens) {
|
||||||
|
if (!special_tokens_subpattern.empty()) {
|
||||||
|
special_tokens_subpattern += "|";
|
||||||
|
}
|
||||||
|
special_tokens_subpattern += regex_escape(token);
|
||||||
|
}
|
||||||
|
std::regex re(special_tokens_subpattern);
|
||||||
|
std::smatch m;
|
||||||
|
while (std::regex_search(str, m, re)) {
|
||||||
|
auto tok = vocab.token_to_id.find(m.str());
|
||||||
|
if (tok != vocab.token_to_id.end()) {
|
||||||
|
auto tokid = tok->second;
|
||||||
|
auto pfxtoks = mpt_tokenize_inner(vocab, m.prefix());
|
||||||
|
out.insert(out.end(), pfxtoks.begin(), pfxtoks.end());
|
||||||
|
out.push_back(tokid);
|
||||||
|
str = m.suffix();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!str.empty()) {
|
||||||
|
auto tokrest = mpt_tokenize_inner(vocab, str);
|
||||||
|
out.insert(out.end(), tokrest.begin(), tokrest.end());
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
} else {
|
||||||
|
return mpt_tokenize_inner(vocab, text);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#define MPT_MAX_RNG_STATE 64*1024
|
#define MPT_MAX_RNG_STATE 64*1024
|
||||||
|
|
||||||
|
@ -596,6 +691,104 @@ size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint
|
||||||
return written;
|
return written;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mpt_vocab::id mpt_sample_top_k_top_p(
|
||||||
|
const mpt_vocab & vocab,
|
||||||
|
const size_t actualVocabSize,
|
||||||
|
const int32_t * last_n_tokens_data,
|
||||||
|
int last_n_tokens_size,
|
||||||
|
const std::vector<float> logits,
|
||||||
|
int top_k,
|
||||||
|
double top_p,
|
||||||
|
double temp,
|
||||||
|
float repeat_penalty,
|
||||||
|
std::mt19937 & rng) {
|
||||||
|
int n_logits = actualVocabSize;
|
||||||
|
|
||||||
|
const auto last_n_tokens = std::vector<int32_t>(last_n_tokens_data, last_n_tokens_data + last_n_tokens_size);
|
||||||
|
const auto * plogits = logits.data() + logits.size() - n_logits;
|
||||||
|
|
||||||
|
std::vector<std::pair<double, mpt_vocab::id>> logits_id;
|
||||||
|
logits_id.reserve(n_logits);
|
||||||
|
|
||||||
|
{
|
||||||
|
const float scale = 1.0f/temp;
|
||||||
|
for (int i = 0; i < n_logits; ++i) {
|
||||||
|
// repetition penalty from ctrl paper (https://arxiv.org/abs/1909.05858)
|
||||||
|
// credit https://github.com/facebookresearch/llama/compare/main...shawwn:llama:main
|
||||||
|
if (std::find(last_n_tokens.begin(), last_n_tokens.end(), i) != last_n_tokens.end()) {
|
||||||
|
// if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
|
||||||
|
if (plogits[i] < 0.0f) {
|
||||||
|
logits_id.push_back(std::make_pair(plogits[i]*scale*repeat_penalty, i));
|
||||||
|
} else {
|
||||||
|
logits_id.push_back(std::make_pair(plogits[i]*scale/repeat_penalty, i));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
logits_id.push_back(std::make_pair(plogits[i]*scale, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// find the top K tokens
|
||||||
|
std::partial_sort(
|
||||||
|
logits_id.begin(),
|
||||||
|
logits_id.begin() + top_k, logits_id.end(),
|
||||||
|
[](const std::pair<double, mpt_vocab::id> & a, const std::pair<double, mpt_vocab::id> & b) {
|
||||||
|
return a.first > b.first;
|
||||||
|
});
|
||||||
|
|
||||||
|
logits_id.resize(top_k);
|
||||||
|
|
||||||
|
double maxl = -INFINITY;
|
||||||
|
for (const auto & kv : logits_id) {
|
||||||
|
maxl = std::max(maxl, kv.first);
|
||||||
|
}
|
||||||
|
|
||||||
|
// compute probs for the top K tokens
|
||||||
|
std::vector<double> probs;
|
||||||
|
probs.reserve(logits_id.size());
|
||||||
|
|
||||||
|
double sum = 0.0;
|
||||||
|
for (const auto & kv : logits_id) {
|
||||||
|
double p = exp(kv.first - maxl);
|
||||||
|
probs.push_back(p);
|
||||||
|
sum += p;
|
||||||
|
}
|
||||||
|
|
||||||
|
// normalize the probs
|
||||||
|
for (auto & p : probs) {
|
||||||
|
p /= sum;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (top_p < 1.0f) {
|
||||||
|
double cumsum = 0.0f;
|
||||||
|
for (int i = 0; i < top_k; i++) {
|
||||||
|
cumsum += probs[i];
|
||||||
|
if (cumsum >= top_p) {
|
||||||
|
top_k = i + 1;
|
||||||
|
probs.resize(top_k);
|
||||||
|
logits_id.resize(top_k);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cumsum = 1.0/cumsum;
|
||||||
|
for (int i = 0; i < (int) probs.size(); i++) {
|
||||||
|
probs[i] *= cumsum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//printf("\n");
|
||||||
|
//for (int i = 0; i < (int) probs.size(); i++) {
|
||||||
|
// printf("%d: '%s' %f\n", i, vocab.id_to_token.at(logits_id[i].second).c_str(), probs[i]);
|
||||||
|
//}
|
||||||
|
//exit(0);
|
||||||
|
|
||||||
|
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||||
|
int idx = dist(rng);
|
||||||
|
|
||||||
|
return logits_id[idx].second;
|
||||||
|
}
|
||||||
|
|
||||||
size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *src)
|
size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *src)
|
||||||
{
|
{
|
||||||
const uint8_t * in = src;
|
const uint8_t * in = src;
|
||||||
|
|
20
mpt/mpt.hpp
20
mpt/mpt.hpp
|
@ -1,7 +1,5 @@
|
||||||
#ifndef MPT_H
|
#ifndef MPT_H
|
||||||
#define MPT_H
|
#define MPT_H
|
||||||
#include "../g4a_common.hpp"
|
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
@ -85,6 +83,7 @@ struct mpt_model {
|
||||||
struct ggml_context * ctx;
|
struct ggml_context * ctx;
|
||||||
std::map<std::string, struct ggml_tensor *> tensors;
|
std::map<std::string, struct ggml_tensor *> tensors;
|
||||||
|
|
||||||
|
|
||||||
mpt_buffer buf;
|
mpt_buffer buf;
|
||||||
|
|
||||||
~mpt_model() {
|
~mpt_model() {
|
||||||
|
@ -94,9 +93,24 @@ struct mpt_model {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct mpt_vocab {
|
||||||
|
using id = int32_t;
|
||||||
|
using token = std::string;
|
||||||
|
|
||||||
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, gpt_vocab& vocab);
|
std::map<token, id> token_to_id;
|
||||||
|
std::map<id, token> id_to_token;
|
||||||
|
std::vector<std::string> special_tokens;
|
||||||
|
|
||||||
|
void add_special_token(const std::string &token) {
|
||||||
|
special_tokens.push_back(token);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, mpt_vocab & vocab);
|
||||||
bool mpt_eval(mpt_model& model, const int n_threads, const int n_past, const std::vector<int>& embd_inp, std::vector<float>& embd_w, size_t& mem_per_token);
|
bool mpt_eval(mpt_model& model, const int n_threads, const int n_past, const std::vector<int>& embd_inp, std::vector<float>& embd_w, size_t& mem_per_token);
|
||||||
|
std::vector<mpt_vocab::id> mpt_tokenize(const mpt_vocab & vocab, const std::string & text);
|
||||||
|
mpt_vocab::id mpt_sample_top_k_top_p(const mpt_vocab& vocab, const size_t actualVocabSize, const int32_t *last_n_tokens_data, int last_n_tokens_size, const std::vector<float> logits, int top_k, double top_p, double temp, float repeat_penalty, std::mt19937& rng);
|
||||||
size_t mpt_get_state_size(const mpt_model &model);
|
size_t mpt_get_state_size(const mpt_model &model);
|
||||||
size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937& rng, uint8_t *dest);
|
size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937& rng, uint8_t *dest);
|
||||||
size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *src);
|
size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *src);
|
||||||
|
|
|
@ -1,11 +0,0 @@
|
||||||
#if defined(_WIN32) && defined(_MSC_VER)
|
|
||||||
#define WIN32_LEAN_AND_MEAN
|
|
||||||
#ifndef NOMINMAX
|
|
||||||
#define NOMINMAX
|
|
||||||
#endif
|
|
||||||
#include <windows.h>
|
|
||||||
#include <io.h>
|
|
||||||
#include <stdio.h> // for _fseeki64
|
|
||||||
#else
|
|
||||||
#include <unistd.h>
|
|
||||||
#endif
|
|
15
pybind.cpp
15
pybind.cpp
|
@ -9,7 +9,7 @@ namespace py = pybind11;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
PYBIND11_MODULE(justlm_py, m) {
|
PYBIND11_MODULE(libjustlm_py, m) {
|
||||||
using namespace LM;
|
using namespace LM;
|
||||||
py::class_<Inference::Params>(m, "Params")
|
py::class_<Inference::Params>(m, "Params")
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
|
@ -24,23 +24,16 @@ PYBIND11_MODULE(justlm_py, m) {
|
||||||
.def_readwrite("top_p", &Inference::Params::top_p)
|
.def_readwrite("top_p", &Inference::Params::top_p)
|
||||||
.def_readwrite("temp", &Inference::Params::temp)
|
.def_readwrite("temp", &Inference::Params::temp)
|
||||||
.def_readwrite("repeat_penalty", &Inference::Params::repeat_penalty)
|
.def_readwrite("repeat_penalty", &Inference::Params::repeat_penalty)
|
||||||
.def_readwrite("eos_ignores", &Inference::Params::n_eos_ignores)
|
.def_readwrite("eos_ignores", &Inference::Params::eos_ignores)
|
||||||
.def_readwrite("use_mlock", &Inference::Params::use_mlock)
|
.def_readwrite("use_mlock", &Inference::Params::use_mlock);
|
||||||
.def_readwrite("prefer_mirostat", &Inference::Params::prefer_mirostat)
|
|
||||||
.def_readwrite("mirostat_learning_rate", &Inference::Params::mirostat_learning_rate)
|
|
||||||
.def_readwrite("mirostat_target_entropy", &Inference::Params::mirostat_target_entropy);
|
|
||||||
py::class_<Inference>(m, "Inference")
|
py::class_<Inference>(m, "Inference")
|
||||||
.def_static("construct", &Inference::construct, py::arg("weights_path"), py::arg("params") = Inference::Params())
|
.def_static("construct", &Inference::construct, py::arg("weights_path"), py::arg("params") = Inference::Params())
|
||||||
.def("append", &Inference::append, py::arg("prompt"), py::arg("on_tick") = nullptr)
|
.def("append", &Inference::append, py::arg("prompt"), py::arg("on_tick") = nullptr)
|
||||||
.def("run", &Inference::run, py::arg("end") = "", py::arg("on_tick") = nullptr, py::arg("pre_tick") = nullptr)
|
.def("run", &Inference::run, py::arg("end") = "", py::arg("on_tick") = nullptr)
|
||||||
.def("create_savestate", &Inference::create_savestate)
|
.def("create_savestate", &Inference::create_savestate)
|
||||||
.def("restore_savestate", &Inference::restore_savestate)
|
.def("restore_savestate", &Inference::restore_savestate)
|
||||||
.def("get_prompt", &Inference::get_prompt)
|
.def("get_prompt", &Inference::get_prompt)
|
||||||
.def("get_context_size", &Inference::get_context_size)
|
.def("get_context_size", &Inference::get_context_size)
|
||||||
.def("is_mirostat_available", &Inference::is_mirostat_available)
|
|
||||||
.def("is_grammar_available", &Inference::is_grammar_available)
|
|
||||||
.def("load_grammar", &Inference::load_grammar)
|
|
||||||
.def("unload_grammar", &Inference::unload_grammar)
|
|
||||||
.def_readwrite("params", &Inference::params);
|
.def_readwrite("params", &Inference::params);
|
||||||
py::class_<Inference::Savestate>(m, "Savestate")
|
py::class_<Inference::Savestate>(m, "Savestate")
|
||||||
.def(py::init<>());
|
.def(py::init<>());
|
||||||
|
|
Loading…
Add table
Reference in a new issue