mirror of
https://gitlab.com/niansa/libjustlm.git
synced 2025-03-06 20:49:17 +01:00
Fixed linebreaks and support latest llama.cpp
This commit is contained in:
parent
c9dac7cb89
commit
5feca59be7
11 changed files with 48 additions and 13 deletions
3
.gitmodules
vendored
3
.gitmodules
vendored
|
@ -5,5 +5,8 @@
|
|||
path = llama.cpp-alibi
|
||||
url = https://github.com/manyoso/llama.cpp.git
|
||||
[submodule "llama.cpp-mainline"]
|
||||
path = llama.cpp-old2
|
||||
url = https://github.com/ggerganov/llama.cpp.git
|
||||
[submodule "llama.cpp-2"]
|
||||
path = llama.cpp-mainline
|
||||
url = https://github.com/ggerganov/llama.cpp.git
|
||||
|
|
|
@ -38,6 +38,7 @@ include(llama.cpp.cmake)
|
|||
|
||||
include_ggml(llama.cpp-mainline _mainline Yes)
|
||||
include_ggml(llama.cpp-old _old Yes)
|
||||
include_ggml(llama.cpp-old2 _old2 Yes)
|
||||
include_ggml(llama.cpp-alibi _alibi No)
|
||||
|
||||
|
||||
|
@ -57,7 +58,7 @@ if (LM_GPTJ)
|
|||
endif()
|
||||
|
||||
if (LM_LLAMA)
|
||||
add_library(justlm_llama SHARED llama.cpp justlm_llama.hpp)
|
||||
add_library(justlm_llama SHARED llama_mainline.cpp justlm_llama.hpp)
|
||||
target_link_libraries(justlm_llama PRIVATE ggml_mainline llama_mainline)
|
||||
target_justlm_setup(justlm_llama)
|
||||
endif()
|
||||
|
@ -66,6 +67,10 @@ if (LM_LLAMA_OLD)
|
|||
add_library(justlm_llama_old SHARED llama_old.cpp justlm_llama_old.hpp)
|
||||
target_link_libraries(justlm_llama_old PRIVATE ggml_old llama_old)
|
||||
target_justlm_setup(justlm_llama_old)
|
||||
|
||||
add_library(justlm_llama_old2 SHARED llama_old2.cpp justlm_llama.hpp)
|
||||
target_link_libraries(justlm_llama_old2 PRIVATE ggml_old2 llama_old2)
|
||||
target_justlm_setup(justlm_llama_old2)
|
||||
endif()
|
||||
|
||||
|
||||
|
|
|
@ -64,12 +64,6 @@ protected:
|
|||
|
||||
void *generic_state = nullptr;
|
||||
|
||||
static inline
|
||||
bool ends_with(std::string_view str, std::string_view suffix) noexcept {
|
||||
if (suffix.empty()) return false;
|
||||
return str.size() >= suffix.size() && 0 == str.compare(str.size()-suffix.size(), suffix.size(), suffix);
|
||||
}
|
||||
|
||||
LM_LAST_ERROR_STORAGE
|
||||
|
||||
public:
|
||||
|
|
|
@ -168,7 +168,7 @@ public:
|
|||
// Loop until done
|
||||
bool abort = false;
|
||||
unsigned eos_count = 0;
|
||||
while (!abort && !ends_with(fres, end)) {
|
||||
while (!abort && fres.find(end) != fres.npos) {
|
||||
// Sample top p and top k
|
||||
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
|
||||
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
|
||||
|
|
|
@ -181,7 +181,7 @@ public:
|
|||
// Loop until done
|
||||
bool abort = false;
|
||||
unsigned eos_count = 0;
|
||||
while (!abort && !ends_with(fres, end)) {
|
||||
while (!abort && fres.find(end) != fres.npos) {
|
||||
// Sample top p and top k
|
||||
auto id = llama_sample_top_p_top_k();
|
||||
|
||||
|
|
|
@ -158,7 +158,7 @@ public:
|
|||
// Loop until done
|
||||
bool abort = false;
|
||||
unsigned eos_count = 0;
|
||||
while (!abort && !ends_with(fres, end)) {
|
||||
while (!abort && fres.find(end) != fres.npos) {
|
||||
// Sample top p and top k
|
||||
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
|
||||
auto id = llama_sample_top_p_top_k(state->ctx, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-n_repeat_last):nullptr, n_repeat_last, params.top_k, params.top_p, params.temp, params.repeat_penalty);
|
||||
|
|
|
@ -177,7 +177,7 @@ public:
|
|||
// Loop until done
|
||||
bool abort = false;
|
||||
unsigned eos_count = 0;
|
||||
while (!abort && !ends_with(fres, end)) {
|
||||
while (!abort && fres.find(end) != fres.npos) {
|
||||
// Sample top p and top k
|
||||
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
|
||||
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 5ea43392731040b454c293123839b90e159cbb99
|
||||
Subproject commit 2d5db48371052087a83974abda3767d1aedec598
|
1
llama.cpp-old2
Submodule
1
llama.cpp-old2
Submodule
|
@ -0,0 +1 @@
|
|||
Subproject commit 5ea43392731040b454c293123839b90e159cbb99
|
|
@ -22,7 +22,7 @@ bool magic_match(std::istream& f) {
|
|||
// Check version
|
||||
uint32_t version = 0;
|
||||
f.read(reinterpret_cast<char*>(&version), sizeof(version));
|
||||
return version >= 2;
|
||||
return version >= 3;
|
||||
}
|
||||
|
||||
LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) {
|
32
llama_old2.cpp
Normal file
32
llama_old2.cpp
Normal file
|
@ -0,0 +1,32 @@
|
|||
#include "justlm_llama.hpp"
|
||||
#include "justlm.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <fstream>
|
||||
#include <cstdint>
|
||||
|
||||
|
||||
|
||||
extern "C" {
|
||||
const LM::Implementation *get_justlm_implementation() {
|
||||
static LM::Implementation fres{false};
|
||||
return &fres;
|
||||
}
|
||||
|
||||
bool magic_match(std::istream& f) {
|
||||
// Check magic
|
||||
uint32_t magic;
|
||||
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
|
||||
if (magic != 0x67676a74) return false;
|
||||
// Check version
|
||||
uint32_t version = 0;
|
||||
f.read(reinterpret_cast<char*>(&version), sizeof(version));
|
||||
return version == 2;
|
||||
}
|
||||
|
||||
LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) {
|
||||
f.close();
|
||||
return new LM::LLaMAInference(weights_path, p);
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue