1
0
Fork 0
mirror of https://gitlab.com/niansa/discord_llama.git synced 2025-03-06 20:48:25 +01:00

Make sure to now throw in coroutines

This commit is contained in:
niansa 2023-05-10 21:52:14 +02:00
parent 52288eb0c7
commit 91f5d683bb
4 changed files with 40 additions and 11 deletions

View file

@ -6,8 +6,9 @@ set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(ANYPROC_COSCHED ON CACHE BOOL "" FORCE)
set(LM_COSCHED ON CACHE BOOL "" FORCE)
set(ANYPROC_EXAMPLES OFF CACHE BOOL "" FORCE)
set(LM_COSCHED ON CACHE BOOL "" FORCE)
set(LM_NOEXCEPT ON CACHE BOOL "" FORCE)
add_subdirectory(libjustlm)
add_subdirectory(anyproc)

View file

@ -76,6 +76,7 @@ public:
model_missing = "Error: The model that was used in this thread could no longer be found.",
timeout = "Error: Timeout",
length_error = "Error: Message length error",
empty_response = "Empty response",
terminated = "Error: Terminated";
bool translated = false;

@ -1 +1 @@
Subproject commit 05fa44e1e0567cb09e3b026d7963076ac34a676d
Subproject commit 087fe1396b8444f39010822a93bd68ef08d4d00e

View file

@ -144,22 +144,27 @@ private:
}
// Must run in llama thread
CoSched::AwaitableTask<void> llm_restart(const std::shared_ptr<LM::Inference>& inference, const BotChannelConfig& channel_cfg) {
CoSched::AwaitableTask<bool> llm_restart(const std::shared_ptr<LM::Inference>& inference, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
// Deserialize init cache if not instruct mode without prompt file
if (channel_cfg.instruct_mode && config.instruct_prompt_file == "none") co_return;
if (channel_cfg.instruct_mode && config.instruct_prompt_file == "none") co_return true;
std::ifstream f((*channel_cfg.model_name)+(channel_cfg.instruct_mode?"_instruct_init_cache":"_init_cache"), std::ios::binary);
co_await inference->deserialize(f);
if (!co_await inference->deserialize(f)) {
co_return false;
}
// Set params
inference->params.n_ctx_window_top_bar = inference->get_context_size();
inference->params.scroll_keep = float(config.scroll_keep) * 0.01f;
co_return true;
}
// Must run in llama thread
CoSched::AwaitableTask<std::shared_ptr<LM::Inference>> llm_start(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
// Get or create inference
auto inference = co_await llm_pool.create_inference(id, channel_cfg.model->weights_path, llm_get_params(channel_cfg.instruct_mode));
co_await llm_restart(inference, channel_cfg);
if (!co_await llm_restart(inference, channel_cfg)) {
co_return nullptr;
}
co_return inference;
}
@ -171,6 +176,10 @@ private:
if (!fres) {
// Start new inference
fres = co_await llm_start(id, channel_cfg);
// Check for error
if (!fres) {
co_return nullptr;
}
}
// Set scroll callback
fres->set_scroll_callback([msg = dpp::message(), channel_id = id] (float progress) {
@ -194,6 +203,7 @@ private:
config.texts.thread_create_fail = co_await llm_translate_from_en(config.texts.thread_create_fail);
config.texts.timeout = co_await llm_translate_from_en(config.texts.timeout);
config.texts.length_error = co_await llm_translate_from_en(config.texts.length_error);
config.texts.empty_response = co_await llm_translate_from_en(config.texts.empty_response);
config.texts.terminated = co_await llm_translate_from_en(config.texts.terminated);
config.texts.translated = true;
}
@ -275,6 +285,10 @@ private:
ENSURE_LLM_THREAD();
// Get inference
auto inference = co_await llm_get_inference(msg.channel_id, channel_cfg);
if (!inference) {
std::cerr << "Warning: Failed to get inference" << std::endl;
co_return;
}
std::string prefix;
// Define callback for console progress and timeout
utils::Timer timeout;
@ -300,14 +314,13 @@ private:
if (timeout_exceeded) co_await inference->append("\n");
}
// Must run in llama thread
CoSched::AwaitableTask<void> prompt_add_trigger(const std::shared_ptr<LM::Inference>& inference, const BotChannelConfig& channel_cfg) {
CoSched::AwaitableTask<bool> prompt_add_trigger(const std::shared_ptr<LM::Inference>& inference, const BotChannelConfig& channel_cfg) {
ENSURE_LLM_THREAD();
if (channel_cfg.instruct_mode) {
co_await inference->append('\n'+channel_cfg.model->bot_prompt+"\n\n");
co_return co_await inference->append('\n'+channel_cfg.model->bot_prompt+"\n\n");
} else {
co_await inference->append(bot.me.username+':', show_console_progress);
co_return co_await inference->append(bot.me.username+':', show_console_progress);
}
co_return;
}
// Must run in llama thread
@ -315,8 +328,18 @@ private:
ENSURE_LLM_THREAD();
// Get inference
auto inference = co_await llm_get_inference(id, channel_cfg);
if (!inference) {
std::cerr << "Warning: Failed to get inference" << std::endl;
co_return;
}
// Trigger LLM correctly
co_await prompt_add_trigger(inference, channel_cfg);
if (!co_await prompt_add_trigger(inference, channel_cfg)) {
std::cerr << "Warning: Failed to add trigger to prompt: " << inference->get_last_error() << std::endl;
co_return;
}
if (CoSched::Task::get_current().is_dead()) {
co_return;
}
// Run model
utils::Timer timeout;
utils::Timer edit_timer;
@ -345,6 +368,10 @@ private:
}
return true;
});
if (output.empty()) {
std::cerr << "Warning: Failed to generate message: " << inference->get_last_error() << std::endl;
output = '<'+config.texts.empty_response+'>';
}
std::cout << std::endl;
// Handle message length error
if (response_too_long) {