mirror of
https://gitlab.com/niansa/libjustlm.git
synced 2025-03-06 20:49:17 +01:00
Load implemenations as shared objects
This commit is contained in:
parent
5b01daa764
commit
60fe6b9c55
17 changed files with 686 additions and 97 deletions
2
.gitmodules
vendored
2
.gitmodules
vendored
|
@ -1,5 +1,5 @@
|
||||||
[submodule "llama.cpp"]
|
[submodule "llama.cpp"]
|
||||||
path = llama.cpp
|
path = llama.cpp-mainline
|
||||||
url = https://github.com/ggerganov/llama.cpp.git
|
url = https://github.com/ggerganov/llama.cpp.git
|
||||||
[submodule "llama.cpp-alibi"]
|
[submodule "llama.cpp-alibi"]
|
||||||
path = llama.cpp-alibi
|
path = llama.cpp-alibi
|
||||||
|
|
|
@ -1,52 +1,76 @@
|
||||||
cmake_minimum_required(VERSION 3.14)
|
cmake_minimum_required(VERSION 3.18)
|
||||||
|
|
||||||
|
project(justlm LANGUAGES C CXX)
|
||||||
|
|
||||||
project(libjustlm LANGUAGES C CXX)
|
|
||||||
|
|
||||||
set(CMAKE_CXX_STANDARD 17)
|
set(CMAKE_CXX_STANDARD 17)
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
||||||
|
|
||||||
set(LM_PYBIND No CACHE BOOL "If Libjustlm Python bindings should be build")
|
set(LM_PYBIND No CACHE BOOL "If justlm Python bindings should be build")
|
||||||
set(LM_COSCHED No CACHE BOOL "If Libjustlm should make use of CoSched")
|
set(LM_COSCHED No CACHE BOOL "If justlm should make use of CoSched")
|
||||||
set(LM_NOEXCEPT No CACHE BOOL "If exceptions should be disabled")
|
set(LM_NOEXCEPT No CACHE BOOL "If justlm exceptions should be disabled")
|
||||||
set(LM_MPT No CACHE BOOL "If MPT model support should be built")
|
set(LM_LLAMA Yes CACHE BOOL "If LLaMa model support should be built into justlm")
|
||||||
|
set(LM_GPTJ Yes CACHE BOOL "If GPT-J model support should be built into justlm")
|
||||||
|
set(LM_MPT Yes CACHE BOOL "If MPT model support should be built into justlm")
|
||||||
|
|
||||||
if (LM_COSCHED)
|
if (LM_COSCHED)
|
||||||
set(CMAKE_CXX_STANDARD 20)
|
set(CMAKE_CXX_STANDARD 20)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
function(target_justlm_setup target)
|
||||||
|
target_include_directories(${target} PUBLIC include/)
|
||||||
|
if (LM_COSCHED)
|
||||||
|
target_compile_definitions(${target} PUBLIC LM_COSCHED)
|
||||||
|
target_link_libraries(${target} PRIVATE cosched)
|
||||||
|
endif()
|
||||||
|
if (LM_NOEXCEPT)
|
||||||
|
target_compile_definitions(${target} PUBLIC LM_NOEXCEPT)
|
||||||
|
endif()
|
||||||
|
endfunction()
|
||||||
|
|
||||||
|
|
||||||
|
include(llama.cpp.cmake)
|
||||||
|
|
||||||
|
include_ggml(llama.cpp-mainline _mainline Yes)
|
||||||
|
include_ggml(llama.cpp-alibi _alibi No)
|
||||||
|
|
||||||
|
|
||||||
|
add_library(justlm_g4a_common SHARED g4a-common.cpp g4a-common.hpp)
|
||||||
|
|
||||||
|
|
||||||
|
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
|
||||||
|
|
||||||
if (LM_MPT)
|
if (LM_MPT)
|
||||||
set(LM_MPT_SOURCES justlm_mpt.hpp mpt/mpt.cpp mpt/mpt.hpp)
|
add_library(justlm_mpt SHARED mpt.cpp justlm_mpt.hpp mpt/mpt.cpp mpt/mpt.hpp)
|
||||||
add_subdirectory(llama.cpp-alibi)
|
target_link_libraries(justlm_mpt PRIVATE ggml_alibi justlm_g4a_common)
|
||||||
else()
|
target_justlm_setup(justlm_mpt)
|
||||||
set(LM_MPT_SOURCES )
|
|
||||||
add_subdirectory(llama.cpp)
|
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
add_library(libjustlm STATIC
|
if (LM_GPTJ)
|
||||||
|
add_library(justlm_gptj SHARED gptj.cpp justlm_gptj.hpp gptj/gptj.cpp gptj/gptj.hpp)
|
||||||
|
target_link_libraries(justlm_gptj PRIVATE ggml_mainline justlm_g4a_common)
|
||||||
|
target_justlm_setup(justlm_gptj)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LM_LLAMA)
|
||||||
|
add_library(justlm_llama SHARED llama.cpp justlm_llama.hpp)
|
||||||
|
target_link_libraries(justlm_llama PRIVATE ggml_mainline llama_mainline)
|
||||||
|
target_justlm_setup(justlm_llama)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
add_library(justlm STATIC
|
||||||
include/justlm.hpp justlm.cpp
|
include/justlm.hpp justlm.cpp
|
||||||
justlm_llama.hpp
|
|
||||||
g4a-common.cpp g4a-common.hpp
|
|
||||||
justlm_gptj.hpp gptj/gptj.cpp gptj/gptj.hpp
|
|
||||||
${LM_MPT_SOURCES}
|
|
||||||
include/justlm_pool.hpp justlm_pool.cpp
|
include/justlm_pool.hpp justlm_pool.cpp
|
||||||
|
dlhandle.hpp
|
||||||
)
|
)
|
||||||
target_link_libraries(libjustlm PRIVATE llama)
|
add_library(libjustlm ALIAS justlm)
|
||||||
|
target_link_libraries(justlm PRIVATE dl)
|
||||||
if (LM_MPT)
|
target_include_directories(justlm PUBLIC include/)
|
||||||
target_compile_definitions(libjustlm PUBLIC LM_MPT)
|
target_compile_definitions(justlm PRIVATE LIB_FILE_EXT="${CMAKE_SHARED_LIBRARY_SUFFIX}")
|
||||||
endif()
|
target_justlm_setup(justlm)
|
||||||
|
|
||||||
if (LM_COSCHED)
|
|
||||||
target_compile_definitions(libjustlm PUBLIC LM_COSCHED)
|
|
||||||
target_link_libraries(libjustlm PRIVATE cosched)
|
|
||||||
|
|
||||||
set(LM_COSCHED Yes CACHE BOOL "If Libjustlm should make use of CoSched" FORCE)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LM_NOEXCEPT)
|
|
||||||
target_compile_definitions(libjustlm PUBLIC LM_NOEXCEPT)
|
|
||||||
endif()
|
|
||||||
|
|
||||||
if (LM_PYBIND)
|
if (LM_PYBIND)
|
||||||
if (LM_COSCHED)
|
if (LM_COSCHED)
|
||||||
|
@ -55,8 +79,6 @@ if (LM_PYBIND)
|
||||||
|
|
||||||
find_package(Python COMPONENTS Interpreter Development)
|
find_package(Python COMPONENTS Interpreter Development)
|
||||||
find_package(pybind11 CONFIG)
|
find_package(pybind11 CONFIG)
|
||||||
pybind11_add_module(libjustlm_py pybind.cpp)
|
pybind11_add_module(justlm_py pybind.cpp)
|
||||||
target_link_libraries(libjustlm_py PRIVATE libjustlm)
|
target_link_libraries(justlm_py PRIVATE justlm)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
target_include_directories(libjustlm PUBLIC include/)
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
# JustLM
|
# JustLM
|
||||||
Super easy to use library for doing LLaMA/GPT-J stuff!
|
Super easy to use library for doing LLaMA/GPT-J/MPT stuff!
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
This library implements an easy to use interface to both LLaMa and GPT-J, with optional Python bindings.
|
This library implements an easy to use interface to LLaMa, GPT-J and MPT, with optional Python bindings.
|
||||||
|
|
||||||
Context scrolling is automatic and supports a top window bar.
|
Context scrolling is automatic and supports a top window bar.
|
||||||
|
|
||||||
|
|
108
dlhandle.hpp
Normal file
108
dlhandle.hpp
Normal file
|
@ -0,0 +1,108 @@
|
||||||
|
#ifndef __WIN32
|
||||||
|
#include <string>
|
||||||
|
#include <exception>
|
||||||
|
#include <utility>
|
||||||
|
#include <dlfcn.h>
|
||||||
|
|
||||||
|
|
||||||
|
class Dlhandle {
|
||||||
|
void *chandle;
|
||||||
|
|
||||||
|
public:
|
||||||
|
class Exception : public std::exception {
|
||||||
|
std::string errmsg;
|
||||||
|
public:
|
||||||
|
Exception(std::string errmsg) {
|
||||||
|
this->errmsg = errmsg;
|
||||||
|
}
|
||||||
|
virtual const char* what() const throw() {
|
||||||
|
return errmsg.c_str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Dlhandle() : chandle(nullptr) {}
|
||||||
|
Dlhandle(const std::string& fpath, int flags = RTLD_LAZY) {
|
||||||
|
chandle = dlopen(fpath.c_str(), flags);
|
||||||
|
if (!chandle) {
|
||||||
|
throw Exception("dlopen(): "+fpath);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Dlhandle(const Dlhandle& o) = delete;
|
||||||
|
Dlhandle(Dlhandle&& o) : chandle(o.chandle) {
|
||||||
|
o.chandle = nullptr;
|
||||||
|
}
|
||||||
|
~Dlhandle() {
|
||||||
|
if (chandle) dlclose(chandle);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto operator =(Dlhandle&& o) {
|
||||||
|
chandle = std::exchange(o.chandle, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_valid() const {
|
||||||
|
return chandle != nullptr;
|
||||||
|
}
|
||||||
|
operator bool() const {
|
||||||
|
return is_valid();
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
T* get(const std::string& fname) {
|
||||||
|
dlerror(); // Clear error
|
||||||
|
auto fres = reinterpret_cast<T*>(dlsym(chandle, fname.c_str()));
|
||||||
|
return (dlerror()==NULL)?fres:nullptr;
|
||||||
|
}
|
||||||
|
auto get_fnc(const std::string& fname) {
|
||||||
|
return get<void*(...)>(fname);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#else
|
||||||
|
#include <string>
|
||||||
|
#include <exception>
|
||||||
|
#include <libloaderapi.h>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Dlhandle {
|
||||||
|
HMODULE chandle;
|
||||||
|
|
||||||
|
public:
|
||||||
|
class Exception : public std::exception {
|
||||||
|
std::string errmsg;
|
||||||
|
public:
|
||||||
|
Exception(std::string errmsg) {
|
||||||
|
this->errmsg = errmsg;
|
||||||
|
}
|
||||||
|
virtual const char* what() const throw() {
|
||||||
|
return errmsg.c_str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Dlhandle() : chandle(nullptr) {}
|
||||||
|
Dlhandle(const std::string& fpath) {
|
||||||
|
chandle = LoadLibraryA(fpath.c_str());
|
||||||
|
if (!chandle) {
|
||||||
|
throw Exception("dlopen(): "+fpath);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Dlhandle(const Dlhandle& o) = delete;
|
||||||
|
Dlhandle(Dlhandle&& o) : chandle(o.chandle) {
|
||||||
|
o.chandle = nullptr;
|
||||||
|
}
|
||||||
|
~Dlhandle() {
|
||||||
|
if (chandle) FreeLibrary(chandle);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_valid() const {
|
||||||
|
return chandle != nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
T* get(const std::string& fname) {
|
||||||
|
return reinterpret_cast<T*>(GetProcAddress(chandle, fname.c_str()));
|
||||||
|
}
|
||||||
|
auto get_fnc(const std::string& fname) {
|
||||||
|
return get<void*(...)>(fname);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif
|
24
gptj.cpp
Normal file
24
gptj.cpp
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
#include "justlm_gptj.hpp"
|
||||||
|
#include "justlm.hpp"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
|
#include <fstream>
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
const LM::Implementation *get_justlm_implementation() {
|
||||||
|
static LM::Implementation fres{false};
|
||||||
|
return &fres;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool magic_match(uint32_t magic) {
|
||||||
|
return magic == 0x67676d6c;
|
||||||
|
}
|
||||||
|
|
||||||
|
LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) {
|
||||||
|
return new LM::GPTJInference(weights_path, f, p);
|
||||||
|
}
|
||||||
|
}
|
|
@ -146,5 +146,10 @@ public:
|
||||||
|
|
||||||
LM_LAST_ERROR_GETTER
|
LM_LAST_ERROR_GETTER
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct Implementation {
|
||||||
|
bool is_fallback = false;
|
||||||
|
};
|
||||||
}
|
}
|
||||||
#endif // JUSTLM_HPP
|
#endif // JUSTLM_HPP
|
||||||
|
|
72
justlm.cpp
72
justlm.cpp
|
@ -1,30 +1,66 @@
|
||||||
#include "justlm.hpp"
|
#include "justlm.hpp"
|
||||||
#include "justlm_llama.hpp"
|
#include "dlhandle.hpp"
|
||||||
#include "justlm_gptj.hpp"
|
|
||||||
#ifdef LM_MPT
|
|
||||||
# include "justlm_mpt.hpp"
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
|
#include <filesystem>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
static
|
||||||
|
Dlhandle get_implementation(uint32_t magic) {
|
||||||
|
Dlhandle matching;
|
||||||
|
Dlhandle fallback;
|
||||||
|
// Iterate over all libraries
|
||||||
|
for (const auto& f : std::filesystem::directory_iterator(".")) {
|
||||||
|
// Get path
|
||||||
|
const auto& p = f.path();
|
||||||
|
// Check extension
|
||||||
|
if (p.extension() != LIB_FILE_EXT) continue;
|
||||||
|
// Load library
|
||||||
|
try {
|
||||||
|
Dlhandle dl(p);
|
||||||
|
// Get implementation info getter
|
||||||
|
auto implementation_getter = dl.get<const LM::Implementation *()>("get_justlm_implementation");
|
||||||
|
if (!implementation_getter) continue;
|
||||||
|
// Get implementation info
|
||||||
|
const auto *implementation_info = implementation_getter();
|
||||||
|
// Set if fallback
|
||||||
|
if (implementation_info->is_fallback) {
|
||||||
|
fallback = std::move(dl);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
// Set if matching magic
|
||||||
|
auto magic_match = dl.get<bool(uint32_t)>("magic_match");
|
||||||
|
if (magic_match && magic_match(magic)) {
|
||||||
|
matching = std::move(dl);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
} catch (...) {}
|
||||||
|
}
|
||||||
|
// Return matching if any, fallback otherwise
|
||||||
|
if (matching) return matching;
|
||||||
|
return fallback;
|
||||||
|
}
|
||||||
|
|
||||||
LM::Inference *LM::Inference::construct(const std::string &weights_path, const Params &p) {
|
LM::Inference *LM::Inference::construct(const std::string &weights_path, const Params &p) {
|
||||||
|
static std::vector<Dlhandle> dls;
|
||||||
// Read magic
|
// Read magic
|
||||||
std::ifstream f(weights_path, std::ios::binary);
|
std::ifstream f(weights_path, std::ios::binary);
|
||||||
uint32_t magic;
|
uint32_t magic;
|
||||||
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
if (!f.read(reinterpret_cast<char*>(&magic), sizeof(magic))) {
|
||||||
// Create inference instance
|
throw Exception("Failed to open weights file for reading at "+weights_path);
|
||||||
if (magic == 0x67676d6c) {
|
|
||||||
f.seekg(0);
|
|
||||||
return new GPTJInference(weights_path, f, p);
|
|
||||||
# ifdef LM_MPT
|
|
||||||
} else if (magic == 0x67676d6d) {
|
|
||||||
f.seekg(0);
|
|
||||||
return new MPTInference(weights_path, f, p);
|
|
||||||
# endif
|
|
||||||
} else {
|
|
||||||
f.close();
|
|
||||||
return new LLaMaInference(weights_path, p);
|
|
||||||
}
|
}
|
||||||
|
f.seekg(0);
|
||||||
|
// Get correct implementation
|
||||||
|
auto impl = get_implementation(magic);
|
||||||
|
if (!impl) return nullptr;
|
||||||
|
// Get inference constructor
|
||||||
|
auto constructor = impl.get<LM::Inference *(const std::string &, std::ifstream&, const LM::Inference::Params &)>("construct");
|
||||||
|
if (!constructor) return nullptr;
|
||||||
|
// Back up Dlhandle
|
||||||
|
dls.push_back(std::move(impl));
|
||||||
|
// Construct inference
|
||||||
|
return constructor(weights_path, f, p);
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,7 +53,6 @@ class GPTJInference final : public Inference {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
|
|
||||||
if (state) {
|
if (state) {
|
||||||
if (state->model.ctx) ggml_free(state->model.ctx); //TODO: Is that enough?
|
|
||||||
delete state;
|
delete state;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -192,6 +191,7 @@ public:
|
||||||
const auto str = state->vocab.id_to_token[id];
|
const auto str = state->vocab.id_to_token[id];
|
||||||
|
|
||||||
// Append string to function result
|
// Append string to function result
|
||||||
|
state->prompt.append(str);
|
||||||
fres.append(str);
|
fres.append(str);
|
||||||
|
|
||||||
// Evaluate token
|
// Evaluate token
|
||||||
|
@ -207,7 +207,6 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create final string TODO: Could be optimized
|
// Create final string TODO: Could be optimized
|
||||||
state->prompt.append(fres);
|
|
||||||
if (!abort) {
|
if (!abort) {
|
||||||
fres = std::string(fres.data(), fres.size()-end.size());
|
fres = std::string(fres.data(), fres.size()-end.size());
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,7 @@ class LLaMaInference final : public Inference {
|
||||||
llama_context *ctx = nullptr;
|
llama_context *ctx = nullptr;
|
||||||
std::string prompt; // Mostly here for easy "debugging"
|
std::string prompt; // Mostly here for easy "debugging"
|
||||||
std::vector<int> tokens;
|
std::vector<int> tokens;
|
||||||
int n_ctx;
|
unsigned n_ctx;
|
||||||
};
|
};
|
||||||
|
|
||||||
State*& get_state() {
|
State*& get_state() {
|
||||||
|
@ -91,8 +91,8 @@ class LLaMaInference final : public Inference {
|
||||||
// Calculate progress
|
// Calculate progress
|
||||||
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
|
auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f;
|
||||||
// Tick and yield
|
// Tick and yield
|
||||||
if (!on_tick(progress)) LM_BOOL_SUCCESS;
|
if (!on_tick(progress)) LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
else if (!LM_TASKYIELD) LM_BOOL_SUCCESS;
|
else if (!LM_TASKYIELD) LM_CORETURN LM_BOOL_SUCCESS;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -182,6 +182,7 @@ public:
|
||||||
const auto str = llama_token_to_str(state->ctx, id);
|
const auto str = llama_token_to_str(state->ctx, id);
|
||||||
|
|
||||||
// Append string to function result
|
// Append string to function result
|
||||||
|
state->prompt.append(str);
|
||||||
fres.append(str);
|
fres.append(str);
|
||||||
|
|
||||||
// Evaluate token
|
// Evaluate token
|
||||||
|
@ -196,7 +197,6 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create final string TODO: Could be optimized
|
// Create final string TODO: Could be optimized
|
||||||
state->prompt.append(fres);
|
|
||||||
if (!abort) {
|
if (!abort) {
|
||||||
fres = std::string(fres.data(), fres.size()-end.size());
|
fres = std::string(fres.data(), fres.size()-end.size());
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,7 +53,6 @@ class MPTInference final : public Inference {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
|
|
||||||
if (state) {
|
if (state) {
|
||||||
if (state->model.ctx) ggml_free(state->model.ctx); //TODO: Is that enough?
|
|
||||||
delete state;
|
delete state;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -162,19 +161,6 @@ public:
|
||||||
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
|
LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick);
|
||||||
}
|
}
|
||||||
|
|
||||||
/*mpt_vocab::id mpt_sample_top_k_top_p(
|
|
||||||
const mpt_vocab & vocab,
|
|
||||||
const size_t actualVocabSize,
|
|
||||||
const int32_t * last_n_tokens_data,
|
|
||||||
int last_n_tokens_size,
|
|
||||||
const std::vector<float> logits,
|
|
||||||
int top_k,
|
|
||||||
double top_p,
|
|
||||||
double temp,
|
|
||||||
float repeat_penalty,
|
|
||||||
std::mt19937 & rng)
|
|
||||||
*/
|
|
||||||
|
|
||||||
LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
|
LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function<bool (const char *)> &on_tick = nullptr) LM_NOEXCEPTDECL override {
|
||||||
auto& state = get_state();
|
auto& state = get_state();
|
||||||
std::string fres;
|
std::string fres;
|
||||||
|
@ -184,7 +170,7 @@ public:
|
||||||
unsigned eos_count = 0;
|
unsigned eos_count = 0;
|
||||||
while (!abort && !ends_with(fres, end)) {
|
while (!abort && !ends_with(fres, end)) {
|
||||||
// Sample top p and top k
|
// Sample top p and top k
|
||||||
auto id = mpt_sample_top_k_top_p(state->vocab, state->model.hparams.n_vocab, state->tokens.data(), state->tokens.size(), state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
|
auto id = mpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data(), state->tokens.size(), state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
|
||||||
|
|
||||||
if (id == state->vocab.token_to_id["<|im_end|>"]) {
|
if (id == state->vocab.token_to_id["<|im_end|>"]) {
|
||||||
if (eos_count++ == params.eos_ignores) {
|
if (eos_count++ == params.eos_ignores) {
|
||||||
|
@ -206,6 +192,7 @@ public:
|
||||||
|
|
||||||
// Append string to function result
|
// Append string to function result
|
||||||
fres.append(str);
|
fres.append(str);
|
||||||
|
state->prompt.append(str);
|
||||||
|
|
||||||
// Evaluate token
|
// Evaluate token
|
||||||
// TODO: Respect batch size
|
// TODO: Respect batch size
|
||||||
|
@ -220,7 +207,6 @@ public:
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create final string TODO: Could be optimized
|
// Create final string TODO: Could be optimized
|
||||||
state->prompt.append(fres);
|
|
||||||
if (!abort) {
|
if (!abort) {
|
||||||
fres = std::string(fres.data(), fres.size()-end.size());
|
fres = std::string(fres.data(), fres.size()-end.size());
|
||||||
}
|
}
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
Subproject commit 0e018fe008eacebdbcfa2d61b6c988c245c961cd
|
|
25
llama.cpp
Normal file
25
llama.cpp
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
#include "justlm_llama.hpp"
|
||||||
|
#include "justlm.hpp"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
|
#include <fstream>
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
const LM::Implementation *get_justlm_implementation() {
|
||||||
|
static LM::Implementation fres{true};
|
||||||
|
return &fres;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool magic_match(uint32_t magic) {
|
||||||
|
return magic == 0x67676d6c;
|
||||||
|
}
|
||||||
|
|
||||||
|
LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) {
|
||||||
|
f.close();
|
||||||
|
return new LM::LLaMaInference(weights_path, p);
|
||||||
|
}
|
||||||
|
}
|
1
llama.cpp-mainline
Submodule
1
llama.cpp-mainline
Submodule
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit 0e018fe008eacebdbcfa2d61b6c988c245c961cd
|
356
llama.cpp.cmake
Normal file
356
llama.cpp.cmake
Normal file
|
@ -0,0 +1,356 @@
|
||||||
|
cmake_minimum_required(VERSION 3.12) # Don't bump this version for no reason
|
||||||
|
|
||||||
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
|
||||||
|
if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
|
||||||
|
set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)
|
||||||
|
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
|
||||||
|
|
||||||
|
if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR)
|
||||||
|
set(LLAMA_STANDALONE ON)
|
||||||
|
|
||||||
|
# configure project version
|
||||||
|
# TODO
|
||||||
|
else()
|
||||||
|
set(LLAMA_STANDALONE OFF)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (EMSCRIPTEN)
|
||||||
|
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
||||||
|
|
||||||
|
option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON)
|
||||||
|
else()
|
||||||
|
if (MINGW)
|
||||||
|
set(BUILD_SHARED_LIBS_DEFAULT OFF)
|
||||||
|
else()
|
||||||
|
set(BUILD_SHARED_LIBS_DEFAULT ON)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
# Option list
|
||||||
|
#
|
||||||
|
|
||||||
|
# general
|
||||||
|
option(LLAMA_STATIC "llama: static link libraries" OFF)
|
||||||
|
option(LLAMA_NATIVE "llama: enable -march=native flag" OFF)
|
||||||
|
option(LLAMA_LTO "llama: enable link time optimization" OFF)
|
||||||
|
|
||||||
|
# debug
|
||||||
|
option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON)
|
||||||
|
option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF)
|
||||||
|
option(LLAMA_GPROF "llama: enable gprof" OFF)
|
||||||
|
|
||||||
|
# sanitizers
|
||||||
|
option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF)
|
||||||
|
option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF)
|
||||||
|
option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF)
|
||||||
|
|
||||||
|
# instruction set specific
|
||||||
|
option(LLAMA_AVX "llama: enable AVX" ON)
|
||||||
|
option(LLAMA_AVX2 "llama: enable AVX2" ON)
|
||||||
|
option(LLAMA_AVX512 "llama: enable AVX512" OFF)
|
||||||
|
option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF)
|
||||||
|
option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF)
|
||||||
|
option(LLAMA_FMA "llama: enable FMA" ON)
|
||||||
|
# in MSVC F16C is implied with AVX2/AVX512
|
||||||
|
if (NOT MSVC)
|
||||||
|
option(LLAMA_F16C "llama: enable F16C" ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# 3rd party libs
|
||||||
|
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
|
||||||
|
option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF)
|
||||||
|
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
|
||||||
|
option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
|
||||||
|
|
||||||
|
#
|
||||||
|
# Compile flags
|
||||||
|
#
|
||||||
|
|
||||||
|
set(CMAKE_C_STANDARD 11)
|
||||||
|
set(CMAKE_C_STANDARD_REQUIRED true)
|
||||||
|
set(THREADS_PREFER_PTHREAD_FLAG ON)
|
||||||
|
find_package(Threads REQUIRED)
|
||||||
|
|
||||||
|
if (NOT MSVC)
|
||||||
|
if (LLAMA_SANITIZE_THREAD)
|
||||||
|
add_compile_options(-fsanitize=thread)
|
||||||
|
link_libraries(-fsanitize=thread)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_SANITIZE_ADDRESS)
|
||||||
|
add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
|
||||||
|
link_libraries(-fsanitize=address)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_SANITIZE_UNDEFINED)
|
||||||
|
add_compile_options(-fsanitize=undefined)
|
||||||
|
link_libraries(-fsanitize=undefined)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (APPLE AND LLAMA_ACCELERATE)
|
||||||
|
find_library(ACCELERATE_FRAMEWORK Accelerate)
|
||||||
|
if (ACCELERATE_FRAMEWORK)
|
||||||
|
message(STATUS "Accelerate framework found")
|
||||||
|
|
||||||
|
add_compile_definitions(GGML_USE_ACCELERATE)
|
||||||
|
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
|
||||||
|
else()
|
||||||
|
message(WARNING "Accelerate framework not found")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_OPENBLAS)
|
||||||
|
if (LLAMA_STATIC)
|
||||||
|
set(BLA_STATIC ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
set(BLA_VENDOR OpenBLAS)
|
||||||
|
find_package(BLAS)
|
||||||
|
if (BLAS_FOUND)
|
||||||
|
message(STATUS "OpenBLAS found")
|
||||||
|
|
||||||
|
add_compile_definitions(GGML_USE_OPENBLAS)
|
||||||
|
add_link_options(${BLAS_LIBRARIES})
|
||||||
|
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} openblas)
|
||||||
|
|
||||||
|
# find header file
|
||||||
|
set(OPENBLAS_INCLUDE_SEARCH_PATHS
|
||||||
|
/usr/include
|
||||||
|
/usr/include/openblas
|
||||||
|
/usr/include/openblas-base
|
||||||
|
/usr/local/include
|
||||||
|
/usr/local/include/openblas
|
||||||
|
/usr/local/include/openblas-base
|
||||||
|
/opt/OpenBLAS/include
|
||||||
|
$ENV{OpenBLAS_HOME}
|
||||||
|
$ENV{OpenBLAS_HOME}/include
|
||||||
|
)
|
||||||
|
find_path(OPENBLAS_INC NAMES cblas.h PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS})
|
||||||
|
add_compile_options(-I${OPENBLAS_INC})
|
||||||
|
else()
|
||||||
|
message(WARNING "OpenBLAS not found")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_ALL_WARNINGS)
|
||||||
|
if (NOT MSVC)
|
||||||
|
set(c_flags
|
||||||
|
-Wall
|
||||||
|
-Wextra
|
||||||
|
-Wpedantic
|
||||||
|
-Wcast-qual
|
||||||
|
-Wdouble-promotion
|
||||||
|
-Wshadow
|
||||||
|
-Wstrict-prototypes
|
||||||
|
-Wpointer-arith
|
||||||
|
)
|
||||||
|
set(cxx_flags
|
||||||
|
-Wall
|
||||||
|
-Wextra
|
||||||
|
-Wpedantic
|
||||||
|
-Wcast-qual
|
||||||
|
-Wno-unused-function
|
||||||
|
-Wno-multichar
|
||||||
|
)
|
||||||
|
else()
|
||||||
|
# todo : msvc
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_compile_options(
|
||||||
|
"$<$<COMPILE_LANGUAGE:C>:${c_flags}>"
|
||||||
|
"$<$<COMPILE_LANGUAGE:CXX>:${cxx_flags}>"
|
||||||
|
)
|
||||||
|
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (MSVC)
|
||||||
|
add_compile_definitions(_CRT_SECURE_NO_WARNINGS)
|
||||||
|
|
||||||
|
if (BUILD_SHARED_LIBS)
|
||||||
|
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_LTO)
|
||||||
|
include(CheckIPOSupported)
|
||||||
|
check_ipo_supported(RESULT result OUTPUT output)
|
||||||
|
if (result)
|
||||||
|
set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE)
|
||||||
|
else()
|
||||||
|
message(WARNING "IPO is not supported: ${output}")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
# Architecture specific
|
||||||
|
# TODO: probably these flags need to be tweaked on some architectures
|
||||||
|
# feel free to update the Makefile for your architecture and send a pull request or issue
|
||||||
|
message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}")
|
||||||
|
if (NOT MSVC)
|
||||||
|
if (LLAMA_STATIC)
|
||||||
|
add_link_options(-static)
|
||||||
|
if (MINGW)
|
||||||
|
add_link_options(-static-libgcc -static-libstdc++)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
if (LLAMA_GPROF)
|
||||||
|
add_compile_options(-pg)
|
||||||
|
endif()
|
||||||
|
if (LLAMA_NATIVE)
|
||||||
|
add_compile_options(-march=native)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
|
||||||
|
message(STATUS "ARM detected")
|
||||||
|
if (MSVC)
|
||||||
|
# TODO: arm msvc?
|
||||||
|
else()
|
||||||
|
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64")
|
||||||
|
add_compile_options(-mcpu=native)
|
||||||
|
endif()
|
||||||
|
# TODO: armv6,7,8 version specific flags
|
||||||
|
endif()
|
||||||
|
elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$")
|
||||||
|
message(STATUS "x86 detected")
|
||||||
|
if (MSVC)
|
||||||
|
if (LLAMA_AVX512)
|
||||||
|
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX512>)
|
||||||
|
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX512>)
|
||||||
|
# MSVC has no compile-time flags enabling specific
|
||||||
|
# AVX512 extensions, neither it defines the
|
||||||
|
# macros corresponding to the extensions.
|
||||||
|
# Do it manually.
|
||||||
|
if (LLAMA_AVX512_VBMI)
|
||||||
|
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VBMI__>)
|
||||||
|
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VBMI__>)
|
||||||
|
endif()
|
||||||
|
if (LLAMA_AVX512_VNNI)
|
||||||
|
add_compile_definitions($<$<COMPILE_LANGUAGE:C>:__AVX512VNNI__>)
|
||||||
|
add_compile_definitions($<$<COMPILE_LANGUAGE:CXX>:__AVX512VNNI__>)
|
||||||
|
endif()
|
||||||
|
elseif (LLAMA_AVX2)
|
||||||
|
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX2>)
|
||||||
|
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX2>)
|
||||||
|
elseif (LLAMA_AVX)
|
||||||
|
add_compile_options($<$<COMPILE_LANGUAGE:C>:/arch:AVX>)
|
||||||
|
add_compile_options($<$<COMPILE_LANGUAGE:CXX>:/arch:AVX>)
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
if (LLAMA_F16C)
|
||||||
|
add_compile_options(-mf16c)
|
||||||
|
endif()
|
||||||
|
if (LLAMA_FMA)
|
||||||
|
add_compile_options(-mfma)
|
||||||
|
endif()
|
||||||
|
if (LLAMA_AVX)
|
||||||
|
add_compile_options(-mavx)
|
||||||
|
endif()
|
||||||
|
if (LLAMA_AVX2)
|
||||||
|
add_compile_options(-mavx2)
|
||||||
|
endif()
|
||||||
|
if (LLAMA_AVX512)
|
||||||
|
add_compile_options(-mavx512f)
|
||||||
|
add_compile_options(-mavx512bw)
|
||||||
|
endif()
|
||||||
|
if (LLAMA_AVX512_VBMI)
|
||||||
|
add_compile_options(-mavx512vbmi)
|
||||||
|
endif()
|
||||||
|
if (LLAMA_AVX512_VNNI)
|
||||||
|
add_compile_options(-mavx512vnni)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
else()
|
||||||
|
# TODO: support PowerPC
|
||||||
|
message(STATUS "Unknown architecture")
|
||||||
|
endif()
|
||||||
|
|
||||||
|
#
|
||||||
|
# Build libraries
|
||||||
|
#
|
||||||
|
|
||||||
|
function(include_ggml DIRECTORY SUFFIX WITH_LLAMA)
|
||||||
|
if (LLAMA_CUBLAS)
|
||||||
|
cmake_minimum_required(VERSION 3.17)
|
||||||
|
|
||||||
|
find_package(CUDAToolkit)
|
||||||
|
if (CUDAToolkit_FOUND)
|
||||||
|
message(STATUS "cuBLAS found")
|
||||||
|
|
||||||
|
enable_language(CUDA)
|
||||||
|
|
||||||
|
set(GGML_CUDA_SOURCES ${DIRECTORY}ggml-cuda.cu ${DIRECTORY}ggml-cuda.h)
|
||||||
|
|
||||||
|
add_compile_definitions(GGML_USE_CUBLAS)
|
||||||
|
|
||||||
|
if (LLAMA_STATIC)
|
||||||
|
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
|
||||||
|
else()
|
||||||
|
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
else()
|
||||||
|
message(WARNING "cuBLAS not found")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (LLAMA_CLBLAST)
|
||||||
|
find_package(CLBlast)
|
||||||
|
if (CLBlast_FOUND)
|
||||||
|
message(STATUS "CLBlast found")
|
||||||
|
|
||||||
|
set(GGML_OPENCL_SOURCES ${DIRECTORY}ggml-opencl.c ${DIRECTORY}ggml-opencl.h)
|
||||||
|
|
||||||
|
add_compile_definitions(GGML_USE_CLBLAST)
|
||||||
|
|
||||||
|
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast)
|
||||||
|
else()
|
||||||
|
message(WARNING "CLBlast not found")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
add_library(ggml${SUFFIX} OBJECT
|
||||||
|
${DIRECTORY}/ggml.c
|
||||||
|
${DIRECTORY}/ggml.h
|
||||||
|
${GGML_CUDA_SOURCES}
|
||||||
|
${GGML_OPENCL_SOURCES})
|
||||||
|
|
||||||
|
target_include_directories(ggml${SUFFIX} PUBLIC ${DIRECTORY})
|
||||||
|
target_compile_features(ggml${SUFFIX} PUBLIC c_std_11) # don't bump
|
||||||
|
target_link_libraries(ggml${SUFFIX} PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
|
||||||
|
|
||||||
|
if (BUILD_SHARED_LIBS)
|
||||||
|
set_target_properties(ggml${SUFFIX} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (WITH_LLAMA)
|
||||||
|
add_library(llama${SUFFIX}
|
||||||
|
${DIRECTORY}/llama.cpp
|
||||||
|
${DIRECTORY}/llama.h
|
||||||
|
${DIRECTORY}/llama_util.h)
|
||||||
|
|
||||||
|
target_include_directories(llama${SUFFIX} PUBLIC .)
|
||||||
|
target_compile_features(llama${SUFFIX} PUBLIC cxx_std_11) # don't bump
|
||||||
|
target_link_libraries(llama${SUFFIX} PRIVATE ggml${SUFFIX} ${LLAMA_EXTRA_LIBS})
|
||||||
|
|
||||||
|
if (BUILD_SHARED_LIBS)
|
||||||
|
set_target_properties(llama${SUFFIX} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||||
|
target_compile_definitions(llama${SUFFIX} PRIVATE LLAMA_SHARED LLAMA_BUILD)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
|
if (GGML_CUDA_SOURCES)
|
||||||
|
message(STATUS "GGML CUDA sources found, configuring CUDA architecture")
|
||||||
|
set_property(TARGET ggml${SUFFIX} PROPERTY CUDA_ARCHITECTURES OFF)
|
||||||
|
set_property(TARGET ggml${SUFFIX} PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto")
|
||||||
|
if (WITH_LLAMA)
|
||||||
|
set_property(TARGET llama${SUFFIX} PROPERTY CUDA_ARCHITECTURES OFF)
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
endfunction()
|
24
mpt.cpp
Normal file
24
mpt.cpp
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
#include "justlm_mpt.hpp"
|
||||||
|
#include "justlm.hpp"
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <string_view>
|
||||||
|
#include <fstream>
|
||||||
|
#include <cstdint>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
const LM::Implementation *get_justlm_implementation() {
|
||||||
|
static LM::Implementation fres{false};
|
||||||
|
return &fres;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool magic_match(uint32_t magic) {
|
||||||
|
return magic == 0x67676d6d;
|
||||||
|
}
|
||||||
|
|
||||||
|
LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) {
|
||||||
|
return new LM::MPTInference(weights_path, f, p);
|
||||||
|
}
|
||||||
|
}
|
28
mpt/mpt.cpp
28
mpt/mpt.cpp
|
@ -187,9 +187,9 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod
|
||||||
// create the ggml context
|
// create the ggml context
|
||||||
{
|
{
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
.mem_size = ctx_size,
|
ctx_size,
|
||||||
.mem_buffer = NULL,
|
NULL,
|
||||||
.no_alloc = false,
|
false,
|
||||||
};
|
};
|
||||||
|
|
||||||
model.ctx = ggml_init(params);
|
model.ctx = ggml_init(params);
|
||||||
|
@ -362,30 +362,29 @@ bool mpt_eval(
|
||||||
const int n_head = hparams.n_head;
|
const int n_head = hparams.n_head;
|
||||||
const int n_vocab = hparams.n_vocab;
|
const int n_vocab = hparams.n_vocab;
|
||||||
|
|
||||||
static size_t buf_size = 256u*1024*1024;
|
|
||||||
static void * buf = malloc(buf_size);
|
|
||||||
|
|
||||||
if (mem_per_token > 0 && mem_per_token*N > buf_size) {
|
if (mem_per_token > 0 && mem_per_token*N > model.eval_buf_size) {
|
||||||
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
|
const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead
|
||||||
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
//printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
|
||||||
|
|
||||||
// reallocate
|
// reallocate
|
||||||
buf_size = buf_size_new;
|
model.eval_buf_size = buf_size_new;
|
||||||
buf = realloc(buf, buf_size);
|
model.eval_buf = realloc(model.eval_buf, model.eval_buf_size);
|
||||||
if (buf == nullptr) {
|
if (model.eval_buf == nullptr) {
|
||||||
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size);
|
fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, model.eval_buf_size);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params = {
|
||||||
.mem_size = buf_size,
|
model.eval_buf_size,
|
||||||
.mem_buffer = buf,
|
model.eval_buf,
|
||||||
.no_alloc = false,
|
false
|
||||||
};
|
};
|
||||||
|
|
||||||
struct ggml_context * ctx0 = ggml_init(params);
|
struct ggml_context * ctx0 = ggml_init(params);
|
||||||
struct ggml_cgraph gf = { .n_threads = n_threads };
|
struct ggml_cgraph gf;
|
||||||
|
gf.n_threads = n_threads;
|
||||||
|
|
||||||
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||||
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd));
|
||||||
|
@ -692,7 +691,6 @@ size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint
|
||||||
}
|
}
|
||||||
|
|
||||||
mpt_vocab::id mpt_sample_top_k_top_p(
|
mpt_vocab::id mpt_sample_top_k_top_p(
|
||||||
const mpt_vocab & vocab,
|
|
||||||
const size_t actualVocabSize,
|
const size_t actualVocabSize,
|
||||||
const int32_t * last_n_tokens_data,
|
const int32_t * last_n_tokens_data,
|
||||||
int last_n_tokens_size,
|
int last_n_tokens_size,
|
||||||
|
|
|
@ -83,10 +83,16 @@ struct mpt_model {
|
||||||
struct ggml_context * ctx;
|
struct ggml_context * ctx;
|
||||||
std::map<std::string, struct ggml_tensor *> tensors;
|
std::map<std::string, struct ggml_tensor *> tensors;
|
||||||
|
|
||||||
|
size_t eval_buf_size = 256u*1024*1024;
|
||||||
|
void *eval_buf;
|
||||||
|
|
||||||
mpt_buffer buf;
|
mpt_buffer buf;
|
||||||
|
|
||||||
|
mpt_model() {
|
||||||
|
eval_buf = malloc(eval_buf_size);
|
||||||
|
}
|
||||||
~mpt_model() {
|
~mpt_model() {
|
||||||
|
free(eval_buf);
|
||||||
if (ctx) {
|
if (ctx) {
|
||||||
ggml_free(ctx);
|
ggml_free(ctx);
|
||||||
}
|
}
|
||||||
|
@ -110,7 +116,7 @@ struct mpt_vocab {
|
||||||
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, mpt_vocab & vocab);
|
bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, mpt_vocab & vocab);
|
||||||
bool mpt_eval(mpt_model& model, const int n_threads, const int n_past, const std::vector<int>& embd_inp, std::vector<float>& embd_w, size_t& mem_per_token);
|
bool mpt_eval(mpt_model& model, const int n_threads, const int n_past, const std::vector<int>& embd_inp, std::vector<float>& embd_w, size_t& mem_per_token);
|
||||||
std::vector<mpt_vocab::id> mpt_tokenize(const mpt_vocab & vocab, const std::string & text);
|
std::vector<mpt_vocab::id> mpt_tokenize(const mpt_vocab & vocab, const std::string & text);
|
||||||
mpt_vocab::id mpt_sample_top_k_top_p(const mpt_vocab& vocab, const size_t actualVocabSize, const int32_t *last_n_tokens_data, int last_n_tokens_size, const std::vector<float> logits, int top_k, double top_p, double temp, float repeat_penalty, std::mt19937& rng);
|
mpt_vocab::id mpt_sample_top_k_top_p(const size_t actualVocabSize, const int32_t *last_n_tokens_data, int last_n_tokens_size, const std::vector<float> logits, int top_k, double top_p, double temp, float repeat_penalty, std::mt19937& rng);
|
||||||
size_t mpt_get_state_size(const mpt_model &model);
|
size_t mpt_get_state_size(const mpt_model &model);
|
||||||
size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937& rng, uint8_t *dest);
|
size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937& rng, uint8_t *dest);
|
||||||
size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *src);
|
size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *src);
|
||||||
|
|
|
@ -9,7 +9,7 @@ namespace py = pybind11;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
PYBIND11_MODULE(libjustlm_py, m) {
|
PYBIND11_MODULE(justlm_py, m) {
|
||||||
using namespace LM;
|
using namespace LM;
|
||||||
py::class_<Inference::Params>(m, "Params")
|
py::class_<Inference::Params>(m, "Params")
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
|
|
Loading…
Add table
Reference in a new issue