mirror of
https://gitlab.com/niansa/discord_llama.git
synced 2025-03-06 20:48:25 +01:00
Make sure to now throw in coroutines
This commit is contained in:
parent
52288eb0c7
commit
91f5d683bb
4 changed files with 40 additions and 11 deletions
|
@ -6,8 +6,9 @@ set(CMAKE_CXX_STANDARD 20)
|
|||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
|
||||
set(ANYPROC_COSCHED ON CACHE BOOL "" FORCE)
|
||||
set(LM_COSCHED ON CACHE BOOL "" FORCE)
|
||||
set(ANYPROC_EXAMPLES OFF CACHE BOOL "" FORCE)
|
||||
set(LM_COSCHED ON CACHE BOOL "" FORCE)
|
||||
set(LM_NOEXCEPT ON CACHE BOOL "" FORCE)
|
||||
|
||||
add_subdirectory(libjustlm)
|
||||
add_subdirectory(anyproc)
|
||||
|
|
|
@ -76,6 +76,7 @@ public:
|
|||
model_missing = "Error: The model that was used in this thread could no longer be found.",
|
||||
timeout = "Error: Timeout",
|
||||
length_error = "Error: Message length error",
|
||||
empty_response = "Empty response",
|
||||
terminated = "Error: Terminated";
|
||||
bool translated = false;
|
||||
|
||||
|
|
|
@ -1 +1 @@
|
|||
Subproject commit 05fa44e1e0567cb09e3b026d7963076ac34a676d
|
||||
Subproject commit 087fe1396b8444f39010822a93bd68ef08d4d00e
|
45
main.cpp
45
main.cpp
|
@ -144,22 +144,27 @@ private:
|
|||
}
|
||||
|
||||
// Must run in llama thread
|
||||
CoSched::AwaitableTask<void> llm_restart(const std::shared_ptr<LM::Inference>& inference, const BotChannelConfig& channel_cfg) {
|
||||
CoSched::AwaitableTask<bool> llm_restart(const std::shared_ptr<LM::Inference>& inference, const BotChannelConfig& channel_cfg) {
|
||||
ENSURE_LLM_THREAD();
|
||||
// Deserialize init cache if not instruct mode without prompt file
|
||||
if (channel_cfg.instruct_mode && config.instruct_prompt_file == "none") co_return;
|
||||
if (channel_cfg.instruct_mode && config.instruct_prompt_file == "none") co_return true;
|
||||
std::ifstream f((*channel_cfg.model_name)+(channel_cfg.instruct_mode?"_instruct_init_cache":"_init_cache"), std::ios::binary);
|
||||
co_await inference->deserialize(f);
|
||||
if (!co_await inference->deserialize(f)) {
|
||||
co_return false;
|
||||
}
|
||||
// Set params
|
||||
inference->params.n_ctx_window_top_bar = inference->get_context_size();
|
||||
inference->params.scroll_keep = float(config.scroll_keep) * 0.01f;
|
||||
co_return true;
|
||||
}
|
||||
// Must run in llama thread
|
||||
CoSched::AwaitableTask<std::shared_ptr<LM::Inference>> llm_start(dpp::snowflake id, const BotChannelConfig& channel_cfg) {
|
||||
ENSURE_LLM_THREAD();
|
||||
// Get or create inference
|
||||
auto inference = co_await llm_pool.create_inference(id, channel_cfg.model->weights_path, llm_get_params(channel_cfg.instruct_mode));
|
||||
co_await llm_restart(inference, channel_cfg);
|
||||
if (!co_await llm_restart(inference, channel_cfg)) {
|
||||
co_return nullptr;
|
||||
}
|
||||
co_return inference;
|
||||
}
|
||||
|
||||
|
@ -171,6 +176,10 @@ private:
|
|||
if (!fres) {
|
||||
// Start new inference
|
||||
fres = co_await llm_start(id, channel_cfg);
|
||||
// Check for error
|
||||
if (!fres) {
|
||||
co_return nullptr;
|
||||
}
|
||||
}
|
||||
// Set scroll callback
|
||||
fres->set_scroll_callback([msg = dpp::message(), channel_id = id] (float progress) {
|
||||
|
@ -194,6 +203,7 @@ private:
|
|||
config.texts.thread_create_fail = co_await llm_translate_from_en(config.texts.thread_create_fail);
|
||||
config.texts.timeout = co_await llm_translate_from_en(config.texts.timeout);
|
||||
config.texts.length_error = co_await llm_translate_from_en(config.texts.length_error);
|
||||
config.texts.empty_response = co_await llm_translate_from_en(config.texts.empty_response);
|
||||
config.texts.terminated = co_await llm_translate_from_en(config.texts.terminated);
|
||||
config.texts.translated = true;
|
||||
}
|
||||
|
@ -275,6 +285,10 @@ private:
|
|||
ENSURE_LLM_THREAD();
|
||||
// Get inference
|
||||
auto inference = co_await llm_get_inference(msg.channel_id, channel_cfg);
|
||||
if (!inference) {
|
||||
std::cerr << "Warning: Failed to get inference" << std::endl;
|
||||
co_return;
|
||||
}
|
||||
std::string prefix;
|
||||
// Define callback for console progress and timeout
|
||||
utils::Timer timeout;
|
||||
|
@ -300,14 +314,13 @@ private:
|
|||
if (timeout_exceeded) co_await inference->append("\n");
|
||||
}
|
||||
// Must run in llama thread
|
||||
CoSched::AwaitableTask<void> prompt_add_trigger(const std::shared_ptr<LM::Inference>& inference, const BotChannelConfig& channel_cfg) {
|
||||
CoSched::AwaitableTask<bool> prompt_add_trigger(const std::shared_ptr<LM::Inference>& inference, const BotChannelConfig& channel_cfg) {
|
||||
ENSURE_LLM_THREAD();
|
||||
if (channel_cfg.instruct_mode) {
|
||||
co_await inference->append('\n'+channel_cfg.model->bot_prompt+"\n\n");
|
||||
co_return co_await inference->append('\n'+channel_cfg.model->bot_prompt+"\n\n");
|
||||
} else {
|
||||
co_await inference->append(bot.me.username+':', show_console_progress);
|
||||
co_return co_await inference->append(bot.me.username+':', show_console_progress);
|
||||
}
|
||||
co_return;
|
||||
}
|
||||
|
||||
// Must run in llama thread
|
||||
|
@ -315,8 +328,18 @@ private:
|
|||
ENSURE_LLM_THREAD();
|
||||
// Get inference
|
||||
auto inference = co_await llm_get_inference(id, channel_cfg);
|
||||
if (!inference) {
|
||||
std::cerr << "Warning: Failed to get inference" << std::endl;
|
||||
co_return;
|
||||
}
|
||||
// Trigger LLM correctly
|
||||
co_await prompt_add_trigger(inference, channel_cfg);
|
||||
if (!co_await prompt_add_trigger(inference, channel_cfg)) {
|
||||
std::cerr << "Warning: Failed to add trigger to prompt: " << inference->get_last_error() << std::endl;
|
||||
co_return;
|
||||
}
|
||||
if (CoSched::Task::get_current().is_dead()) {
|
||||
co_return;
|
||||
}
|
||||
// Run model
|
||||
utils::Timer timeout;
|
||||
utils::Timer edit_timer;
|
||||
|
@ -345,6 +368,10 @@ private:
|
|||
}
|
||||
return true;
|
||||
});
|
||||
if (output.empty()) {
|
||||
std::cerr << "Warning: Failed to generate message: " << inference->get_last_error() << std::endl;
|
||||
output = '<'+config.texts.empty_response+'>';
|
||||
}
|
||||
std::cout << std::endl;
|
||||
// Handle message length error
|
||||
if (response_too_long) {
|
||||
|
|
Loading…
Add table
Reference in a new issue