1
0
Fork 0
mirror of https://gitlab.com/niansa/libjustlm.git synced 2025-03-06 20:49:17 +01:00

Fully implemented grammar sampling

This commit is contained in:
niansa/tuxifan 2023-09-05 10:22:42 +02:00
parent f5314a0dde
commit 215db6b9b7

View file

@ -121,6 +121,13 @@ class LLaMAInference final : public Inference {
LM_CORETURN LM_BOOL_SUCCESS;
}
int accept_token(int t) {
auto& state = get_state();
if (state->grammar)
llama_grammar_accept_token(state->ctx, state->grammar, t);
return t;
}
int llama_sample_top_p_top_k() {
auto& state = get_state();
auto logits = llama_get_logits(state->ctx);
@ -148,25 +155,24 @@ class LLaMAInference final : public Inference {
llama_sample_typical(state->ctx, &candidates_p, 1.0f, 1);
llama_sample_top_p(state->ctx, &candidates_p, params.top_p, 1);
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
return llama_sample_token(state->ctx, &candidates_p);
return accept_token(llama_sample_token(state->ctx, &candidates_p));
}
case 1: {
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
const int mirostat_m = 100;
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
return llama_sample_token_mirostat(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, mirostat_m, &mirostat_mu);
return accept_token(llama_sample_token_mirostat(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, mirostat_m, &mirostat_mu));
}
case 2: {
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
return llama_sample_token_mirostat_v2(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, &mirostat_mu);
return accept_token(llama_sample_token_mirostat_v2(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, &mirostat_mu));
}
default: LM_THROW("Invalid mirostat version "+std::to_string(params.prefer_mirostat), LM_BOOL_ERROR);
}
} else {
// Greedy sampling
abort();
return llama_sample_token(state->ctx, &candidates_p);
return accept_token(llama_sample_token(state->ctx, &candidates_p));
}
}