mirror of
https://gitlab.com/niansa/discord_llama.git
synced 2025-03-06 20:48:25 +01:00
Don't apply repeat penalty in instruct mode
This commit is contained in:
parent
bf568c3c8a
commit
485a628806
3 changed files with 18 additions and 16 deletions
2
anyproc
2
anyproc
|
@ -1 +1 @@
|
|||
Subproject commit 03b1689ef708b26ac139d09c3fa194195eef7707
|
||||
Subproject commit 7ea62dcd2bc26419e81482b4649e824f331def97
|
|
@ -1,4 +1,4 @@
|
|||
filename ggml-vicuna-13b-1.1-q4_0.bin
|
||||
instruct_mode_policy force
|
||||
user_prompt HUMAN:
|
||||
user_prompt USER:
|
||||
bot_prompt ASSISTANT:
|
||||
|
|
30
main.cpp
30
main.cpp
|
@ -198,13 +198,13 @@ private:
|
|||
fres.use_mlock = config.mlock;
|
||||
return fres;
|
||||
}
|
||||
LM::Inference::Params llm_get_params() const {
|
||||
LM::Inference::Params llm_get_params(bool instruct_mode = false) const {
|
||||
return {
|
||||
.n_threads = int(config.threads),
|
||||
.n_ctx = int(config.ctx_size),
|
||||
.n_repeat_last = 256,
|
||||
.n_repeat_last = instruct_mode?0:256,
|
||||
.temp = 0.3f,
|
||||
.repeat_penalty = 1.372222224f,
|
||||
.repeat_penalty = instruct_mode?1.0f:1.372222224f,
|
||||
.use_mlock = config.mlock
|
||||
};
|
||||
}
|
||||
|
@ -218,10 +218,10 @@ private:
|
|||
inference.deserialize(f);
|
||||
}
|
||||
// Must run in llama thread
|
||||
LM::Inference &llm_restart(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
|
||||
LM::Inference &llm_start(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
|
||||
ENSURE_LLM_THREAD();
|
||||
// Get or create inference
|
||||
auto& inference = llm_pool.get_or_create_inference(id, channel_cfg.model_config->weight_path, llm_get_params());
|
||||
auto& inference = llm_pool.create_inference(id, channel_cfg.model_config->weight_path, llm_get_params(channel_cfg.instruct_mode));
|
||||
llm_restart(inference, channel_cfg);
|
||||
return inference;
|
||||
}
|
||||
|
@ -232,7 +232,7 @@ private:
|
|||
auto inference_opt = llm_pool.get_inference(id);
|
||||
if (!inference_opt.has_value()) {
|
||||
// Start new inference
|
||||
inference_opt = llm_restart(id, channel_cfg);
|
||||
inference_opt = llm_start(id, channel_cfg);
|
||||
}
|
||||
return inference_opt.value();
|
||||
}
|
||||
|
@ -350,9 +350,8 @@ private:
|
|||
}
|
||||
}
|
||||
// Must run in llama thread
|
||||
void prompt_add_trigger(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
|
||||
void prompt_add_trigger(LM::Inference& inference, const BotChannelConfig& channel_cfg) {
|
||||
ENSURE_LLM_THREAD();
|
||||
auto& inference = llm_get_inference(id, channel_cfg);
|
||||
try {
|
||||
if (channel_cfg.instruct_mode) {
|
||||
inference.append('\n'+channel_cfg.model_config->bot_prompt+"\n\n");
|
||||
|
@ -368,10 +367,10 @@ private:
|
|||
void reply(dpp::snowflake id, dpp::message msg, const BotChannelConfig& channel_cfg) {
|
||||
ENSURE_LLM_THREAD();
|
||||
try {
|
||||
// Trigger LLM correctly
|
||||
prompt_add_trigger(id, channel_cfg);
|
||||
// Get inference
|
||||
auto& inference = llm_get_inference(id, channel_cfg);
|
||||
// Trigger LLM correctly
|
||||
prompt_add_trigger(inference, channel_cfg);
|
||||
// Run model
|
||||
Timer timeout;
|
||||
bool timeout_exceeded = false;
|
||||
|
@ -556,15 +555,18 @@ public:
|
|||
channel_cfg.model_name = &config.default_inference_model;
|
||||
channel_cfg.model_config = config.default_inference_model_cfg;
|
||||
}
|
||||
// Debug store command
|
||||
if (msg.content == "!store") {
|
||||
llm_pool.store_all(); //DEBUG
|
||||
# warning DEBUG CODE!!!
|
||||
return;
|
||||
}
|
||||
// Append message
|
||||
thread_pool.submit([=, this] () {
|
||||
prompt_add_msg(msg, channel_cfg);
|
||||
});
|
||||
// Handle message somehow...
|
||||
if (msg.content == "!store") {
|
||||
llm_pool.store_all(); //DEBUG
|
||||
# warning DEBUG CODE!!!
|
||||
} else if (in_bot_thread) {
|
||||
if (in_bot_thread) {
|
||||
// Send a reply
|
||||
enqueue_reply(msg.channel_id, channel_cfg);
|
||||
} else if (msg.content == "!trigger") {
|
||||
|
|
Loading…
Add table
Reference in a new issue