1
0
Fork 0
mirror of https://gitlab.com/niansa/discord_llama.git synced 2025-03-06 20:48:25 +01:00

Implemented proper scrolling

This commit is contained in:
niansa 2023-04-28 18:04:53 +02:00
parent 524d90e138
commit 31885f6cd2
4 changed files with 82 additions and 2 deletions

@ -1 +1 @@
Subproject commit 59e75326813a2c73b94340892de55a16d8ee8d9a
Subproject commit e820c00e723e6eac2a935ad7b5b91cee9446eb9d

View file

@ -21,3 +21,4 @@ pool_size 2
threads 4
timeout 120
ctx_size 1012
scroll_keep 20

54
explained_config.txt Normal file
View file

@ -0,0 +1,54 @@
token MTA0MDYxMTQzNjUwNzk1OTMyNw.Gl_iMU.jVVM3bRqBJVi8ORVpWHquOivlASGJpRySt8qFg
# The following parameters are set to their defaults here and can be ommited
# Directory the models are located in. For example, see example_models/
models_dir models
# Language everything is translated to (will be disabled if set to "EN" anyways)
language EN
# Weather the bot should respond to pings outside threads
threads_only true
# Weather the bot should update messages periodically while writing them. Incompatible with translation
live_edit false
# Model to use outside threads
default_inference_model 13b-vanilla
# Model to be used for translation
translation_model none
# Few-shot prompt for non-instruct-mode. See example_prompt.txt
prompt_file none
# Prompt for instruct-mode. See example_instruct_prompt.txt
instruct_prompt_file none
# Amount of shards ("instances") of this bot. This is NOT Discord sharding
shard_count 1
# Number of this shard. Must be unique in the entire bot
shard_id 0
# Weather context ("chat histories") should persist restarts
persistance true
# Weather swapping should be prevented using mlock
mlock false
# Amount of contexts to keep in RAM at a time
pool_size 2
# Amount of CPU threads to use
threads 4
# Response/Evaluation timeout in seconds
timeout 120
# Max. context size
ctx_size 1012
# Percentage of context below prompt to be kept when scrolling. 0 means no context will be kept when scolling.
scroll_keep 20

View file

@ -108,6 +108,7 @@ public:
pool_size = 2,
timeout = 120,
threads = 4,
scroll_keep = 20,
shard_cout = 1,
shard_id = 0;
bool persistance = true,
@ -205,7 +206,9 @@ private:
if (channel_cfg.instruct_mode && config.instruct_prompt_file == "none") return;
std::ifstream f((*channel_cfg.model_name)+(channel_cfg.instruct_mode?"_instruct_init_cache":"_init_cache"), std::ios::binary);
inference.deserialize(f);
// Set params
inference.params.n_ctx_window_top_bar = inference.get_context_size();
inference.params.scroll_keep = float(config.scroll_keep) * 0.01f;
}
// Must run in llama thread
LM::Inference &llm_start(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
@ -219,12 +222,20 @@ private:
// Must run in llama thread
LM::Inference &llm_get_inference(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
// Get inference
auto inference_opt = llm_pool.get_inference(id);
if (!inference_opt.has_value()) {
// Start new inference
inference_opt = llm_start(id, channel_cfg);
}
return inference_opt.value();
auto& fres = inference_opt.value();
// Set scroll callback
fres.get().set_scroll_callback([msg = dpp::message(), channel_id = id] (float progress) {
std::cout << "WARNING: " << channel_id << " is scrolling! " << progress << "% \r" << std::flush;
return true;
});
// Return inference
return fres;
}
// Must run in llama thread
@ -239,6 +250,12 @@ private:
texts.timeout = llm_translate_from_en(texts.timeout);
texts.translated = true;
}
// Set scroll callback
auto scroll_cb = [] (float) {
std::cerr << "Error: Prompt doesn't fit into max. context size!" << std::endl;
abort();
return false;
};
// Build init caches
std::string filename;
for (const auto& [model_name, model_config] : model_configs) {
@ -266,6 +283,7 @@ private:
// Append
using namespace fmt::literals;
if (prompt.back() != '\n') prompt.push_back('\n');
llm->set_scroll_callback(scroll_cb);
llm->append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username), show_console_progress);
// Serialize end result
std::ofstream f(filename, std::ios::binary);
@ -294,6 +312,7 @@ private:
// Append
using namespace fmt::literals;
if (prompt.back() != '\n') prompt.push_back('\n');
llm->set_scroll_callback(scroll_cb);
llm->append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username)+"\n\n"+model_config.user_prompt, show_console_progress);
// Serialize end result
std::ofstream f(filename, std::ios::binary);
@ -671,6 +690,8 @@ int main(int argc, char **argv) {
cfg.pool_size = std::stoi(value);
} else if (key == "threads") {
cfg.threads = std::stoi(value);
} else if (key == "scroll_keep") {
cfg.scroll_keep = std::stoi(value);
} else if (key == "shard_cout") {
cfg.shard_cout = std::stoi(value);
} else if (key == "shard_id") {
@ -789,6 +810,10 @@ int main(int argc, char **argv) {
exit(-7);
}
}
if (cfg.scroll_keep >= 99) {
std::cerr << "Error: Scroll_keep must be a non-float percentage and in a range of 0-99." << std::endl;
exit(-12);
}
// Construct and configure bot
Bot bot(cfg, models);