mirror of
https://gitlab.com/niansa/discord_llama.git
synced 2025-03-06 20:48:25 +01:00
Don't apply repeat penalty in instruct mode
This commit is contained in:
parent
bf568c3c8a
commit
485a628806
3 changed files with 18 additions and 16 deletions
2
anyproc
2
anyproc
|
@ -1 +1 @@
|
||||||
Subproject commit 03b1689ef708b26ac139d09c3fa194195eef7707
|
Subproject commit 7ea62dcd2bc26419e81482b4649e824f331def97
|
|
@ -1,4 +1,4 @@
|
||||||
filename ggml-vicuna-13b-1.1-q4_0.bin
|
filename ggml-vicuna-13b-1.1-q4_0.bin
|
||||||
instruct_mode_policy force
|
instruct_mode_policy force
|
||||||
user_prompt HUMAN:
|
user_prompt USER:
|
||||||
bot_prompt ASSISTANT:
|
bot_prompt ASSISTANT:
|
||||||
|
|
30
main.cpp
30
main.cpp
|
@ -198,13 +198,13 @@ private:
|
||||||
fres.use_mlock = config.mlock;
|
fres.use_mlock = config.mlock;
|
||||||
return fres;
|
return fres;
|
||||||
}
|
}
|
||||||
LM::Inference::Params llm_get_params() const {
|
LM::Inference::Params llm_get_params(bool instruct_mode = false) const {
|
||||||
return {
|
return {
|
||||||
.n_threads = int(config.threads),
|
.n_threads = int(config.threads),
|
||||||
.n_ctx = int(config.ctx_size),
|
.n_ctx = int(config.ctx_size),
|
||||||
.n_repeat_last = 256,
|
.n_repeat_last = instruct_mode?0:256,
|
||||||
.temp = 0.3f,
|
.temp = 0.3f,
|
||||||
.repeat_penalty = 1.372222224f,
|
.repeat_penalty = instruct_mode?1.0f:1.372222224f,
|
||||||
.use_mlock = config.mlock
|
.use_mlock = config.mlock
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -218,10 +218,10 @@ private:
|
||||||
inference.deserialize(f);
|
inference.deserialize(f);
|
||||||
}
|
}
|
||||||
// Must run in llama thread
|
// Must run in llama thread
|
||||||
LM::Inference &llm_restart(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
|
LM::Inference &llm_start(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
|
||||||
ENSURE_LLM_THREAD();
|
ENSURE_LLM_THREAD();
|
||||||
// Get or create inference
|
// Get or create inference
|
||||||
auto& inference = llm_pool.get_or_create_inference(id, channel_cfg.model_config->weight_path, llm_get_params());
|
auto& inference = llm_pool.create_inference(id, channel_cfg.model_config->weight_path, llm_get_params(channel_cfg.instruct_mode));
|
||||||
llm_restart(inference, channel_cfg);
|
llm_restart(inference, channel_cfg);
|
||||||
return inference;
|
return inference;
|
||||||
}
|
}
|
||||||
|
@ -232,7 +232,7 @@ private:
|
||||||
auto inference_opt = llm_pool.get_inference(id);
|
auto inference_opt = llm_pool.get_inference(id);
|
||||||
if (!inference_opt.has_value()) {
|
if (!inference_opt.has_value()) {
|
||||||
// Start new inference
|
// Start new inference
|
||||||
inference_opt = llm_restart(id, channel_cfg);
|
inference_opt = llm_start(id, channel_cfg);
|
||||||
}
|
}
|
||||||
return inference_opt.value();
|
return inference_opt.value();
|
||||||
}
|
}
|
||||||
|
@ -350,9 +350,8 @@ private:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Must run in llama thread
|
// Must run in llama thread
|
||||||
void prompt_add_trigger(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
|
void prompt_add_trigger(LM::Inference& inference, const BotChannelConfig& channel_cfg) {
|
||||||
ENSURE_LLM_THREAD();
|
ENSURE_LLM_THREAD();
|
||||||
auto& inference = llm_get_inference(id, channel_cfg);
|
|
||||||
try {
|
try {
|
||||||
if (channel_cfg.instruct_mode) {
|
if (channel_cfg.instruct_mode) {
|
||||||
inference.append('\n'+channel_cfg.model_config->bot_prompt+"\n\n");
|
inference.append('\n'+channel_cfg.model_config->bot_prompt+"\n\n");
|
||||||
|
@ -368,10 +367,10 @@ private:
|
||||||
void reply(dpp::snowflake id, dpp::message msg, const BotChannelConfig& channel_cfg) {
|
void reply(dpp::snowflake id, dpp::message msg, const BotChannelConfig& channel_cfg) {
|
||||||
ENSURE_LLM_THREAD();
|
ENSURE_LLM_THREAD();
|
||||||
try {
|
try {
|
||||||
// Trigger LLM correctly
|
|
||||||
prompt_add_trigger(id, channel_cfg);
|
|
||||||
// Get inference
|
// Get inference
|
||||||
auto& inference = llm_get_inference(id, channel_cfg);
|
auto& inference = llm_get_inference(id, channel_cfg);
|
||||||
|
// Trigger LLM correctly
|
||||||
|
prompt_add_trigger(inference, channel_cfg);
|
||||||
// Run model
|
// Run model
|
||||||
Timer timeout;
|
Timer timeout;
|
||||||
bool timeout_exceeded = false;
|
bool timeout_exceeded = false;
|
||||||
|
@ -556,15 +555,18 @@ public:
|
||||||
channel_cfg.model_name = &config.default_inference_model;
|
channel_cfg.model_name = &config.default_inference_model;
|
||||||
channel_cfg.model_config = config.default_inference_model_cfg;
|
channel_cfg.model_config = config.default_inference_model_cfg;
|
||||||
}
|
}
|
||||||
|
// Debug store command
|
||||||
|
if (msg.content == "!store") {
|
||||||
|
llm_pool.store_all(); //DEBUG
|
||||||
|
# warning DEBUG CODE!!!
|
||||||
|
return;
|
||||||
|
}
|
||||||
// Append message
|
// Append message
|
||||||
thread_pool.submit([=, this] () {
|
thread_pool.submit([=, this] () {
|
||||||
prompt_add_msg(msg, channel_cfg);
|
prompt_add_msg(msg, channel_cfg);
|
||||||
});
|
});
|
||||||
// Handle message somehow...
|
// Handle message somehow...
|
||||||
if (msg.content == "!store") {
|
if (in_bot_thread) {
|
||||||
llm_pool.store_all(); //DEBUG
|
|
||||||
# warning DEBUG CODE!!!
|
|
||||||
} else if (in_bot_thread) {
|
|
||||||
// Send a reply
|
// Send a reply
|
||||||
enqueue_reply(msg.channel_id, channel_cfg);
|
enqueue_reply(msg.channel_id, channel_cfg);
|
||||||
} else if (msg.content == "!trigger") {
|
} else if (msg.content == "!trigger") {
|
||||||
|
|
Loading…
Add table
Reference in a new issue