1
0
Fork 0
mirror of https://gitlab.com/niansa/discord_llama.git synced 2025-03-06 20:48:25 +01:00

Don't apply repeat penalty in instruct mode

This commit is contained in:
niansa 2023-04-25 17:23:24 +02:00
parent bf568c3c8a
commit 485a628806
3 changed files with 18 additions and 16 deletions

@ -1 +1 @@
Subproject commit 03b1689ef708b26ac139d09c3fa194195eef7707 Subproject commit 7ea62dcd2bc26419e81482b4649e824f331def97

View file

@ -1,4 +1,4 @@
filename ggml-vicuna-13b-1.1-q4_0.bin filename ggml-vicuna-13b-1.1-q4_0.bin
instruct_mode_policy force instruct_mode_policy force
user_prompt HUMAN: user_prompt USER:
bot_prompt ASSISTANT: bot_prompt ASSISTANT:

View file

@ -198,13 +198,13 @@ private:
fres.use_mlock = config.mlock; fres.use_mlock = config.mlock;
return fres; return fres;
} }
LM::Inference::Params llm_get_params() const { LM::Inference::Params llm_get_params(bool instruct_mode = false) const {
return { return {
.n_threads = int(config.threads), .n_threads = int(config.threads),
.n_ctx = int(config.ctx_size), .n_ctx = int(config.ctx_size),
.n_repeat_last = 256, .n_repeat_last = instruct_mode?0:256,
.temp = 0.3f, .temp = 0.3f,
.repeat_penalty = 1.372222224f, .repeat_penalty = instruct_mode?1.0f:1.372222224f,
.use_mlock = config.mlock .use_mlock = config.mlock
}; };
} }
@ -218,10 +218,10 @@ private:
inference.deserialize(f); inference.deserialize(f);
} }
// Must run in llama thread // Must run in llama thread
LM::Inference &llm_restart(dpp::snowflake id, const BotChannelConfig& channel_cfg) { LM::Inference &llm_start(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD(); ENSURE_LLM_THREAD();
// Get or create inference // Get or create inference
auto& inference = llm_pool.get_or_create_inference(id, channel_cfg.model_config->weight_path, llm_get_params()); auto& inference = llm_pool.create_inference(id, channel_cfg.model_config->weight_path, llm_get_params(channel_cfg.instruct_mode));
llm_restart(inference, channel_cfg); llm_restart(inference, channel_cfg);
return inference; return inference;
} }
@ -232,7 +232,7 @@ private:
auto inference_opt = llm_pool.get_inference(id); auto inference_opt = llm_pool.get_inference(id);
if (!inference_opt.has_value()) { if (!inference_opt.has_value()) {
// Start new inference // Start new inference
inference_opt = llm_restart(id, channel_cfg); inference_opt = llm_start(id, channel_cfg);
} }
return inference_opt.value(); return inference_opt.value();
} }
@ -350,9 +350,8 @@ private:
} }
} }
// Must run in llama thread // Must run in llama thread
void prompt_add_trigger(dpp::snowflake id, const BotChannelConfig& channel_cfg) { void prompt_add_trigger(LM::Inference& inference, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD(); ENSURE_LLM_THREAD();
auto& inference = llm_get_inference(id, channel_cfg);
try { try {
if (channel_cfg.instruct_mode) { if (channel_cfg.instruct_mode) {
inference.append('\n'+channel_cfg.model_config->bot_prompt+"\n\n"); inference.append('\n'+channel_cfg.model_config->bot_prompt+"\n\n");
@ -368,10 +367,10 @@ private:
void reply(dpp::snowflake id, dpp::message msg, const BotChannelConfig& channel_cfg) { void reply(dpp::snowflake id, dpp::message msg, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD(); ENSURE_LLM_THREAD();
try { try {
// Trigger LLM correctly
prompt_add_trigger(id, channel_cfg);
// Get inference // Get inference
auto& inference = llm_get_inference(id, channel_cfg); auto& inference = llm_get_inference(id, channel_cfg);
// Trigger LLM correctly
prompt_add_trigger(inference, channel_cfg);
// Run model // Run model
Timer timeout; Timer timeout;
bool timeout_exceeded = false; bool timeout_exceeded = false;
@ -556,15 +555,18 @@ public:
channel_cfg.model_name = &config.default_inference_model; channel_cfg.model_name = &config.default_inference_model;
channel_cfg.model_config = config.default_inference_model_cfg; channel_cfg.model_config = config.default_inference_model_cfg;
} }
// Debug store command
if (msg.content == "!store") {
llm_pool.store_all(); //DEBUG
# warning DEBUG CODE!!!
return;
}
// Append message // Append message
thread_pool.submit([=, this] () { thread_pool.submit([=, this] () {
prompt_add_msg(msg, channel_cfg); prompt_add_msg(msg, channel_cfg);
}); });
// Handle message somehow... // Handle message somehow...
if (msg.content == "!store") { if (in_bot_thread) {
llm_pool.store_all(); //DEBUG
# warning DEBUG CODE!!!
} else if (in_bot_thread) {
// Send a reply // Send a reply
enqueue_reply(msg.channel_id, channel_cfg); enqueue_reply(msg.channel_id, channel_cfg);
} else if (msg.content == "!trigger") { } else if (msg.content == "!trigger") {