1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00

Updated llama-mainline and deleted old llama versions

This commit is contained in:
niansa 2023-08-31 16:52:38 +02:00
parent d8f4efb0c9
commit e3d52c42b7
5 changed files with 11 additions and 35 deletions

View file

@ -12,7 +12,6 @@ set(LM_PYBIND No CACHE BOOL "If justlm Python bindings should be build")
set(LM_COSCHED No CACHE BOOL "If justlm should make use of CoSched")
set(LM_NOEXCEPT No CACHE BOOL "If justlm exceptions should be disabled")
set(LM_LLAMA Yes CACHE BOOL "If LLaMa model support should be built into justlm")
set(LM_LLAMA_OLD Yes CACHE BOOL "If old LLaMa model support should be built into justlm")
set(LM_GPTJ Yes CACHE BOOL "If GPT-J model support should be built into justlm")
set(LM_MPT Yes CACHE BOOL "If MPT model support should be built into justlm")
@ -37,8 +36,6 @@ endfunction()
include(llama.cpp.cmake)
include_ggml(llama.cpp-mainline _mainline Yes)
include_ggml(llama.cpp-230511 _230511 Yes)
include_ggml(llama.cpp-230519 _230519 Yes)
include_ggml(llama.cpp-alibi _alibi No)
@ -53,7 +50,7 @@ endif()
if (LM_GPTJ)
add_library(justlm_gptj SHARED gptj.cpp justlm_gptj.hpp gptj/gptj.cpp gptj/gptj.hpp)
target_link_libraries(justlm_gptj PRIVATE ggml_230511 justlm_g4a_common)
target_link_libraries(justlm_gptj PRIVATE ggml_alibi justlm_g4a_common)
target_justlm_setup(justlm_gptj)
endif()
@ -65,20 +62,6 @@ if (LM_LLAMA)
target_justlm_setup(justlm_llama)
endif()
if (LM_LLAMA_OLD)
add_library(justlm_llama_230511 SHARED llama.cpp justlm_llama.hpp)
target_link_libraries(justlm_llama_230511 PRIVATE ggml_230511 llama_230511)
target_compile_definitions(justlm_llama_230511 PRIVATE
LLAMA_VERSIONS=<=1 LLAMA_DATE=230511)
target_justlm_setup(justlm_llama_230511)
add_library(justlm_llama_230519 SHARED llama.cpp justlm_llama.hpp)
target_link_libraries(justlm_llama_230519 PRIVATE ggml_230519 llama_230519)
target_compile_definitions(justlm_llama_230519 PRIVATE
LLAMA_VERSIONS===2 LLAMA_DATE=230519)
target_justlm_setup(justlm_llama_230519)
endif()
add_library(justlm STATIC
include/justlm.hpp justlm.cpp

View file

@ -9,6 +9,7 @@ namespace LM {
class LLaMAInference final : public Inference {
struct State {
llama_context *ctx = nullptr;
struct llama_model *model;
std::string prompt; // Mostly here for easy "debugging"
std::vector<int> tokens;
unsigned n_ctx;
@ -32,14 +33,16 @@ class LLaMAInference final : public Inference {
lparams.seed = params.seed;
lparams.n_ctx = params.n_ctx = params.n_ctx>0?params.n_ctx:2024;
lparams.use_mlock = params.use_mlock;
#if LLAMA_DATE >= 230519
lparams.n_gpu_layers = params.n_gpu_layers;
#endif
// Create context
state->ctx = llama_init_from_file(weights_path.c_str(), lparams);
state->model = llama_load_model_from_file(weights_path.c_str(), lparams);
if (!state->ctx) {
LM_THROW("Failed to initialize llama from file", LM_BOOL_ERROR);
LM_THROW("Failed to initialize llama model from file", LM_BOOL_ERROR);
}
state->ctx = llama_new_context_with_model(state->model, lparams);
if (!state->ctx) {
LM_THROW("Failed to initialize llama context from model", LM_BOOL_ERROR);
}
// Initialize some variables
@ -114,7 +117,6 @@ class LLaMAInference final : public Inference {
LM_CORETURN LM_BOOL_SUCCESS;
}
#if LLAMA_DATE >= 230519
int llama_sample_top_p_top_k() {
auto& state = get_state();
auto logits = llama_get_logits(state->ctx);
@ -153,13 +155,6 @@ class LLaMAInference final : public Inference {
default: LM_THROW("Invalid mirostat version "+std::to_string(params.prefer_mirostat), LM_BOOL_ERROR);
}
}
#else
int llama_sample_top_p_top_k() {
auto& state = get_state();
auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
return ::llama_sample_top_p_top_k(state->ctx, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-n_repeat_last):nullptr, n_repeat_last, params.top_k, params.top_p, params.temp, params.repeat_penalty);
}
#endif
public:
LLaMAInference(const std::string& weights_path, const Params& p) : Inference(p) {
@ -219,7 +214,7 @@ public:
LM_COTHROW(e.what(), "");
}
if (id == llama_token_eos()) {
if (id == llama_token_eos(state->ctx)) {
if (eos_count++ == params.n_eos_ignores) {
abort = true;
continue;
@ -236,7 +231,7 @@ public:
LM_COAWAIT window_scroll();
// Get token as string
const std::string_view str = llama_token_to_str(state->ctx, id);
const std::string_view str = llama_token_get_text(state->ctx, id);
// Append string to function result
state->prompt.append(str);

@ -1 +0,0 @@
Subproject commit 0e018fe008eacebdbcfa2d61b6c988c245c961cd

@ -1 +0,0 @@
Subproject commit 5ea43392731040b454c293123839b90e159cbb99

@ -1 +1 @@
Subproject commit 5ec8dd5a3c6a9a109351d2257bb9d53869bd0a94
Subproject commit e8422de39e4aa2f7e50574124b060a80607e654a