mirror of
https://gitlab.com/niansa/libjustlm.git
synced 2025-03-06 20:49:17 +01:00
Implemented grammar sampling and zero-temperature sampling
This commit is contained in:
parent
3a953ed13a
commit
79cf49faae
3 changed files with 75 additions and 23 deletions
|
@ -147,9 +147,17 @@ public:
|
|||
virtual LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream&) const LM_NOEXCEPTDECL = 0;
|
||||
virtual LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream&) LM_NOEXCEPTDECL = 0;
|
||||
|
||||
virtual LM_SCHEDULABLE(LM_ERRBOOL) load_grammar(const std::string&, bool override_temperature [[maybe_unused]] = false) LM_NOEXCEPTDECL {
|
||||
LM_COTHROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
|
||||
}
|
||||
virtual LM_SCHEDULABLE(LM_ERRBOOL) unload_grammar() LM_NOEXCEPTDECL {
|
||||
LM_COTHROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
|
||||
}
|
||||
|
||||
virtual const std::string& get_prompt() const LM_NOEXCEPTDECL = 0;
|
||||
|
||||
virtual bool is_mirostat_available() const noexcept {return false;}
|
||||
virtual bool is_grammar_available() const noexcept {return false;}
|
||||
|
||||
LM_LAST_ERROR_GETTER
|
||||
};
|
||||
|
|
|
@ -3,13 +3,17 @@
|
|||
#include <cstring>
|
||||
#include <ggml.h>
|
||||
#include <llama.h>
|
||||
#include <common/grammar-parser.h>
|
||||
|
||||
|
||||
namespace LM {
|
||||
class LLaMAInference final : public Inference {
|
||||
struct State {
|
||||
llama_context *ctx = nullptr;
|
||||
struct llama_model *model;
|
||||
llama_model *model;
|
||||
llama_grammar *grammar = nullptr;
|
||||
bool grammar_override_temp;
|
||||
grammar_parser::parse_state parsed_grammar;
|
||||
std::string prompt; // Mostly here for easy "debugging"
|
||||
std::vector<int> tokens;
|
||||
unsigned n_ctx;
|
||||
|
@ -131,29 +135,39 @@ class LLaMAInference final : public Inference {
|
|||
// Sample repeat penalty
|
||||
auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
|
||||
llama_sample_repetition_penalty(state->ctx, &candidates_p, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-n_repeat_last):nullptr, n_repeat_last, params.repeat_penalty);
|
||||
// Temperature sampling
|
||||
switch (params.prefer_mirostat) {
|
||||
case 0: {
|
||||
llama_sample_top_k(state->ctx, &candidates_p, params.top_k, 1);
|
||||
llama_sample_tail_free(state->ctx, &candidates_p, 1.0f, 1);
|
||||
llama_sample_typical(state->ctx, &candidates_p, 1.0f, 1);
|
||||
llama_sample_top_p(state->ctx, &candidates_p, params.top_p, 1);
|
||||
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
|
||||
// Grammar sampling
|
||||
if (state->grammar) {
|
||||
llama_sample_grammar(state->ctx, &candidates_p, state->grammar);
|
||||
}
|
||||
if (!(state->grammar && state->grammar_override_temp) && (params.temp > 0.01f || params.temp < -0.01f)) {
|
||||
// Temperature sampling
|
||||
switch (params.prefer_mirostat) {
|
||||
case 0: {
|
||||
llama_sample_top_k(state->ctx, &candidates_p, params.top_k, 1);
|
||||
llama_sample_tail_free(state->ctx, &candidates_p, 1.0f, 1);
|
||||
llama_sample_typical(state->ctx, &candidates_p, 1.0f, 1);
|
||||
llama_sample_top_p(state->ctx, &candidates_p, params.top_p, 1);
|
||||
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
|
||||
return llama_sample_token(state->ctx, &candidates_p);
|
||||
}
|
||||
case 1: {
|
||||
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
|
||||
const int mirostat_m = 100;
|
||||
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
|
||||
return llama_sample_token_mirostat(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, mirostat_m, &mirostat_mu);
|
||||
}
|
||||
case 2: {
|
||||
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
|
||||
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
|
||||
return llama_sample_token_mirostat_v2(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, &mirostat_mu);
|
||||
}
|
||||
default: LM_THROW("Invalid mirostat version "+std::to_string(params.prefer_mirostat), LM_BOOL_ERROR);
|
||||
}
|
||||
} else {
|
||||
// Greedy sampling
|
||||
abort();
|
||||
return llama_sample_token(state->ctx, &candidates_p);
|
||||
}
|
||||
case 1: {
|
||||
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
|
||||
const int mirostat_m = 100;
|
||||
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
|
||||
return llama_sample_token_mirostat(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, mirostat_m, &mirostat_mu);
|
||||
}
|
||||
case 2: {
|
||||
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
|
||||
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
|
||||
return llama_sample_token_mirostat_v2(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, &mirostat_mu);
|
||||
}
|
||||
default: LM_THROW("Invalid mirostat version "+std::to_string(params.prefer_mirostat), LM_BOOL_ERROR);
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
|
@ -344,6 +358,30 @@ public:
|
|||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) load_grammar(const std::string& src, bool override_temperature) LM_NOEXCEPTDECL override {
|
||||
auto& state = get_state();
|
||||
|
||||
state->parsed_grammar = grammar_parser::parse(src.c_str());
|
||||
if (state->parsed_grammar.rules.empty()) {
|
||||
LM_COTHROW("Failed to parse grammar (or no rules)", LM_BOOL_ERROR);
|
||||
}
|
||||
|
||||
auto rules = state->parsed_grammar.c_rules();
|
||||
state->grammar = llama_grammar_init(rules.data(), rules.size(), state->parsed_grammar.symbol_ids.at("root"));
|
||||
if (!state->grammar) {
|
||||
LM_COTHROW("Failed to generate llama grammar", LM_BOOL_ERROR);
|
||||
}
|
||||
|
||||
state->grammar_override_temp = override_temperature;
|
||||
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
}
|
||||
LM_SCHEDULABLE(LM_ERRBOOL) unload_grammar() LM_NOEXCEPTDECL override {
|
||||
get_state()->grammar = nullptr;
|
||||
|
||||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
|
||||
return get_state()->prompt;
|
||||
}
|
||||
|
@ -351,5 +389,9 @@ public:
|
|||
bool is_mirostat_available() const noexcept override {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_grammar_available() const noexcept override {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -332,7 +332,9 @@ function(include_ggml DIRECTORY SUFFIX WITH_LLAMA)
|
|||
if (WITH_LLAMA)
|
||||
SET(LLAMA_SOURCES
|
||||
${DIRECTORY}/llama.cpp
|
||||
${DIRECTORY}/llama.h)
|
||||
${DIRECTORY}/llama.h
|
||||
${DIRECTORY}/common/grammar-parser.h
|
||||
${DIRECTORY}/common/grammar-parser.cpp)
|
||||
remove_nonexistent(LLAMA_SOURCES)
|
||||
add_library(llama${SUFFIX} ${LLAMA_SOURCES})
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue