mirror of
https://gitlab.com/niansa/libjustlm.git
synced 2025-03-06 20:49:17 +01:00
Completed mainline llama implementation
This commit is contained in:
parent
b17cc6ffbd
commit
ad1e8a3368
2 changed files with 29 additions and 6 deletions
|
@ -56,18 +56,18 @@ if (LM_GPTJ)
|
|||
target_justlm_setup(justlm_gptj)
|
||||
endif()
|
||||
|
||||
if (LM_LLAMA_OLD)
|
||||
add_library(justlm_llama_old SHARED llama_old.cpp justlm_llama_old.hpp)
|
||||
target_link_libraries(justlm_llama_old PRIVATE ggml_old llama_old)
|
||||
target_justlm_setup(justlm_llama_old)
|
||||
endif()
|
||||
|
||||
if (LM_LLAMA)
|
||||
add_library(justlm_llama SHARED llama.cpp justlm_llama.hpp)
|
||||
target_link_libraries(justlm_llama PRIVATE ggml_mainline llama_mainline)
|
||||
target_justlm_setup(justlm_llama)
|
||||
endif()
|
||||
|
||||
if (LM_LLAMA_OLD)
|
||||
add_library(justlm_llama_old SHARED llama_old.cpp justlm_llama_old.hpp)
|
||||
target_link_libraries(justlm_llama_old PRIVATE ggml_old llama_old)
|
||||
target_justlm_setup(justlm_llama_old)
|
||||
endif()
|
||||
|
||||
|
||||
add_library(justlm STATIC
|
||||
include/justlm.hpp justlm.cpp
|
||||
|
|
|
@ -111,6 +111,29 @@ class LLaMAInference final : public Inference {
|
|||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
int llama_sample_top_p_top_k() {
|
||||
auto& state = get_state();
|
||||
auto logits = llama_get_logits(state->ctx);
|
||||
auto n_vocab = llama_n_vocab(state->ctx);
|
||||
// Populate initial list of all candidates
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (int token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
llama_token_data_array candidates_p = {candidates.data(), candidates.size(), false};
|
||||
// Sample repeat penalty
|
||||
auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
|
||||
llama_sample_repetition_penalty(nullptr, &candidates_p, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-n_repeat_last):nullptr, n_repeat_last, params.repeat_penalty);
|
||||
// Temperature sampling
|
||||
llama_sample_top_k(state->ctx, &candidates_p, params.top_k, 1);
|
||||
llama_sample_tail_free(state->ctx, &candidates_p, 1.0f, 1);
|
||||
llama_sample_typical(state->ctx, &candidates_p, 1.0f, 1);
|
||||
llama_sample_top_p(state->ctx, &candidates_p, params.top_p, 1);
|
||||
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
|
||||
return llama_sample_token(state->ctx, &candidates_p);
|
||||
}
|
||||
|
||||
public:
|
||||
LLaMAInference(const std::string& weights_path, const Params& p) : Inference(p) {
|
||||
init(weights_path);
|
||||
|
|
Loading…
Add table
Reference in a new issue