1
0
Fork 0
mirror of https://gitlab.com/niansa/discord_llama.git synced 2025-03-06 20:48:25 +01:00

Don't apply repeat penalty in instruct mode

This commit is contained in:
niansa 2023-04-25 17:23:24 +02:00
parent bf568c3c8a
commit 485a628806
3 changed files with 18 additions and 16 deletions

@ -1 +1 @@
Subproject commit 03b1689ef708b26ac139d09c3fa194195eef7707
Subproject commit 7ea62dcd2bc26419e81482b4649e824f331def97

View file

@ -1,4 +1,4 @@
filename ggml-vicuna-13b-1.1-q4_0.bin
instruct_mode_policy force
user_prompt HUMAN:
user_prompt USER:
bot_prompt ASSISTANT:

View file

@ -198,13 +198,13 @@ private:
fres.use_mlock = config.mlock;
return fres;
}
LM::Inference::Params llm_get_params() const {
LM::Inference::Params llm_get_params(bool instruct_mode = false) const {
return {
.n_threads = int(config.threads),
.n_ctx = int(config.ctx_size),
.n_repeat_last = 256,
.n_repeat_last = instruct_mode?0:256,
.temp = 0.3f,
.repeat_penalty = 1.372222224f,
.repeat_penalty = instruct_mode?1.0f:1.372222224f,
.use_mlock = config.mlock
};
}
@ -218,10 +218,10 @@ private:
inference.deserialize(f);
}
// Must run in llama thread
LM::Inference &llm_restart(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
LM::Inference &llm_start(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
// Get or create inference
auto& inference = llm_pool.get_or_create_inference(id, channel_cfg.model_config->weight_path, llm_get_params());
auto& inference = llm_pool.create_inference(id, channel_cfg.model_config->weight_path, llm_get_params(channel_cfg.instruct_mode));
llm_restart(inference, channel_cfg);
return inference;
}
@ -232,7 +232,7 @@ private:
auto inference_opt = llm_pool.get_inference(id);
if (!inference_opt.has_value()) {
// Start new inference
inference_opt = llm_restart(id, channel_cfg);
inference_opt = llm_start(id, channel_cfg);
}
return inference_opt.value();
}
@ -350,9 +350,8 @@ private:
}
}
// Must run in llama thread
void prompt_add_trigger(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
void prompt_add_trigger(LM::Inference& inference, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
auto& inference = llm_get_inference(id, channel_cfg);
try {
if (channel_cfg.instruct_mode) {
inference.append('\n'+channel_cfg.model_config->bot_prompt+"\n\n");
@ -368,10 +367,10 @@ private:
void reply(dpp::snowflake id, dpp::message msg, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
try {
// Trigger LLM correctly
prompt_add_trigger(id, channel_cfg);
// Get inference
auto& inference = llm_get_inference(id, channel_cfg);
// Trigger LLM correctly
prompt_add_trigger(inference, channel_cfg);
// Run model
Timer timeout;
bool timeout_exceeded = false;
@ -556,15 +555,18 @@ public:
channel_cfg.model_name = &config.default_inference_model;
channel_cfg.model_config = config.default_inference_model_cfg;
}
// Debug store command
if (msg.content == "!store") {
llm_pool.store_all(); //DEBUG
# warning DEBUG CODE!!!
return;
}
// Append message
thread_pool.submit([=, this] () {
prompt_add_msg(msg, channel_cfg);
});
// Handle message somehow...
if (msg.content == "!store") {
llm_pool.store_all(); //DEBUG
# warning DEBUG CODE!!!
} else if (in_bot_thread) {
if (in_bot_thread) {
// Send a reply
enqueue_reply(msg.channel_id, channel_cfg);
} else if (msg.content == "!trigger") {