From 60fe6b9c558c284d72aa8e5ae181ca803f8ef58f Mon Sep 17 00:00:00 2001 From: niansa Date: Tue, 16 May 2023 19:10:05 +0000 Subject: [PATCH] Load implemenations as shared objects --- .gitmodules | 2 +- CMakeLists.txt | 94 +++++++----- README.md | 4 +- dlhandle.hpp | 108 ++++++++++++++ gptj.cpp | 24 +++ include/justlm.hpp | 5 + justlm.cpp | 72 ++++++--- justlm_gptj.hpp | 3 +- justlm_llama.hpp | 8 +- justlm_mpt.hpp | 18 +-- llama.cpp | 26 +++- llama.cpp-mainline | 1 + llama.cpp.cmake | 356 +++++++++++++++++++++++++++++++++++++++++++++ mpt.cpp | 24 +++ mpt/mpt.cpp | 28 ++-- mpt/mpt.hpp | 8 +- pybind.cpp | 2 +- 17 files changed, 686 insertions(+), 97 deletions(-) create mode 100644 dlhandle.hpp create mode 100644 gptj.cpp mode change 160000 => 100644 llama.cpp create mode 160000 llama.cpp-mainline create mode 100644 llama.cpp.cmake create mode 100644 mpt.cpp diff --git a/.gitmodules b/.gitmodules index 3f9cc66..6575f7f 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,5 +1,5 @@ [submodule "llama.cpp"] - path = llama.cpp + path = llama.cpp-mainline url = https://github.com/ggerganov/llama.cpp.git [submodule "llama.cpp-alibi"] path = llama.cpp-alibi diff --git a/CMakeLists.txt b/CMakeLists.txt index 38b50af..72bcb28 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,52 +1,76 @@ -cmake_minimum_required(VERSION 3.14) +cmake_minimum_required(VERSION 3.18) + +project(justlm LANGUAGES C CXX) -project(libjustlm LANGUAGES C CXX) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_POSITION_INDEPENDENT_CODE ON) -set(LM_PYBIND No CACHE BOOL "If Libjustlm Python bindings should be build") -set(LM_COSCHED No CACHE BOOL "If Libjustlm should make use of CoSched") -set(LM_NOEXCEPT No CACHE BOOL "If exceptions should be disabled") -set(LM_MPT No CACHE BOOL "If MPT model support should be built") +set(LM_PYBIND No CACHE BOOL "If justlm Python bindings should be build") +set(LM_COSCHED No CACHE BOOL "If justlm should make use of CoSched") +set(LM_NOEXCEPT No CACHE BOOL "If justlm exceptions should be disabled") +set(LM_LLAMA Yes CACHE BOOL "If LLaMa model support should be built into justlm") +set(LM_GPTJ Yes CACHE BOOL "If GPT-J model support should be built into justlm") +set(LM_MPT Yes CACHE BOOL "If MPT model support should be built into justlm") if (LM_COSCHED) set(CMAKE_CXX_STANDARD 20) endif() + +function(target_justlm_setup target) + target_include_directories(${target} PUBLIC include/) + if (LM_COSCHED) + target_compile_definitions(${target} PUBLIC LM_COSCHED) + target_link_libraries(${target} PRIVATE cosched) + endif() + if (LM_NOEXCEPT) + target_compile_definitions(${target} PUBLIC LM_NOEXCEPT) + endif() +endfunction() + + +include(llama.cpp.cmake) + +include_ggml(llama.cpp-mainline _mainline Yes) +include_ggml(llama.cpp-alibi _alibi No) + + +add_library(justlm_g4a_common SHARED g4a-common.cpp g4a-common.hpp) + + +set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) + if (LM_MPT) - set(LM_MPT_SOURCES justlm_mpt.hpp mpt/mpt.cpp mpt/mpt.hpp) - add_subdirectory(llama.cpp-alibi) -else() - set(LM_MPT_SOURCES ) - add_subdirectory(llama.cpp) + add_library(justlm_mpt SHARED mpt.cpp justlm_mpt.hpp mpt/mpt.cpp mpt/mpt.hpp) + target_link_libraries(justlm_mpt PRIVATE ggml_alibi justlm_g4a_common) + target_justlm_setup(justlm_mpt) endif() -add_library(libjustlm STATIC +if (LM_GPTJ) + add_library(justlm_gptj SHARED gptj.cpp justlm_gptj.hpp gptj/gptj.cpp gptj/gptj.hpp) + target_link_libraries(justlm_gptj PRIVATE ggml_mainline justlm_g4a_common) + target_justlm_setup(justlm_gptj) +endif() + +if (LM_LLAMA) + add_library(justlm_llama SHARED llama.cpp justlm_llama.hpp) + target_link_libraries(justlm_llama PRIVATE ggml_mainline llama_mainline) + target_justlm_setup(justlm_llama) +endif() + + +add_library(justlm STATIC include/justlm.hpp justlm.cpp - justlm_llama.hpp - g4a-common.cpp g4a-common.hpp - justlm_gptj.hpp gptj/gptj.cpp gptj/gptj.hpp - ${LM_MPT_SOURCES} include/justlm_pool.hpp justlm_pool.cpp + dlhandle.hpp ) -target_link_libraries(libjustlm PRIVATE llama) - -if (LM_MPT) - target_compile_definitions(libjustlm PUBLIC LM_MPT) -endif() - -if (LM_COSCHED) - target_compile_definitions(libjustlm PUBLIC LM_COSCHED) - target_link_libraries(libjustlm PRIVATE cosched) - - set(LM_COSCHED Yes CACHE BOOL "If Libjustlm should make use of CoSched" FORCE) -endif() - -if (LM_NOEXCEPT) - target_compile_definitions(libjustlm PUBLIC LM_NOEXCEPT) -endif() +add_library(libjustlm ALIAS justlm) +target_link_libraries(justlm PRIVATE dl) +target_include_directories(justlm PUBLIC include/) +target_compile_definitions(justlm PRIVATE LIB_FILE_EXT="${CMAKE_SHARED_LIBRARY_SUFFIX}") +target_justlm_setup(justlm) if (LM_PYBIND) if (LM_COSCHED) @@ -55,8 +79,6 @@ if (LM_PYBIND) find_package(Python COMPONENTS Interpreter Development) find_package(pybind11 CONFIG) - pybind11_add_module(libjustlm_py pybind.cpp) - target_link_libraries(libjustlm_py PRIVATE libjustlm) + pybind11_add_module(justlm_py pybind.cpp) + target_link_libraries(justlm_py PRIVATE justlm) endif() - -target_include_directories(libjustlm PUBLIC include/) diff --git a/README.md b/README.md index 39622f0..6155872 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ # JustLM -Super easy to use library for doing LLaMA/GPT-J stuff! +Super easy to use library for doing LLaMA/GPT-J/MPT stuff! ## Overview -This library implements an easy to use interface to both LLaMa and GPT-J, with optional Python bindings. +This library implements an easy to use interface to LLaMa, GPT-J and MPT, with optional Python bindings. Context scrolling is automatic and supports a top window bar. diff --git a/dlhandle.hpp b/dlhandle.hpp new file mode 100644 index 0000000..af58034 --- /dev/null +++ b/dlhandle.hpp @@ -0,0 +1,108 @@ +#ifndef __WIN32 +#include +#include +#include +#include + + +class Dlhandle { + void *chandle; + +public: + class Exception : public std::exception { + std::string errmsg; + public: + Exception(std::string errmsg) { + this->errmsg = errmsg; + } + virtual const char* what() const throw() { + return errmsg.c_str(); + } + }; + + Dlhandle() : chandle(nullptr) {} + Dlhandle(const std::string& fpath, int flags = RTLD_LAZY) { + chandle = dlopen(fpath.c_str(), flags); + if (!chandle) { + throw Exception("dlopen(): "+fpath); + } + } + Dlhandle(const Dlhandle& o) = delete; + Dlhandle(Dlhandle&& o) : chandle(o.chandle) { + o.chandle = nullptr; + } + ~Dlhandle() { + if (chandle) dlclose(chandle); + } + + auto operator =(Dlhandle&& o) { + chandle = std::exchange(o.chandle, nullptr); + } + + bool is_valid() const { + return chandle != nullptr; + } + operator bool() const { + return is_valid(); + } + + template + T* get(const std::string& fname) { + dlerror(); // Clear error + auto fres = reinterpret_cast(dlsym(chandle, fname.c_str())); + return (dlerror()==NULL)?fres:nullptr; + } + auto get_fnc(const std::string& fname) { + return get(fname); + } +}; +#else +#include +#include +#include + + + +class Dlhandle { + HMODULE chandle; + +public: + class Exception : public std::exception { + std::string errmsg; + public: + Exception(std::string errmsg) { + this->errmsg = errmsg; + } + virtual const char* what() const throw() { + return errmsg.c_str(); + } + }; + + Dlhandle() : chandle(nullptr) {} + Dlhandle(const std::string& fpath) { + chandle = LoadLibraryA(fpath.c_str()); + if (!chandle) { + throw Exception("dlopen(): "+fpath); + } + } + Dlhandle(const Dlhandle& o) = delete; + Dlhandle(Dlhandle&& o) : chandle(o.chandle) { + o.chandle = nullptr; + } + ~Dlhandle() { + if (chandle) FreeLibrary(chandle); + } + + bool is_valid() const { + return chandle != nullptr; + } + + template + T* get(const std::string& fname) { + return reinterpret_cast(GetProcAddress(chandle, fname.c_str())); + } + auto get_fnc(const std::string& fname) { + return get(fname); + } +}; +#endif diff --git a/gptj.cpp b/gptj.cpp new file mode 100644 index 0000000..6bdd7c3 --- /dev/null +++ b/gptj.cpp @@ -0,0 +1,24 @@ +#include "justlm_gptj.hpp" +#include "justlm.hpp" + +#include +#include +#include +#include + + + +extern "C" { +const LM::Implementation *get_justlm_implementation() { + static LM::Implementation fres{false}; + return &fres; +} + +bool magic_match(uint32_t magic) { + return magic == 0x67676d6c; +} + +LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) { + return new LM::GPTJInference(weights_path, f, p); +} +} diff --git a/include/justlm.hpp b/include/justlm.hpp index ad02104..ea79162 100644 --- a/include/justlm.hpp +++ b/include/justlm.hpp @@ -146,5 +146,10 @@ public: LM_LAST_ERROR_GETTER }; + + +struct Implementation { + bool is_fallback = false; +}; } #endif // JUSTLM_HPP diff --git a/justlm.cpp b/justlm.cpp index 4f68b6e..ad05ea9 100644 --- a/justlm.cpp +++ b/justlm.cpp @@ -1,30 +1,66 @@ #include "justlm.hpp" -#include "justlm_llama.hpp" -#include "justlm_gptj.hpp" -#ifdef LM_MPT -# include "justlm_mpt.hpp" -#endif +#include "dlhandle.hpp" +#include +#include #include +#include +static +Dlhandle get_implementation(uint32_t magic) { + Dlhandle matching; + Dlhandle fallback; + // Iterate over all libraries + for (const auto& f : std::filesystem::directory_iterator(".")) { + // Get path + const auto& p = f.path(); + // Check extension + if (p.extension() != LIB_FILE_EXT) continue; + // Load library + try { + Dlhandle dl(p); + // Get implementation info getter + auto implementation_getter = dl.get("get_justlm_implementation"); + if (!implementation_getter) continue; + // Get implementation info + const auto *implementation_info = implementation_getter(); + // Set if fallback + if (implementation_info->is_fallback) { + fallback = std::move(dl); + continue; + } + // Set if matching magic + auto magic_match = dl.get("magic_match"); + if (magic_match && magic_match(magic)) { + matching = std::move(dl); + continue; + } + } catch (...) {} + } + // Return matching if any, fallback otherwise + if (matching) return matching; + return fallback; +} + LM::Inference *LM::Inference::construct(const std::string &weights_path, const Params &p) { + static std::vector dls; // Read magic std::ifstream f(weights_path, std::ios::binary); uint32_t magic; - f.read(reinterpret_cast(&magic), sizeof(magic)); - // Create inference instance - if (magic == 0x67676d6c) { - f.seekg(0); - return new GPTJInference(weights_path, f, p); -# ifdef LM_MPT - } else if (magic == 0x67676d6d) { - f.seekg(0); - return new MPTInference(weights_path, f, p); -# endif - } else { - f.close(); - return new LLaMaInference(weights_path, p); + if (!f.read(reinterpret_cast(&magic), sizeof(magic))) { + throw Exception("Failed to open weights file for reading at "+weights_path); } + f.seekg(0); + // Get correct implementation + auto impl = get_implementation(magic); + if (!impl) return nullptr; + // Get inference constructor + auto constructor = impl.get("construct"); + if (!constructor) return nullptr; + // Back up Dlhandle + dls.push_back(std::move(impl)); + // Construct inference + return constructor(weights_path, f, p); } diff --git a/justlm_gptj.hpp b/justlm_gptj.hpp index c6de499..fa6514c 100644 --- a/justlm_gptj.hpp +++ b/justlm_gptj.hpp @@ -53,7 +53,6 @@ class GPTJInference final : public Inference { auto& state = get_state(); if (state) { - if (state->model.ctx) ggml_free(state->model.ctx); //TODO: Is that enough? delete state; } } @@ -192,6 +191,7 @@ public: const auto str = state->vocab.id_to_token[id]; // Append string to function result + state->prompt.append(str); fres.append(str); // Evaluate token @@ -207,7 +207,6 @@ public: } // Create final string TODO: Could be optimized - state->prompt.append(fres); if (!abort) { fres = std::string(fres.data(), fres.size()-end.size()); } diff --git a/justlm_llama.hpp b/justlm_llama.hpp index 04cdc1c..cb7c9a1 100644 --- a/justlm_llama.hpp +++ b/justlm_llama.hpp @@ -11,7 +11,7 @@ class LLaMaInference final : public Inference { llama_context *ctx = nullptr; std::string prompt; // Mostly here for easy "debugging" std::vector tokens; - int n_ctx; + unsigned n_ctx; }; State*& get_state() { @@ -91,8 +91,8 @@ class LLaMaInference final : public Inference { // Calculate progress auto progress = float(it-starting_offset) / (state->tokens.size()-starting_offset) * 100.f; // Tick and yield - if (!on_tick(progress)) LM_BOOL_SUCCESS; - else if (!LM_TASKYIELD) LM_BOOL_SUCCESS; + if (!on_tick(progress)) LM_CORETURN LM_BOOL_SUCCESS; + else if (!LM_TASKYIELD) LM_CORETURN LM_BOOL_SUCCESS; } } @@ -182,6 +182,7 @@ public: const auto str = llama_token_to_str(state->ctx, id); // Append string to function result + state->prompt.append(str); fres.append(str); // Evaluate token @@ -196,7 +197,6 @@ public: } // Create final string TODO: Could be optimized - state->prompt.append(fres); if (!abort) { fres = std::string(fres.data(), fres.size()-end.size()); } diff --git a/justlm_mpt.hpp b/justlm_mpt.hpp index a504f5a..1e01a89 100644 --- a/justlm_mpt.hpp +++ b/justlm_mpt.hpp @@ -53,7 +53,6 @@ class MPTInference final : public Inference { auto& state = get_state(); if (state) { - if (state->model.ctx) ggml_free(state->model.ctx); //TODO: Is that enough? delete state; } } @@ -162,19 +161,6 @@ public: LM_CORETURN LM_COAWAIT evaluate_tokens(old_token_count, on_tick); } - /*mpt_vocab::id mpt_sample_top_k_top_p( - const mpt_vocab & vocab, - const size_t actualVocabSize, - const int32_t * last_n_tokens_data, - int last_n_tokens_size, - const std::vector logits, - int top_k, - double top_p, - double temp, - float repeat_penalty, - std::mt19937 & rng) -*/ - LM_SCHEDULABLE(std::string) run(std::string_view end, const std::function &on_tick = nullptr) LM_NOEXCEPTDECL override { auto& state = get_state(); std::string fres; @@ -184,7 +170,7 @@ public: unsigned eos_count = 0; while (!abort && !ends_with(fres, end)) { // Sample top p and top k - auto id = mpt_sample_top_k_top_p(state->vocab, state->model.hparams.n_vocab, state->tokens.data(), state->tokens.size(), state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng); + auto id = mpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data(), state->tokens.size(), state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng); if (id == state->vocab.token_to_id["<|im_end|>"]) { if (eos_count++ == params.eos_ignores) { @@ -206,6 +192,7 @@ public: // Append string to function result fres.append(str); + state->prompt.append(str); // Evaluate token // TODO: Respect batch size @@ -220,7 +207,6 @@ public: } // Create final string TODO: Could be optimized - state->prompt.append(fres); if (!abort) { fres = std::string(fres.data(), fres.size()-end.size()); } diff --git a/llama.cpp b/llama.cpp deleted file mode 160000 index 0e018fe..0000000 --- a/llama.cpp +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 0e018fe008eacebdbcfa2d61b6c988c245c961cd diff --git a/llama.cpp b/llama.cpp new file mode 100644 index 0000000..0c51fe9 --- /dev/null +++ b/llama.cpp @@ -0,0 +1,25 @@ +#include "justlm_llama.hpp" +#include "justlm.hpp" + +#include +#include +#include +#include + + + +extern "C" { +const LM::Implementation *get_justlm_implementation() { + static LM::Implementation fres{true}; + return &fres; +} + +bool magic_match(uint32_t magic) { + return magic == 0x67676d6c; +} + +LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) { + f.close(); + return new LM::LLaMaInference(weights_path, p); +} +} diff --git a/llama.cpp-mainline b/llama.cpp-mainline new file mode 160000 index 0000000..0e018fe --- /dev/null +++ b/llama.cpp-mainline @@ -0,0 +1 @@ +Subproject commit 0e018fe008eacebdbcfa2d61b6c988c245c961cd diff --git a/llama.cpp.cmake b/llama.cpp.cmake new file mode 100644 index 0000000..8d0cb86 --- /dev/null +++ b/llama.cpp.cmake @@ -0,0 +1,356 @@ +cmake_minimum_required(VERSION 3.12) # Don't bump this version for no reason + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE) + set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo") +endif() + +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) + +if(CMAKE_SOURCE_DIR STREQUAL CMAKE_CURRENT_SOURCE_DIR) + set(LLAMA_STANDALONE ON) + + # configure project version + # TODO +else() + set(LLAMA_STANDALONE OFF) +endif() + +if (EMSCRIPTEN) + set(BUILD_SHARED_LIBS_DEFAULT OFF) + + option(LLAMA_WASM_SINGLE_FILE "llama: embed WASM inside the generated llama.js" ON) +else() + if (MINGW) + set(BUILD_SHARED_LIBS_DEFAULT OFF) + else() + set(BUILD_SHARED_LIBS_DEFAULT ON) + endif() +endif() + + +# +# Option list +# + +# general +option(LLAMA_STATIC "llama: static link libraries" OFF) +option(LLAMA_NATIVE "llama: enable -march=native flag" OFF) +option(LLAMA_LTO "llama: enable link time optimization" OFF) + +# debug +option(LLAMA_ALL_WARNINGS "llama: enable all compiler warnings" ON) +option(LLAMA_ALL_WARNINGS_3RD_PARTY "llama: enable all compiler warnings in 3rd party libs" OFF) +option(LLAMA_GPROF "llama: enable gprof" OFF) + +# sanitizers +option(LLAMA_SANITIZE_THREAD "llama: enable thread sanitizer" OFF) +option(LLAMA_SANITIZE_ADDRESS "llama: enable address sanitizer" OFF) +option(LLAMA_SANITIZE_UNDEFINED "llama: enable undefined sanitizer" OFF) + +# instruction set specific +option(LLAMA_AVX "llama: enable AVX" ON) +option(LLAMA_AVX2 "llama: enable AVX2" ON) +option(LLAMA_AVX512 "llama: enable AVX512" OFF) +option(LLAMA_AVX512_VBMI "llama: enable AVX512-VBMI" OFF) +option(LLAMA_AVX512_VNNI "llama: enable AVX512-VNNI" OFF) +option(LLAMA_FMA "llama: enable FMA" ON) +# in MSVC F16C is implied with AVX2/AVX512 +if (NOT MSVC) + option(LLAMA_F16C "llama: enable F16C" ON) +endif() + +# 3rd party libs +option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON) +option(LLAMA_OPENBLAS "llama: use OpenBLAS" OFF) +option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) +option(LLAMA_CLBLAST "llama: use CLBlast" OFF) + +# +# Compile flags +# + +set(CMAKE_C_STANDARD 11) +set(CMAKE_C_STANDARD_REQUIRED true) +set(THREADS_PREFER_PTHREAD_FLAG ON) +find_package(Threads REQUIRED) + +if (NOT MSVC) + if (LLAMA_SANITIZE_THREAD) + add_compile_options(-fsanitize=thread) + link_libraries(-fsanitize=thread) + endif() + + if (LLAMA_SANITIZE_ADDRESS) + add_compile_options(-fsanitize=address -fno-omit-frame-pointer) + link_libraries(-fsanitize=address) + endif() + + if (LLAMA_SANITIZE_UNDEFINED) + add_compile_options(-fsanitize=undefined) + link_libraries(-fsanitize=undefined) + endif() +endif() + +if (APPLE AND LLAMA_ACCELERATE) + find_library(ACCELERATE_FRAMEWORK Accelerate) + if (ACCELERATE_FRAMEWORK) + message(STATUS "Accelerate framework found") + + add_compile_definitions(GGML_USE_ACCELERATE) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK}) + else() + message(WARNING "Accelerate framework not found") + endif() +endif() + +if (LLAMA_OPENBLAS) + if (LLAMA_STATIC) + set(BLA_STATIC ON) + endif() + + set(BLA_VENDOR OpenBLAS) + find_package(BLAS) + if (BLAS_FOUND) + message(STATUS "OpenBLAS found") + + add_compile_definitions(GGML_USE_OPENBLAS) + add_link_options(${BLAS_LIBRARIES}) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} openblas) + + # find header file + set(OPENBLAS_INCLUDE_SEARCH_PATHS + /usr/include + /usr/include/openblas + /usr/include/openblas-base + /usr/local/include + /usr/local/include/openblas + /usr/local/include/openblas-base + /opt/OpenBLAS/include + $ENV{OpenBLAS_HOME} + $ENV{OpenBLAS_HOME}/include + ) + find_path(OPENBLAS_INC NAMES cblas.h PATHS ${OPENBLAS_INCLUDE_SEARCH_PATHS}) + add_compile_options(-I${OPENBLAS_INC}) + else() + message(WARNING "OpenBLAS not found") + endif() +endif() + +if (LLAMA_ALL_WARNINGS) + if (NOT MSVC) + set(c_flags + -Wall + -Wextra + -Wpedantic + -Wcast-qual + -Wdouble-promotion + -Wshadow + -Wstrict-prototypes + -Wpointer-arith + ) + set(cxx_flags + -Wall + -Wextra + -Wpedantic + -Wcast-qual + -Wno-unused-function + -Wno-multichar + ) + else() + # todo : msvc + endif() + + add_compile_options( + "$<$:${c_flags}>" + "$<$:${cxx_flags}>" + ) + +endif() + +if (MSVC) + add_compile_definitions(_CRT_SECURE_NO_WARNINGS) + + if (BUILD_SHARED_LIBS) + set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON) + endif() +endif() + +if (LLAMA_LTO) + include(CheckIPOSupported) + check_ipo_supported(RESULT result OUTPUT output) + if (result) + set(CMAKE_INTERPROCEDURAL_OPTIMIZATION TRUE) + else() + message(WARNING "IPO is not supported: ${output}") + endif() +endif() + +# Architecture specific +# TODO: probably these flags need to be tweaked on some architectures +# feel free to update the Makefile for your architecture and send a pull request or issue +message(STATUS "CMAKE_SYSTEM_PROCESSOR: ${CMAKE_SYSTEM_PROCESSOR}") +if (NOT MSVC) + if (LLAMA_STATIC) + add_link_options(-static) + if (MINGW) + add_link_options(-static-libgcc -static-libstdc++) + endif() + endif() + if (LLAMA_GPROF) + add_compile_options(-pg) + endif() + if (LLAMA_NATIVE) + add_compile_options(-march=native) + endif() +endif() + +if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") + message(STATUS "ARM detected") + if (MSVC) + # TODO: arm msvc? + else() + if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") + add_compile_options(-mcpu=native) + endif() + # TODO: armv6,7,8 version specific flags + endif() +elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "^(x86_64|i686|AMD64)$") + message(STATUS "x86 detected") + if (MSVC) + if (LLAMA_AVX512) + add_compile_options($<$:/arch:AVX512>) + add_compile_options($<$:/arch:AVX512>) + # MSVC has no compile-time flags enabling specific + # AVX512 extensions, neither it defines the + # macros corresponding to the extensions. + # Do it manually. + if (LLAMA_AVX512_VBMI) + add_compile_definitions($<$:__AVX512VBMI__>) + add_compile_definitions($<$:__AVX512VBMI__>) + endif() + if (LLAMA_AVX512_VNNI) + add_compile_definitions($<$:__AVX512VNNI__>) + add_compile_definitions($<$:__AVX512VNNI__>) + endif() + elseif (LLAMA_AVX2) + add_compile_options($<$:/arch:AVX2>) + add_compile_options($<$:/arch:AVX2>) + elseif (LLAMA_AVX) + add_compile_options($<$:/arch:AVX>) + add_compile_options($<$:/arch:AVX>) + endif() + else() + if (LLAMA_F16C) + add_compile_options(-mf16c) + endif() + if (LLAMA_FMA) + add_compile_options(-mfma) + endif() + if (LLAMA_AVX) + add_compile_options(-mavx) + endif() + if (LLAMA_AVX2) + add_compile_options(-mavx2) + endif() + if (LLAMA_AVX512) + add_compile_options(-mavx512f) + add_compile_options(-mavx512bw) + endif() + if (LLAMA_AVX512_VBMI) + add_compile_options(-mavx512vbmi) + endif() + if (LLAMA_AVX512_VNNI) + add_compile_options(-mavx512vnni) + endif() + endif() +else() + # TODO: support PowerPC + message(STATUS "Unknown architecture") +endif() + +# +# Build libraries +# + +function(include_ggml DIRECTORY SUFFIX WITH_LLAMA) + if (LLAMA_CUBLAS) + cmake_minimum_required(VERSION 3.17) + + find_package(CUDAToolkit) + if (CUDAToolkit_FOUND) + message(STATUS "cuBLAS found") + + enable_language(CUDA) + + set(GGML_CUDA_SOURCES ${DIRECTORY}ggml-cuda.cu ${DIRECTORY}ggml-cuda.h) + + add_compile_definitions(GGML_USE_CUBLAS) + + if (LLAMA_STATIC) + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + else() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart CUDA::cublas CUDA::cublasLt) + endif() + + else() + message(WARNING "cuBLAS not found") + endif() + endif() + + if (LLAMA_CLBLAST) + find_package(CLBlast) + if (CLBlast_FOUND) + message(STATUS "CLBlast found") + + set(GGML_OPENCL_SOURCES ${DIRECTORY}ggml-opencl.c ${DIRECTORY}ggml-opencl.h) + + add_compile_definitions(GGML_USE_CLBLAST) + + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} clblast) + else() + message(WARNING "CLBlast not found") + endif() + endif() + + add_library(ggml${SUFFIX} OBJECT + ${DIRECTORY}/ggml.c + ${DIRECTORY}/ggml.h + ${GGML_CUDA_SOURCES} + ${GGML_OPENCL_SOURCES}) + + target_include_directories(ggml${SUFFIX} PUBLIC ${DIRECTORY}) + target_compile_features(ggml${SUFFIX} PUBLIC c_std_11) # don't bump + target_link_libraries(ggml${SUFFIX} PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS}) + + if (BUILD_SHARED_LIBS) + set_target_properties(ggml${SUFFIX} PROPERTIES POSITION_INDEPENDENT_CODE ON) + endif() + + if (WITH_LLAMA) + add_library(llama${SUFFIX} + ${DIRECTORY}/llama.cpp + ${DIRECTORY}/llama.h + ${DIRECTORY}/llama_util.h) + + target_include_directories(llama${SUFFIX} PUBLIC .) + target_compile_features(llama${SUFFIX} PUBLIC cxx_std_11) # don't bump + target_link_libraries(llama${SUFFIX} PRIVATE ggml${SUFFIX} ${LLAMA_EXTRA_LIBS}) + + if (BUILD_SHARED_LIBS) + set_target_properties(llama${SUFFIX} PROPERTIES POSITION_INDEPENDENT_CODE ON) + target_compile_definitions(llama${SUFFIX} PRIVATE LLAMA_SHARED LLAMA_BUILD) + endif() + endif() + + if (GGML_CUDA_SOURCES) + message(STATUS "GGML CUDA sources found, configuring CUDA architecture") + set_property(TARGET ggml${SUFFIX} PROPERTY CUDA_ARCHITECTURES OFF) + set_property(TARGET ggml${SUFFIX} PROPERTY CUDA_SELECT_NVCC_ARCH_FLAGS "Auto") + if (WITH_LLAMA) + set_property(TARGET llama${SUFFIX} PROPERTY CUDA_ARCHITECTURES OFF) + endif() + endif() +endfunction() diff --git a/mpt.cpp b/mpt.cpp new file mode 100644 index 0000000..e952c99 --- /dev/null +++ b/mpt.cpp @@ -0,0 +1,24 @@ +#include "justlm_mpt.hpp" +#include "justlm.hpp" + +#include +#include +#include +#include + + + +extern "C" { +const LM::Implementation *get_justlm_implementation() { + static LM::Implementation fres{false}; + return &fres; +} + +bool magic_match(uint32_t magic) { + return magic == 0x67676d6d; +} + +LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) { + return new LM::MPTInference(weights_path, f, p); +} +} diff --git a/mpt/mpt.cpp b/mpt/mpt.cpp index ef972b3..bda6cdf 100644 --- a/mpt/mpt.cpp +++ b/mpt/mpt.cpp @@ -187,9 +187,9 @@ bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & mod // create the ggml context { struct ggml_init_params params = { - .mem_size = ctx_size, - .mem_buffer = NULL, - .no_alloc = false, + ctx_size, + NULL, + false, }; model.ctx = ggml_init(params); @@ -362,30 +362,29 @@ bool mpt_eval( const int n_head = hparams.n_head; const int n_vocab = hparams.n_vocab; - static size_t buf_size = 256u*1024*1024; - static void * buf = malloc(buf_size); - if (mem_per_token > 0 && mem_per_token*N > buf_size) { + if (mem_per_token > 0 && mem_per_token*N > model.eval_buf_size) { const size_t buf_size_new = 1.1*(mem_per_token*N); // add 10% to account for ggml object overhead //printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new); // reallocate - buf_size = buf_size_new; - buf = realloc(buf, buf_size); - if (buf == nullptr) { - fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, buf_size); + model.eval_buf_size = buf_size_new; + model.eval_buf = realloc(model.eval_buf, model.eval_buf_size); + if (model.eval_buf == nullptr) { + fprintf(stderr, "%s: failed to allocate %zu bytes\n", __func__, model.eval_buf_size); return false; } } struct ggml_init_params params = { - .mem_size = buf_size, - .mem_buffer = buf, - .no_alloc = false, + model.eval_buf_size, + model.eval_buf, + false }; struct ggml_context * ctx0 = ggml_init(params); - struct ggml_cgraph gf = { .n_threads = n_threads }; + struct ggml_cgraph gf; + gf.n_threads = n_threads; struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); memcpy(embd->data, embd_inp.data(), N*ggml_element_size(embd)); @@ -692,7 +691,6 @@ size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937 &rng, uint } mpt_vocab::id mpt_sample_top_k_top_p( - const mpt_vocab & vocab, const size_t actualVocabSize, const int32_t * last_n_tokens_data, int last_n_tokens_size, diff --git a/mpt/mpt.hpp b/mpt/mpt.hpp index 5169259..8895a0c 100644 --- a/mpt/mpt.hpp +++ b/mpt/mpt.hpp @@ -83,10 +83,16 @@ struct mpt_model { struct ggml_context * ctx; std::map tensors; + size_t eval_buf_size = 256u*1024*1024; + void *eval_buf; mpt_buffer buf; + mpt_model() { + eval_buf = malloc(eval_buf_size); + } ~mpt_model() { + free(eval_buf); if (ctx) { ggml_free(ctx); } @@ -110,7 +116,7 @@ struct mpt_vocab { bool mpt_model_load(const std::string &fname, std::istream &fin, mpt_model & model, mpt_vocab & vocab); bool mpt_eval(mpt_model& model, const int n_threads, const int n_past, const std::vector& embd_inp, std::vector& embd_w, size_t& mem_per_token); std::vector mpt_tokenize(const mpt_vocab & vocab, const std::string & text); -mpt_vocab::id mpt_sample_top_k_top_p(const mpt_vocab& vocab, const size_t actualVocabSize, const int32_t *last_n_tokens_data, int last_n_tokens_size, const std::vector logits, int top_k, double top_p, double temp, float repeat_penalty, std::mt19937& rng); +mpt_vocab::id mpt_sample_top_k_top_p(const size_t actualVocabSize, const int32_t *last_n_tokens_data, int last_n_tokens_size, const std::vector logits, int top_k, double top_p, double temp, float repeat_penalty, std::mt19937& rng); size_t mpt_get_state_size(const mpt_model &model); size_t mpt_copy_state_data(const mpt_model &model, const std::mt19937& rng, uint8_t *dest); size_t mpt_set_state_data(mpt_model *model, std::mt19937 *rng, const uint8_t *src); diff --git a/pybind.cpp b/pybind.cpp index a1acdae..a1ecc56 100644 --- a/pybind.cpp +++ b/pybind.cpp @@ -9,7 +9,7 @@ namespace py = pybind11; -PYBIND11_MODULE(libjustlm_py, m) { +PYBIND11_MODULE(justlm_py, m) { using namespace LM; py::class_(m, "Params") .def(py::init<>())