mirror of
https://gitlab.com/niansa/discord_llama.git
synced 2025-03-06 20:48:25 +01:00
Implemented proper scrolling
This commit is contained in:
parent
524d90e138
commit
31885f6cd2
4 changed files with 82 additions and 2 deletions
2
anyproc
2
anyproc
|
@ -1 +1 @@
|
||||||
Subproject commit 59e75326813a2c73b94340892de55a16d8ee8d9a
|
Subproject commit e820c00e723e6eac2a935ad7b5b91cee9446eb9d
|
|
@ -21,3 +21,4 @@ pool_size 2
|
||||||
threads 4
|
threads 4
|
||||||
timeout 120
|
timeout 120
|
||||||
ctx_size 1012
|
ctx_size 1012
|
||||||
|
scroll_keep 20
|
||||||
|
|
54
explained_config.txt
Normal file
54
explained_config.txt
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
token MTA0MDYxMTQzNjUwNzk1OTMyNw.Gl_iMU.jVVM3bRqBJVi8ORVpWHquOivlASGJpRySt8qFg
|
||||||
|
|
||||||
|
# The following parameters are set to their defaults here and can be ommited
|
||||||
|
|
||||||
|
# Directory the models are located in. For example, see example_models/
|
||||||
|
models_dir models
|
||||||
|
|
||||||
|
# Language everything is translated to (will be disabled if set to "EN" anyways)
|
||||||
|
language EN
|
||||||
|
|
||||||
|
# Weather the bot should respond to pings outside threads
|
||||||
|
threads_only true
|
||||||
|
|
||||||
|
# Weather the bot should update messages periodically while writing them. Incompatible with translation
|
||||||
|
live_edit false
|
||||||
|
|
||||||
|
# Model to use outside threads
|
||||||
|
default_inference_model 13b-vanilla
|
||||||
|
|
||||||
|
# Model to be used for translation
|
||||||
|
translation_model none
|
||||||
|
|
||||||
|
# Few-shot prompt for non-instruct-mode. See example_prompt.txt
|
||||||
|
prompt_file none
|
||||||
|
|
||||||
|
# Prompt for instruct-mode. See example_instruct_prompt.txt
|
||||||
|
instruct_prompt_file none
|
||||||
|
|
||||||
|
# Amount of shards ("instances") of this bot. This is NOT Discord sharding
|
||||||
|
shard_count 1
|
||||||
|
|
||||||
|
# Number of this shard. Must be unique in the entire bot
|
||||||
|
shard_id 0
|
||||||
|
|
||||||
|
# Weather context ("chat histories") should persist restarts
|
||||||
|
persistance true
|
||||||
|
|
||||||
|
# Weather swapping should be prevented using mlock
|
||||||
|
mlock false
|
||||||
|
|
||||||
|
# Amount of contexts to keep in RAM at a time
|
||||||
|
pool_size 2
|
||||||
|
|
||||||
|
# Amount of CPU threads to use
|
||||||
|
threads 4
|
||||||
|
|
||||||
|
# Response/Evaluation timeout in seconds
|
||||||
|
timeout 120
|
||||||
|
|
||||||
|
# Max. context size
|
||||||
|
ctx_size 1012
|
||||||
|
|
||||||
|
# Percentage of context below prompt to be kept when scrolling. 0 means no context will be kept when scolling.
|
||||||
|
scroll_keep 20
|
27
main.cpp
27
main.cpp
|
@ -108,6 +108,7 @@ public:
|
||||||
pool_size = 2,
|
pool_size = 2,
|
||||||
timeout = 120,
|
timeout = 120,
|
||||||
threads = 4,
|
threads = 4,
|
||||||
|
scroll_keep = 20,
|
||||||
shard_cout = 1,
|
shard_cout = 1,
|
||||||
shard_id = 0;
|
shard_id = 0;
|
||||||
bool persistance = true,
|
bool persistance = true,
|
||||||
|
@ -205,7 +206,9 @@ private:
|
||||||
if (channel_cfg.instruct_mode && config.instruct_prompt_file == "none") return;
|
if (channel_cfg.instruct_mode && config.instruct_prompt_file == "none") return;
|
||||||
std::ifstream f((*channel_cfg.model_name)+(channel_cfg.instruct_mode?"_instruct_init_cache":"_init_cache"), std::ios::binary);
|
std::ifstream f((*channel_cfg.model_name)+(channel_cfg.instruct_mode?"_instruct_init_cache":"_init_cache"), std::ios::binary);
|
||||||
inference.deserialize(f);
|
inference.deserialize(f);
|
||||||
|
// Set params
|
||||||
inference.params.n_ctx_window_top_bar = inference.get_context_size();
|
inference.params.n_ctx_window_top_bar = inference.get_context_size();
|
||||||
|
inference.params.scroll_keep = float(config.scroll_keep) * 0.01f;
|
||||||
}
|
}
|
||||||
// Must run in llama thread
|
// Must run in llama thread
|
||||||
LM::Inference &llm_start(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
|
LM::Inference &llm_start(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
|
||||||
|
@ -219,12 +222,20 @@ private:
|
||||||
// Must run in llama thread
|
// Must run in llama thread
|
||||||
LM::Inference &llm_get_inference(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
|
LM::Inference &llm_get_inference(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
|
||||||
ENSURE_LLM_THREAD();
|
ENSURE_LLM_THREAD();
|
||||||
|
// Get inference
|
||||||
auto inference_opt = llm_pool.get_inference(id);
|
auto inference_opt = llm_pool.get_inference(id);
|
||||||
if (!inference_opt.has_value()) {
|
if (!inference_opt.has_value()) {
|
||||||
// Start new inference
|
// Start new inference
|
||||||
inference_opt = llm_start(id, channel_cfg);
|
inference_opt = llm_start(id, channel_cfg);
|
||||||
}
|
}
|
||||||
return inference_opt.value();
|
auto& fres = inference_opt.value();
|
||||||
|
// Set scroll callback
|
||||||
|
fres.get().set_scroll_callback([msg = dpp::message(), channel_id = id] (float progress) {
|
||||||
|
std::cout << "WARNING: " << channel_id << " is scrolling! " << progress << "% \r" << std::flush;
|
||||||
|
return true;
|
||||||
|
});
|
||||||
|
// Return inference
|
||||||
|
return fres;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Must run in llama thread
|
// Must run in llama thread
|
||||||
|
@ -239,6 +250,12 @@ private:
|
||||||
texts.timeout = llm_translate_from_en(texts.timeout);
|
texts.timeout = llm_translate_from_en(texts.timeout);
|
||||||
texts.translated = true;
|
texts.translated = true;
|
||||||
}
|
}
|
||||||
|
// Set scroll callback
|
||||||
|
auto scroll_cb = [] (float) {
|
||||||
|
std::cerr << "Error: Prompt doesn't fit into max. context size!" << std::endl;
|
||||||
|
abort();
|
||||||
|
return false;
|
||||||
|
};
|
||||||
// Build init caches
|
// Build init caches
|
||||||
std::string filename;
|
std::string filename;
|
||||||
for (const auto& [model_name, model_config] : model_configs) {
|
for (const auto& [model_name, model_config] : model_configs) {
|
||||||
|
@ -266,6 +283,7 @@ private:
|
||||||
// Append
|
// Append
|
||||||
using namespace fmt::literals;
|
using namespace fmt::literals;
|
||||||
if (prompt.back() != '\n') prompt.push_back('\n');
|
if (prompt.back() != '\n') prompt.push_back('\n');
|
||||||
|
llm->set_scroll_callback(scroll_cb);
|
||||||
llm->append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username), show_console_progress);
|
llm->append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username), show_console_progress);
|
||||||
// Serialize end result
|
// Serialize end result
|
||||||
std::ofstream f(filename, std::ios::binary);
|
std::ofstream f(filename, std::ios::binary);
|
||||||
|
@ -294,6 +312,7 @@ private:
|
||||||
// Append
|
// Append
|
||||||
using namespace fmt::literals;
|
using namespace fmt::literals;
|
||||||
if (prompt.back() != '\n') prompt.push_back('\n');
|
if (prompt.back() != '\n') prompt.push_back('\n');
|
||||||
|
llm->set_scroll_callback(scroll_cb);
|
||||||
llm->append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username)+"\n\n"+model_config.user_prompt, show_console_progress);
|
llm->append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username)+"\n\n"+model_config.user_prompt, show_console_progress);
|
||||||
// Serialize end result
|
// Serialize end result
|
||||||
std::ofstream f(filename, std::ios::binary);
|
std::ofstream f(filename, std::ios::binary);
|
||||||
|
@ -671,6 +690,8 @@ int main(int argc, char **argv) {
|
||||||
cfg.pool_size = std::stoi(value);
|
cfg.pool_size = std::stoi(value);
|
||||||
} else if (key == "threads") {
|
} else if (key == "threads") {
|
||||||
cfg.threads = std::stoi(value);
|
cfg.threads = std::stoi(value);
|
||||||
|
} else if (key == "scroll_keep") {
|
||||||
|
cfg.scroll_keep = std::stoi(value);
|
||||||
} else if (key == "shard_cout") {
|
} else if (key == "shard_cout") {
|
||||||
cfg.shard_cout = std::stoi(value);
|
cfg.shard_cout = std::stoi(value);
|
||||||
} else if (key == "shard_id") {
|
} else if (key == "shard_id") {
|
||||||
|
@ -789,6 +810,10 @@ int main(int argc, char **argv) {
|
||||||
exit(-7);
|
exit(-7);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (cfg.scroll_keep >= 99) {
|
||||||
|
std::cerr << "Error: Scroll_keep must be a non-float percentage and in a range of 0-99." << std::endl;
|
||||||
|
exit(-12);
|
||||||
|
}
|
||||||
|
|
||||||
// Construct and configure bot
|
// Construct and configure bot
|
||||||
Bot bot(cfg, models);
|
Bot bot(cfg, models);
|
||||||
|
|
Loading…
Add table
Reference in a new issue