diff --git a/anyproc b/anyproc index 59e7532..e820c00 160000 --- a/anyproc +++ b/anyproc @@ -1 +1 @@ -Subproject commit 59e75326813a2c73b94340892de55a16d8ee8d9a +Subproject commit e820c00e723e6eac2a935ad7b5b91cee9446eb9d diff --git a/example_config.txt b/example_config.txt index 222250d..1d3e1b9 100644 --- a/example_config.txt +++ b/example_config.txt @@ -21,3 +21,4 @@ pool_size 2 threads 4 timeout 120 ctx_size 1012 +scroll_keep 20 diff --git a/explained_config.txt b/explained_config.txt new file mode 100644 index 0000000..ee6726f --- /dev/null +++ b/explained_config.txt @@ -0,0 +1,54 @@ +token MTA0MDYxMTQzNjUwNzk1OTMyNw.Gl_iMU.jVVM3bRqBJVi8ORVpWHquOivlASGJpRySt8qFg + +# The following parameters are set to their defaults here and can be ommited + +# Directory the models are located in. For example, see example_models/ +models_dir models + +# Language everything is translated to (will be disabled if set to "EN" anyways) +language EN + +# Weather the bot should respond to pings outside threads +threads_only true + +# Weather the bot should update messages periodically while writing them. Incompatible with translation +live_edit false + +# Model to use outside threads +default_inference_model 13b-vanilla + +# Model to be used for translation +translation_model none + +# Few-shot prompt for non-instruct-mode. See example_prompt.txt +prompt_file none + +# Prompt for instruct-mode. See example_instruct_prompt.txt +instruct_prompt_file none + +# Amount of shards ("instances") of this bot. This is NOT Discord sharding +shard_count 1 + +# Number of this shard. Must be unique in the entire bot +shard_id 0 + +# Weather context ("chat histories") should persist restarts +persistance true + +# Weather swapping should be prevented using mlock +mlock false + +# Amount of contexts to keep in RAM at a time +pool_size 2 + +# Amount of CPU threads to use +threads 4 + +# Response/Evaluation timeout in seconds +timeout 120 + +# Max. context size +ctx_size 1012 + +# Percentage of context below prompt to be kept when scrolling. 0 means no context will be kept when scolling. +scroll_keep 20 diff --git a/main.cpp b/main.cpp index 2563660..c991d9e 100644 --- a/main.cpp +++ b/main.cpp @@ -108,6 +108,7 @@ public: pool_size = 2, timeout = 120, threads = 4, + scroll_keep = 20, shard_cout = 1, shard_id = 0; bool persistance = true, @@ -205,7 +206,9 @@ private: if (channel_cfg.instruct_mode && config.instruct_prompt_file == "none") return; std::ifstream f((*channel_cfg.model_name)+(channel_cfg.instruct_mode?"_instruct_init_cache":"_init_cache"), std::ios::binary); inference.deserialize(f); + // Set params inference.params.n_ctx_window_top_bar = inference.get_context_size(); + inference.params.scroll_keep = float(config.scroll_keep) * 0.01f; } // Must run in llama thread LM::Inference &llm_start(dpp::snowflake id, const BotChannelConfig& channel_cfg) { @@ -219,12 +222,20 @@ private: // Must run in llama thread LM::Inference &llm_get_inference(dpp::snowflake id, const BotChannelConfig& channel_cfg) { ENSURE_LLM_THREAD(); + // Get inference auto inference_opt = llm_pool.get_inference(id); if (!inference_opt.has_value()) { // Start new inference inference_opt = llm_start(id, channel_cfg); } - return inference_opt.value(); + auto& fres = inference_opt.value(); + // Set scroll callback + fres.get().set_scroll_callback([msg = dpp::message(), channel_id = id] (float progress) { + std::cout << "WARNING: " << channel_id << " is scrolling! " << progress << "% \r" << std::flush; + return true; + }); + // Return inference + return fres; } // Must run in llama thread @@ -239,6 +250,12 @@ private: texts.timeout = llm_translate_from_en(texts.timeout); texts.translated = true; } + // Set scroll callback + auto scroll_cb = [] (float) { + std::cerr << "Error: Prompt doesn't fit into max. context size!" << std::endl; + abort(); + return false; + }; // Build init caches std::string filename; for (const auto& [model_name, model_config] : model_configs) { @@ -266,6 +283,7 @@ private: // Append using namespace fmt::literals; if (prompt.back() != '\n') prompt.push_back('\n'); + llm->set_scroll_callback(scroll_cb); llm->append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username), show_console_progress); // Serialize end result std::ofstream f(filename, std::ios::binary); @@ -294,6 +312,7 @@ private: // Append using namespace fmt::literals; if (prompt.back() != '\n') prompt.push_back('\n'); + llm->set_scroll_callback(scroll_cb); llm->append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username)+"\n\n"+model_config.user_prompt, show_console_progress); // Serialize end result std::ofstream f(filename, std::ios::binary); @@ -671,6 +690,8 @@ int main(int argc, char **argv) { cfg.pool_size = std::stoi(value); } else if (key == "threads") { cfg.threads = std::stoi(value); + } else if (key == "scroll_keep") { + cfg.scroll_keep = std::stoi(value); } else if (key == "shard_cout") { cfg.shard_cout = std::stoi(value); } else if (key == "shard_id") { @@ -789,6 +810,10 @@ int main(int argc, char **argv) { exit(-7); } } + if (cfg.scroll_keep >= 99) { + std::cerr << "Error: Scroll_keep must be a non-float percentage and in a range of 0-99." << std::endl; + exit(-12); + } // Construct and configure bot Bot bot(cfg, models);