mirror of
https://gitlab.com/niansa/libjustlm.git
synced 2025-03-06 20:49:17 +01:00
Updated llama-mainline and deleted old llama versions
This commit is contained in:
parent
d8f4efb0c9
commit
e3d52c42b7
5 changed files with 11 additions and 35 deletions
|
@ -12,7 +12,6 @@ set(LM_PYBIND No CACHE BOOL "If justlm Python bindings should be build")
|
|||
set(LM_COSCHED No CACHE BOOL "If justlm should make use of CoSched")
|
||||
set(LM_NOEXCEPT No CACHE BOOL "If justlm exceptions should be disabled")
|
||||
set(LM_LLAMA Yes CACHE BOOL "If LLaMa model support should be built into justlm")
|
||||
set(LM_LLAMA_OLD Yes CACHE BOOL "If old LLaMa model support should be built into justlm")
|
||||
set(LM_GPTJ Yes CACHE BOOL "If GPT-J model support should be built into justlm")
|
||||
set(LM_MPT Yes CACHE BOOL "If MPT model support should be built into justlm")
|
||||
|
||||
|
@ -37,8 +36,6 @@ endfunction()
|
|||
include(llama.cpp.cmake)
|
||||
|
||||
include_ggml(llama.cpp-mainline _mainline Yes)
|
||||
include_ggml(llama.cpp-230511 _230511 Yes)
|
||||
include_ggml(llama.cpp-230519 _230519 Yes)
|
||||
include_ggml(llama.cpp-alibi _alibi No)
|
||||
|
||||
|
||||
|
@ -53,7 +50,7 @@ endif()
|
|||
|
||||
if (LM_GPTJ)
|
||||
add_library(justlm_gptj SHARED gptj.cpp justlm_gptj.hpp gptj/gptj.cpp gptj/gptj.hpp)
|
||||
target_link_libraries(justlm_gptj PRIVATE ggml_230511 justlm_g4a_common)
|
||||
target_link_libraries(justlm_gptj PRIVATE ggml_alibi justlm_g4a_common)
|
||||
target_justlm_setup(justlm_gptj)
|
||||
endif()
|
||||
|
||||
|
@ -65,20 +62,6 @@ if (LM_LLAMA)
|
|||
target_justlm_setup(justlm_llama)
|
||||
endif()
|
||||
|
||||
if (LM_LLAMA_OLD)
|
||||
add_library(justlm_llama_230511 SHARED llama.cpp justlm_llama.hpp)
|
||||
target_link_libraries(justlm_llama_230511 PRIVATE ggml_230511 llama_230511)
|
||||
target_compile_definitions(justlm_llama_230511 PRIVATE
|
||||
LLAMA_VERSIONS=<=1 LLAMA_DATE=230511)
|
||||
target_justlm_setup(justlm_llama_230511)
|
||||
|
||||
add_library(justlm_llama_230519 SHARED llama.cpp justlm_llama.hpp)
|
||||
target_link_libraries(justlm_llama_230519 PRIVATE ggml_230519 llama_230519)
|
||||
target_compile_definitions(justlm_llama_230519 PRIVATE
|
||||
LLAMA_VERSIONS===2 LLAMA_DATE=230519)
|
||||
target_justlm_setup(justlm_llama_230519)
|
||||
endif()
|
||||
|
||||
|
||||
add_library(justlm STATIC
|
||||
include/justlm.hpp justlm.cpp
|
||||
|
|
|
@ -9,6 +9,7 @@ namespace LM {
|
|||
class LLaMAInference final : public Inference {
|
||||
struct State {
|
||||
llama_context *ctx = nullptr;
|
||||
struct llama_model *model;
|
||||
std::string prompt; // Mostly here for easy "debugging"
|
||||
std::vector<int> tokens;
|
||||
unsigned n_ctx;
|
||||
|
@ -32,14 +33,16 @@ class LLaMAInference final : public Inference {
|
|||
lparams.seed = params.seed;
|
||||
lparams.n_ctx = params.n_ctx = params.n_ctx>0?params.n_ctx:2024;
|
||||
lparams.use_mlock = params.use_mlock;
|
||||
#if LLAMA_DATE >= 230519
|
||||
lparams.n_gpu_layers = params.n_gpu_layers;
|
||||
#endif
|
||||
|
||||
// Create context
|
||||
state->ctx = llama_init_from_file(weights_path.c_str(), lparams);
|
||||
state->model = llama_load_model_from_file(weights_path.c_str(), lparams);
|
||||
if (!state->ctx) {
|
||||
LM_THROW("Failed to initialize llama from file", LM_BOOL_ERROR);
|
||||
LM_THROW("Failed to initialize llama model from file", LM_BOOL_ERROR);
|
||||
}
|
||||
state->ctx = llama_new_context_with_model(state->model, lparams);
|
||||
if (!state->ctx) {
|
||||
LM_THROW("Failed to initialize llama context from model", LM_BOOL_ERROR);
|
||||
}
|
||||
|
||||
// Initialize some variables
|
||||
|
@ -114,7 +117,6 @@ class LLaMAInference final : public Inference {
|
|||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
#if LLAMA_DATE >= 230519
|
||||
int llama_sample_top_p_top_k() {
|
||||
auto& state = get_state();
|
||||
auto logits = llama_get_logits(state->ctx);
|
||||
|
@ -153,13 +155,6 @@ class LLaMAInference final : public Inference {
|
|||
default: LM_THROW("Invalid mirostat version "+std::to_string(params.prefer_mirostat), LM_BOOL_ERROR);
|
||||
}
|
||||
}
|
||||
#else
|
||||
int llama_sample_top_p_top_k() {
|
||||
auto& state = get_state();
|
||||
auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
|
||||
return ::llama_sample_top_p_top_k(state->ctx, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-n_repeat_last):nullptr, n_repeat_last, params.top_k, params.top_p, params.temp, params.repeat_penalty);
|
||||
}
|
||||
#endif
|
||||
|
||||
public:
|
||||
LLaMAInference(const std::string& weights_path, const Params& p) : Inference(p) {
|
||||
|
@ -219,7 +214,7 @@ public:
|
|||
LM_COTHROW(e.what(), "");
|
||||
}
|
||||
|
||||
if (id == llama_token_eos()) {
|
||||
if (id == llama_token_eos(state->ctx)) {
|
||||
if (eos_count++ == params.n_eos_ignores) {
|
||||
abort = true;
|
||||
continue;
|
||||
|
@ -236,7 +231,7 @@ public:
|
|||
LM_COAWAIT window_scroll();
|
||||
|
||||
// Get token as string
|
||||
const std::string_view str = llama_token_to_str(state->ctx, id);
|
||||
const std::string_view str = llama_token_get_text(state->ctx, id);
|
||||
|
||||
// Append string to function result
|
||||
state->prompt.append(str);
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
Subproject commit 0e018fe008eacebdbcfa2d61b6c988c245c961cd
|
|
@ -1 +0,0 @@
|
|||
Subproject commit 5ea43392731040b454c293123839b90e159cbb99
|
|
@ -1 +1 @@
|
|||
Subproject commit 5ec8dd5a3c6a9a109351d2257bb9d53869bd0a94
|
||||
Subproject commit e8422de39e4aa2f7e50574124b060a80607e654a
|
Loading…
Add table
Reference in a new issue