1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00

Fixed linebreaks and support latest llama.cpp

This commit is contained in:
niansa 2023-05-20 02:25:46 +02:00
parent c9dac7cb89
commit 5feca59be7
11 changed files with 48 additions and 13 deletions

3
.gitmodules vendored
View file

@ -5,5 +5,8 @@
path = llama.cpp-alibi
url = https://github.com/manyoso/llama.cpp.git
[submodule "llama.cpp-mainline"]
path = llama.cpp-old2
url = https://github.com/ggerganov/llama.cpp.git
[submodule "llama.cpp-2"]
path = llama.cpp-mainline
url = https://github.com/ggerganov/llama.cpp.git

View file

@ -38,6 +38,7 @@ include(llama.cpp.cmake)
include_ggml(llama.cpp-mainline _mainline Yes)
include_ggml(llama.cpp-old _old Yes)
include_ggml(llama.cpp-old2 _old2 Yes)
include_ggml(llama.cpp-alibi _alibi No)
@ -57,7 +58,7 @@ if (LM_GPTJ)
endif()
if (LM_LLAMA)
add_library(justlm_llama SHARED llama.cpp justlm_llama.hpp)
add_library(justlm_llama SHARED llama_mainline.cpp justlm_llama.hpp)
target_link_libraries(justlm_llama PRIVATE ggml_mainline llama_mainline)
target_justlm_setup(justlm_llama)
endif()
@ -66,6 +67,10 @@ if (LM_LLAMA_OLD)
add_library(justlm_llama_old SHARED llama_old.cpp justlm_llama_old.hpp)
target_link_libraries(justlm_llama_old PRIVATE ggml_old llama_old)
target_justlm_setup(justlm_llama_old)
add_library(justlm_llama_old2 SHARED llama_old2.cpp justlm_llama.hpp)
target_link_libraries(justlm_llama_old2 PRIVATE ggml_old2 llama_old2)
target_justlm_setup(justlm_llama_old2)
endif()

View file

@ -64,12 +64,6 @@ protected:
void *generic_state = nullptr;
static inline
bool ends_with(std::string_view str, std::string_view suffix) noexcept {
if (suffix.empty()) return false;
return str.size() >= suffix.size() && 0 == str.compare(str.size()-suffix.size(), suffix.size(), suffix);
}
LM_LAST_ERROR_STORAGE
public:

View file

@ -168,7 +168,7 @@ public:
// Loop until done
bool abort = false;
unsigned eos_count = 0;
while (!abort && !ends_with(fres, end)) {
while (!abort && fres.find(end) != fres.npos) {
// Sample top p and top k
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);

View file

@ -181,7 +181,7 @@ public:
// Loop until done
bool abort = false;
unsigned eos_count = 0;
while (!abort && !ends_with(fres, end)) {
while (!abort && fres.find(end) != fres.npos) {
// Sample top p and top k
auto id = llama_sample_top_p_top_k();

View file

@ -158,7 +158,7 @@ public:
// Loop until done
bool abort = false;
unsigned eos_count = 0;
while (!abort && !ends_with(fres, end)) {
while (!abort && fres.find(end) != fres.npos) {
// Sample top p and top k
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
auto id = llama_sample_top_p_top_k(state->ctx, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-n_repeat_last):nullptr, n_repeat_last, params.top_k, params.top_p, params.temp, params.repeat_penalty);

View file

@ -177,7 +177,7 @@ public:
// Loop until done
bool abort = false;
unsigned eos_count = 0;
while (!abort && !ends_with(fres, end)) {
while (!abort && fres.find(end) != fres.npos) {
// Sample top p and top k
const auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
auto id = gpt_sample_top_k_top_p(state->model.hparams.n_vocab, state->tokens.data()+state->tokens.size()-n_repeat_last, n_repeat_last, state->logits, params.top_k, params.top_p, params.temp, params.repeat_penalty, state->rng);

@ -1 +1 @@
Subproject commit 5ea43392731040b454c293123839b90e159cbb99
Subproject commit 2d5db48371052087a83974abda3767d1aedec598

1
llama.cpp-old2 Submodule

@ -0,0 +1 @@
Subproject commit 5ea43392731040b454c293123839b90e159cbb99

View file

@ -22,7 +22,7 @@ bool magic_match(std::istream& f) {
// Check version
uint32_t version = 0;
f.read(reinterpret_cast<char*>(&version), sizeof(version));
return version >= 2;
return version >= 3;
}
LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) {

32
llama_old2.cpp Normal file
View file

@ -0,0 +1,32 @@
#include "justlm_llama.hpp"
#include "justlm.hpp"
#include <string>
#include <string_view>
#include <fstream>
#include <cstdint>
extern "C" {
const LM::Implementation *get_justlm_implementation() {
static LM::Implementation fres{false};
return &fres;
}
bool magic_match(std::istream& f) {
// Check magic
uint32_t magic;
f.read(reinterpret_cast<char*>(&magic), sizeof(magic));
if (magic != 0x67676a74) return false;
// Check version
uint32_t version = 0;
f.read(reinterpret_cast<char*>(&version), sizeof(version));
return version == 2;
}
LM::Inference *construct(const std::string &weights_path, std::ifstream& f, const LM::Inference::Params &p) {
f.close();
return new LM::LLaMAInference(weights_path, p);
}
}