1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00

Implemented grammar sampling and zero-temperature sampling

This commit is contained in:
niansa 2023-08-31 19:37:33 +02:00
parent 3a953ed13a
commit 79cf49faae
3 changed files with 75 additions and 23 deletions

View file

@ -147,9 +147,17 @@ public:
virtual LM_SCHEDULABLE(LM_ERRBOOL) serialize(std::ostream&) const LM_NOEXCEPTDECL = 0;
virtual LM_SCHEDULABLE(LM_ERRBOOL) deserialize(std::istream&) LM_NOEXCEPTDECL = 0;
virtual LM_SCHEDULABLE(LM_ERRBOOL) load_grammar(const std::string&, bool override_temperature [[maybe_unused]] = false) LM_NOEXCEPTDECL {
LM_COTHROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
}
virtual LM_SCHEDULABLE(LM_ERRBOOL) unload_grammar() LM_NOEXCEPTDECL {
LM_COTHROW("Grammar is not available for this models backend", LM_BOOL_ERROR);
}
virtual const std::string& get_prompt() const LM_NOEXCEPTDECL = 0;
virtual bool is_mirostat_available() const noexcept {return false;}
virtual bool is_grammar_available() const noexcept {return false;}
LM_LAST_ERROR_GETTER
};

View file

@ -3,13 +3,17 @@
#include <cstring>
#include <ggml.h>
#include <llama.h>
#include <common/grammar-parser.h>
namespace LM {
class LLaMAInference final : public Inference {
struct State {
llama_context *ctx = nullptr;
struct llama_model *model;
llama_model *model;
llama_grammar *grammar = nullptr;
bool grammar_override_temp;
grammar_parser::parse_state parsed_grammar;
std::string prompt; // Mostly here for easy "debugging"
std::vector<int> tokens;
unsigned n_ctx;
@ -131,29 +135,39 @@ class LLaMAInference final : public Inference {
// Sample repeat penalty
auto n_repeat_last = std::min<size_t>(state->tokens.size(), params.n_repeat_last);
llama_sample_repetition_penalty(state->ctx, &candidates_p, params.n_repeat_last?(state->tokens.data()+state->tokens.size()-n_repeat_last):nullptr, n_repeat_last, params.repeat_penalty);
// Temperature sampling
switch (params.prefer_mirostat) {
case 0: {
llama_sample_top_k(state->ctx, &candidates_p, params.top_k, 1);
llama_sample_tail_free(state->ctx, &candidates_p, 1.0f, 1);
llama_sample_typical(state->ctx, &candidates_p, 1.0f, 1);
llama_sample_top_p(state->ctx, &candidates_p, params.top_p, 1);
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
// Grammar sampling
if (state->grammar) {
llama_sample_grammar(state->ctx, &candidates_p, state->grammar);
}
if (!(state->grammar && state->grammar_override_temp) && (params.temp > 0.01f || params.temp < -0.01f)) {
// Temperature sampling
switch (params.prefer_mirostat) {
case 0: {
llama_sample_top_k(state->ctx, &candidates_p, params.top_k, 1);
llama_sample_tail_free(state->ctx, &candidates_p, 1.0f, 1);
llama_sample_typical(state->ctx, &candidates_p, 1.0f, 1);
llama_sample_top_p(state->ctx, &candidates_p, params.top_p, 1);
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
return llama_sample_token(state->ctx, &candidates_p);
}
case 1: {
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
const int mirostat_m = 100;
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
return llama_sample_token_mirostat(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, mirostat_m, &mirostat_mu);
}
case 2: {
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
return llama_sample_token_mirostat_v2(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, &mirostat_mu);
}
default: LM_THROW("Invalid mirostat version "+std::to_string(params.prefer_mirostat), LM_BOOL_ERROR);
}
} else {
// Greedy sampling
abort();
return llama_sample_token(state->ctx, &candidates_p);
}
case 1: {
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
const int mirostat_m = 100;
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
return llama_sample_token_mirostat(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, mirostat_m, &mirostat_mu);
}
case 2: {
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
return llama_sample_token_mirostat_v2(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, &mirostat_mu);
}
default: LM_THROW("Invalid mirostat version "+std::to_string(params.prefer_mirostat), LM_BOOL_ERROR);
}
}
public:
@ -344,6 +358,30 @@ public:
LM_CORETURN LM_BOOL_SUCCESS;
}
LM_SCHEDULABLE(LM_ERRBOOL) load_grammar(const std::string& src, bool override_temperature) LM_NOEXCEPTDECL override {
auto& state = get_state();
state->parsed_grammar = grammar_parser::parse(src.c_str());
if (state->parsed_grammar.rules.empty()) {
LM_COTHROW("Failed to parse grammar (or no rules)", LM_BOOL_ERROR);
}
auto rules = state->parsed_grammar.c_rules();
state->grammar = llama_grammar_init(rules.data(), rules.size(), state->parsed_grammar.symbol_ids.at("root"));
if (!state->grammar) {
LM_COTHROW("Failed to generate llama grammar", LM_BOOL_ERROR);
}
state->grammar_override_temp = override_temperature;
LM_CORETURN LM_BOOL_SUCCESS;
}
LM_SCHEDULABLE(LM_ERRBOOL) unload_grammar() LM_NOEXCEPTDECL override {
get_state()->grammar = nullptr;
LM_CORETURN LM_BOOL_SUCCESS;
}
const std::string &get_prompt() const LM_NOEXCEPTDECL override {
return get_state()->prompt;
}
@ -351,5 +389,9 @@ public:
bool is_mirostat_available() const noexcept override {
return true;
}
bool is_grammar_available() const noexcept override {
return true;
}
};
}

View file

@ -332,7 +332,9 @@ function(include_ggml DIRECTORY SUFFIX WITH_LLAMA)
if (WITH_LLAMA)
SET(LLAMA_SOURCES
${DIRECTORY}/llama.cpp
${DIRECTORY}/llama.h)
${DIRECTORY}/llama.h
${DIRECTORY}/common/grammar-parser.h
${DIRECTORY}/common/grammar-parser.cpp)
remove_nonexistent(LLAMA_SOURCES)
add_library(llama${SUFFIX} ${LLAMA_SOURCES})