mirror of
https://gitlab.com/niansa/libjustlm.git
synced 2025-03-06 20:49:17 +01:00
Fully implemented grammar sampling
This commit is contained in:
parent
f5314a0dde
commit
215db6b9b7
1 changed files with 11 additions and 5 deletions
|
@ -121,6 +121,13 @@ class LLaMAInference final : public Inference {
|
|||
LM_CORETURN LM_BOOL_SUCCESS;
|
||||
}
|
||||
|
||||
int accept_token(int t) {
|
||||
auto& state = get_state();
|
||||
if (state->grammar)
|
||||
llama_grammar_accept_token(state->ctx, state->grammar, t);
|
||||
return t;
|
||||
}
|
||||
|
||||
int llama_sample_top_p_top_k() {
|
||||
auto& state = get_state();
|
||||
auto logits = llama_get_logits(state->ctx);
|
||||
|
@ -148,25 +155,24 @@ class LLaMAInference final : public Inference {
|
|||
llama_sample_typical(state->ctx, &candidates_p, 1.0f, 1);
|
||||
llama_sample_top_p(state->ctx, &candidates_p, params.top_p, 1);
|
||||
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
|
||||
return llama_sample_token(state->ctx, &candidates_p);
|
||||
return accept_token(llama_sample_token(state->ctx, &candidates_p));
|
||||
}
|
||||
case 1: {
|
||||
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
|
||||
const int mirostat_m = 100;
|
||||
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
|
||||
return llama_sample_token_mirostat(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, mirostat_m, &mirostat_mu);
|
||||
return accept_token(llama_sample_token_mirostat(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, mirostat_m, &mirostat_mu));
|
||||
}
|
||||
case 2: {
|
||||
float mirostat_mu = 2.0f * params.mirostat_target_entropy;
|
||||
llama_sample_temperature(state->ctx, &candidates_p, params.temp);
|
||||
return llama_sample_token_mirostat_v2(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, &mirostat_mu);
|
||||
return accept_token(llama_sample_token_mirostat_v2(state->ctx, &candidates_p, params.mirostat_target_entropy, params.mirostat_learning_rate, &mirostat_mu));
|
||||
}
|
||||
default: LM_THROW("Invalid mirostat version "+std::to_string(params.prefer_mirostat), LM_BOOL_ERROR);
|
||||
}
|
||||
} else {
|
||||
// Greedy sampling
|
||||
abort();
|
||||
return llama_sample_token(state->ctx, &candidates_p);
|
||||
return accept_token(llama_sample_token(state->ctx, &candidates_p));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue