diff --git a/.gitmodules b/.gitmodules index 1496cdb..dc26528 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ +[submodule "cosched"] + path = cosched + url = https://gitlab.com/niansa/cosched.git [submodule "anyproc"] path = anyproc url = https://gitlab.com/niansa/anyproc.git diff --git a/CMakeLists.txt b/CMakeLists.txt index d11e4ab..166547a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,16 +7,20 @@ set(CMAKE_CXX_STANDARD_REQUIRED ON) add_subdirectory(libjustlm) add_subdirectory(anyproc) +add_subdirectory(cosched) add_subdirectory(DPP) -add_subdirectory(thread-pool) add_subdirectory(fmt) +set(ANYPROC_COSCHED ON CACHE BOOL "" FORCE) +set(LM_COSCHED ON CACHE BOOL "" FORCE) +set(ANYPROC_EXAMPLES OFF CACHE BOOL "" FORCE) + add_executable(discord_llama main.cpp config.hpp config.cpp utils.cpp utils.hpp ) -target_link_libraries(discord_llama PUBLIC dpp fmt pthread libjustlm anyproc ggml threadpool sqlite3) +target_link_libraries(discord_llama PUBLIC dpp fmt pthread libjustlm anyproc ggml cosched sqlite3) install(TARGETS discord_llama LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}) diff --git a/anyproc b/anyproc index c90cea8..3d3de6a 160000 --- a/anyproc +++ b/anyproc @@ -1 +1 @@ -Subproject commit c90cea8ac8d5182c0e9ca5dc515500ad668bca4d +Subproject commit 3d3de6ae875737c88a8c4496651ef1ed1c74b073 diff --git a/cosched b/cosched new file mode 160000 index 0000000..32bfe1d --- /dev/null +++ b/cosched @@ -0,0 +1 @@ +Subproject commit 32bfe1dbac67b351391460ca530f541050344aca diff --git a/libjustlm b/libjustlm index 57364ae..7076f86 160000 --- a/libjustlm +++ b/libjustlm @@ -1 +1 @@ -Subproject commit 57364ae560253257f89a5164960a27d4bed242db +Subproject commit 7076f863d4667a18ac87104e2b722c3ae3ee9336 diff --git a/main.cpp b/main.cpp index 11a88ca..0ed186e 100644 --- a/main.cpp +++ b/main.cpp @@ -23,12 +23,12 @@ #include #include #include -#include +#include class Bot { - ThreadPool thread_pool{1}; + CoSched::ScheduledThread sched_thread; LM::InferencePool llm_pool; std::unique_ptr translator; std::vector my_messages; @@ -65,45 +65,41 @@ private: # define ENSURE_LLM_THREAD() if (std::this_thread::get_id() != llm_tid) {throw std::runtime_error("LLM execution of '"+std::string(__PRETTY_FUNCTION__)+"' on wrong thread detected");} 0 // Must run in llama thread - std::string_view llm_translate_to_en(std::string_view text, bool skip = false) { + CoSched::AwaitableTask llm_translate_to_en(std::string_view text, bool skip = false) { ENSURE_LLM_THREAD(); + std::string fres(text); // Skip if there is no translator if (translator == nullptr || skip) { std::cout << "(" << config.language << ") " << text << std::endl; - return text; + co_return fres; } - // I am optimizing heavily for the above case. This function always returns a reference so a trick is needed here - static std::string fres; - fres = text; // Replace bot username with [43] utils::str_replace_in_place(fres, bot.me.username, "[43]"); // Run translation - fres = translator->translate(fres, "EN", show_console_progress); + fres = co_await translator->translate(fres, "EN", show_console_progress); // Replace [43] back with bot username utils::str_replace_in_place(fres, "[43]", bot.me.username); std::cout << text << " --> (EN) " << fres << std::endl; - return fres; + co_return fres; } // Must run in llama thread - std::string_view llm_translate_from_en(std::string_view text, bool skip = false) { + CoSched::AwaitableTask llm_translate_from_en(std::string_view text, bool skip = false) { ENSURE_LLM_THREAD(); + std::string fres(text); // Skip if there is no translator if (translator == nullptr || skip) { std::cout << "(" << config.language << ") " << text << std::endl; - return text; + co_return fres; } - // I am optimizing heavily for the above case. This function always returns a reference so a trick is needed here - static std::string fres; - fres = text; // Replace bot username with [43] utils::str_replace_in_place(fres, bot.me.username, "[43]"); // Run translation - fres = translator->translate(fres, config.language, show_console_progress); + fres = co_await translator->translate(fres, config.language, show_console_progress); // Replace [43] back with bot username utils::str_replace_in_place(fres, "[43]", bot.me.username); std::cout << text << " --> (" << config.language << ") " << fres << std::endl; - return fres; + co_return fres; } LM::Inference::Params llm_get_translation_params() const { @@ -124,54 +120,55 @@ private: } // Must run in llama thread - void llm_restart(LM::Inference& inference, const BotChannelConfig& channel_cfg) { + CoSched::AwaitableTask llm_restart(const std::shared_ptr& inference, const BotChannelConfig& channel_cfg) { ENSURE_LLM_THREAD(); // Deserialize init cache if not instruct mode without prompt file - if (channel_cfg.instruct_mode && config.instruct_prompt_file == "none") return; + if (channel_cfg.instruct_mode && config.instruct_prompt_file == "none") co_return; std::ifstream f((*channel_cfg.model_name)+(channel_cfg.instruct_mode?"_instruct_init_cache":"_init_cache"), std::ios::binary); - inference.deserialize(f); + co_await inference->deserialize(f); // Set params - inference.params.n_ctx_window_top_bar = inference.get_context_size(); - inference.params.scroll_keep = float(config.scroll_keep) * 0.01f; + inference->params.n_ctx_window_top_bar = inference->get_context_size(); + inference->params.scroll_keep = float(config.scroll_keep) * 0.01f; } // Must run in llama thread - LM::Inference &llm_start(dpp::snowflake id, const BotChannelConfig& channel_cfg) { + CoSched::AwaitableTask> llm_start(dpp::snowflake id, const BotChannelConfig& channel_cfg) { ENSURE_LLM_THREAD(); // Get or create inference - auto& inference = llm_pool.create_inference(id, channel_cfg.model->weights_path, llm_get_params(channel_cfg.instruct_mode)); - llm_restart(inference, channel_cfg); - return inference; + auto inference = co_await llm_pool.create_inference(id, channel_cfg.model->weights_path, llm_get_params(channel_cfg.instruct_mode)); + co_await llm_restart(inference, channel_cfg); + co_return inference; } // Must run in llama thread - LM::Inference &llm_get_inference(dpp::snowflake id, const BotChannelConfig& channel_cfg) { + CoSched::AwaitableTask> llm_get_inference(dpp::snowflake id, const BotChannelConfig& channel_cfg) { ENSURE_LLM_THREAD(); // Get inference - auto inference_opt = llm_pool.get_inference(id); - if (!inference_opt.has_value()) { + auto fres = co_await llm_pool.get_inference(id); + if (!fres) { // Start new inference - inference_opt = llm_start(id, channel_cfg); + fres = co_await llm_start(id, channel_cfg); } - auto& fres = inference_opt.value(); // Set scroll callback - fres.get().set_scroll_callback([msg = dpp::message(), channel_id = id] (float progress) { + fres->set_scroll_callback([msg = dpp::message(), channel_id = id] (float progress) { std::cout << "WARNING: " << channel_id << " is scrolling! " << progress << "% \r" << std::flush; return true; }); // Return inference - return fres; + co_return fres; } // Must run in llama thread - void llm_init() { + CoSched::AwaitableTask llm_init() { + // Run at realtime priority + CoSched::Task::get_current().set_priority(CoSched::PRIO_REALTIME); // Set LLM thread llm_tid = std::this_thread::get_id(); // Translate texts if (!config.texts.translated) { - config.texts.please_wait = llm_translate_from_en(config.texts.please_wait); - config.texts.model_missing = llm_translate_from_en(config.texts.model_missing); - config.texts.thread_create_fail = llm_translate_from_en(config.texts.thread_create_fail); - config.texts.timeout = llm_translate_from_en(config.texts.timeout); + config.texts.please_wait = co_await llm_translate_from_en(config.texts.please_wait); + config.texts.model_missing = co_await llm_translate_from_en(config.texts.model_missing); + config.texts.thread_create_fail = co_await llm_translate_from_en(config.texts.thread_create_fail); + config.texts.timeout = co_await llm_translate_from_en(config.texts.timeout); config.texts.translated = true; } // Set scroll callback @@ -208,10 +205,10 @@ private: using namespace fmt::literals; if (prompt.back() != '\n') prompt.push_back('\n'); llm->set_scroll_callback(scroll_cb); - llm->append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username), show_console_progress); + co_await llm->append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username), show_console_progress); // Serialize end result std::ofstream f(filename, std::ios::binary); - llm->serialize(f); + co_await llm->serialize(f); } // Instruct prompt filename = model_name+"_instruct_init_cache"; @@ -237,20 +234,21 @@ private: using namespace fmt::literals; if (prompt.back() != '\n') prompt.push_back('\n'); llm->set_scroll_callback(scroll_cb); - llm->append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username)+"\n\n"+model_config.user_prompt, show_console_progress); + co_await llm->append(fmt::format(fmt::runtime(prompt), "bot_name"_a=bot.me.username)+"\n\n"+model_config.user_prompt, show_console_progress); // Serialize end result std::ofstream f(filename, std::ios::binary); - llm->serialize(f); + co_await llm->serialize(f); } } // Report complete init std::cout << "Init done!" << std::endl; } + // Must run in llama thread - void prompt_add_msg(const dpp::message& msg, const BotChannelConfig& channel_cfg) { + CoSched::AwaitableTask prompt_add_msg(const dpp::message& msg, const BotChannelConfig& channel_cfg) { ENSURE_LLM_THREAD(); // Get inference - auto& inference = llm_get_inference(msg.channel_id, channel_cfg); + auto inference = co_await llm_get_inference(msg.channel_id, channel_cfg); std::string prefix; // Define callback for console progress and timeout utils::Timer timeout; @@ -266,101 +264,97 @@ private: // Instruct mode user prompt if (channel_cfg.instruct_mode) { // Append line as-is - inference.append("\n\n"+std::string(llm_translate_to_en(msg.content, channel_cfg.model->no_translate))+'\n', cb); + co_await inference->append("\n\n"+std::string(co_await llm_translate_to_en(msg.content, channel_cfg.model->no_translate))+'\n', cb); } else { // Format and append lines for (const auto line : utils::str_split(msg.content, '\n')) { - inference.append(msg.author.username+": "+std::string(llm_translate_to_en(line, channel_cfg.model->no_translate))+'\n', cb); + co_await inference->append(msg.author.username+": "+std::string(co_await llm_translate_to_en(line, channel_cfg.model->no_translate))+'\n', cb); } } // Append line break on timeout - if (timeout_exceeded) inference.append("\n"); + if (timeout_exceeded) co_await inference->append("\n"); } // Must run in llama thread - void prompt_add_trigger(LM::Inference& inference, const BotChannelConfig& channel_cfg) { + CoSched::AwaitableTask prompt_add_trigger(const std::shared_ptr& inference, const BotChannelConfig& channel_cfg) { ENSURE_LLM_THREAD(); if (channel_cfg.instruct_mode) { - inference.append('\n'+channel_cfg.model->bot_prompt+"\n\n"); + co_await inference->append('\n'+channel_cfg.model->bot_prompt+"\n\n"); } else { - inference.append(bot.me.username+':', show_console_progress); + co_await inference->append(bot.me.username+':', show_console_progress); } + co_return; } // Must run in llama thread - void reply(dpp::snowflake id, dpp::message msg, const BotChannelConfig& channel_cfg) { + CoSched::AwaitableTask reply(dpp::snowflake id, const BotChannelConfig& channel_cfg) { ENSURE_LLM_THREAD(); - try { - // Get inference - auto& inference = llm_get_inference(id, channel_cfg); - // Trigger LLM correctly - prompt_add_trigger(inference, channel_cfg); - // Run model - utils::Timer timeout; - utils::Timer edit_timer; - bool timeout_exceeded = false; - msg.content.clear(); - auto output = inference.run(channel_cfg.instruct_mode?channel_cfg.model->user_prompt:"\n", [&] (std::string_view token) { - std::cout << token << std::flush; - if (timeout.get() > config.timeout) { - timeout_exceeded = true; - std::cerr << "\nWarning: Timeout exceeded generating message"; - return false; - } - if (config.live_edit) { - msg.content += token; - if (edit_timer.get() > 3) { - try { - bot.message_edit(msg); - } catch (...) {} - edit_timer.reset(); - } - } - return true; - }); - std::cout << std::endl; - // Handle timeout - if (timeout_exceeded) { - if (config.live_edit) { - output += "...\n"+config.texts.timeout; - } else { - output = config.texts.timeout; + // Create initial message + auto msg = bot.message_create_sync(dpp::message(id, config.texts.please_wait+" :thinking:")); + co_await CoSched::Task::get_current().yield(); + // Get inference + auto inference = co_await llm_get_inference(id, channel_cfg); + // Trigger LLM correctly + co_await prompt_add_trigger(inference, channel_cfg); + // Run model + utils::Timer timeout; + utils::Timer edit_timer; + bool timeout_exceeded = false; + msg.content.clear(); + const std::string reverse_prompt = channel_cfg.instruct_mode?channel_cfg.model->user_prompt:"\n"; + auto output = co_await inference->run(reverse_prompt, [&] (std::string_view token) { + std::cout << token << std::flush; + // Check for timeout + if (timeout.get() > config.timeout) { + timeout_exceeded = true; + std::cerr << "\nWarning: Timeout exceeded generating message"; + return false; + } + // Edit live + if (config.live_edit) { + msg.content += token; + if (edit_timer.get() > 3) { + try { + bot.message_edit(msg); + } catch (...) {} + edit_timer.reset(); } } - // Send resulting message - msg.content = llm_translate_from_en(output, channel_cfg.model->no_translate); - bot.message_edit(msg); - // Prepare for next message - inference.append("\n"); - if (channel_cfg.model->emits_eos) { - inference.append("\n"+channel_cfg.model->user_prompt); + return true; + }); + std::cout << std::endl; + // Handle timeout + if (timeout_exceeded) { + if (config.live_edit) { + output += "...\n"+config.texts.timeout; + } else { + output = config.texts.timeout; } - } catch (const std::exception& e) { - std::cerr << "Warning: " << e.what() << std::endl; + } + // Send resulting message + msg.content = co_await llm_translate_from_en(output, channel_cfg.model->no_translate); + bot.message_edit(msg); + // Prepare for next message + co_await inference->append("\n"); + if (channel_cfg.model->emits_eos) { + co_await inference->append("\n"+channel_cfg.model->user_prompt); } } - bool attempt_reply(const dpp::message& msg, const BotChannelConfig& channel_cfg) { + CoSched::AwaitableTask attempt_reply(const dpp::message& msg, const BotChannelConfig& channel_cfg) { // Reply if message contains username, mention or ID if (msg.content.find(bot.me.username) != std::string::npos) { - enqueue_reply(msg.channel_id, channel_cfg); - return true; + co_await reply(msg.channel_id, channel_cfg); + co_return true; } // Reply if message references user for (const auto msg_id : my_messages) { if (msg.message_reference.message_id == msg_id) { - enqueue_reply(msg.channel_id, channel_cfg); - return true; + co_await reply(msg.channel_id, channel_cfg); + co_return true; } } // Don't reply otherwise - return false; - } - - void enqueue_reply(dpp::snowflake id, const BotChannelConfig& channel_cfg) { - bot.message_create(dpp::message(id, config.texts.please_wait+" :thinking:"), [=, this] (const dpp::confirmation_callback_t& ccb) { - if (ccb.is_error()) return; - thread_pool.submit(std::bind(&Bot::reply, this, id, ccb.get(), channel_cfg)); - }); + co_return false; } bool is_on_own_shard(dpp::snowflake id) const { @@ -491,18 +485,16 @@ public: " UNIQUE(id)" ");"; - // Configure llm_pool - llm_pool.set_store_on_destruct(cfg.persistance); - - // Initialize thread pool - thread_pool.init(); + // Start Scheduled Thread + sched_thread.start(); // Prepare translator if (cfg.language != "EN") { - thread_pool.submit([this] () { - std::cout << "Preparing translator..." << std::endl; - translator = std::make_unique(config.translation_model_cfg->weights_path, llm_get_translation_params()); - }); + sched_thread.create_task("Translator", [this] () -> CoSched::AwaitableTask { + std::cout << "Preparing translator..." << std::endl; + translator = std::make_unique(config.translation_model_cfg->weights_path, llm_get_translation_params()); + co_return; + }); } // Configure bot @@ -529,7 +521,9 @@ public: } if (dpp::run_once()) { // Prepare llm - thread_pool.submit(std::bind(&Bot::llm_init, this)); + sched_thread.create_task("Language Model Initialization", [this] () -> CoSched::AwaitableTask { + co_await llm_init(); + }); } }); bot.on_slashcommand([=, this](dpp::slashcommand_t event) { @@ -595,9 +589,9 @@ public: // Check for reset command if (msg.content == "!reset") { // Delete inference from pool - thread_pool.submit([this, msg] () { - llm_pool.delete_inference(msg.channel_id); - }); + sched_thread.create_task("Language Model Inference Pool", [=, this] () -> CoSched::AwaitableTask { + co_await llm_pool.delete_inference(msg.channel_id); + }); // Delete message bot.message_delete(msg.id, msg.channel_id); return; @@ -637,21 +631,22 @@ public: channel_cfg.model = config.default_inference_model_cfg; } // Append message - thread_pool.submit([=, this] () { - prompt_add_msg(msg, channel_cfg); - }); - // Handle message somehow... - if (in_bot_thread) { - // Send a reply - enqueue_reply(msg.channel_id, channel_cfg); - } else if (msg.content == "!trigger") { - // Delete message - bot.message_delete(msg.id, msg.channel_id); - // Send a reply - enqueue_reply(msg.channel_id, channel_cfg); - } else { - attempt_reply(msg, channel_cfg); - } + sched_thread.create_task("Language Model Inference ("+*channel_cfg.model_name+')', [=, this] () -> CoSched::AwaitableTask { + co_await prompt_add_msg(msg, channel_cfg); + // Handle message somehow... + if (in_bot_thread) { + // Send a reply + co_await reply(msg.channel_id, channel_cfg); + } else if (msg.content == "!trigger") { + // Delete message + bot.message_delete(msg.id, msg.channel_id); + // Send a reply + co_await reply(msg.channel_id, channel_cfg); + } else { + // Check more conditions in another function... + co_await attempt_reply(msg, channel_cfg); + } + }); // Find thread embed std::scoped_lock L(thread_embeds_mutex); auto res = thread_embeds.find(msg.channel_id); @@ -675,10 +670,12 @@ public: bot.start(dpp::st_wait); } void stop_prepare() { - thread_pool.submit([this] () { - llm_pool.store_all(); - }).wait(); - thread_pool.shutdown(); + if (config.persistance) { + sched_thread.create_task("Language Model Shutdown", [=, this] () -> CoSched::AwaitableTask { + co_await llm_pool.store_all(); + }); + } + sched_thread.wait(); } }; diff --git a/thread-pool/.gitignore b/thread-pool/.gitignore deleted file mode 100644 index 4c22b49..0000000 --- a/thread-pool/.gitignore +++ /dev/null @@ -1,52 +0,0 @@ -# Build -#VS -.vs/ -Debug/ -Release/ -# Netbeans -nbproject/ -*.user -*.filters -*.vcxproj -*.sln - -# Cmake -build -CMakeCache.txt -CMakeFiles -CMakeScripts -Makefile -cmake_install.cmake -install_manifest.txt -CTestTestfile.cmake - -# Compiled Object files -*.slo -*.lo -*.o -*.obj -!assets/models/*/*.obj - -# Precompiled Headers -*.gch -*.pch - -# Compiled Dynamic libraries -*.so -*.dylib -*.dll - -# Fortran module files -*.mod -*.smod - -# Compiled Static libraries -*.lai -*.la -*.a -*.lib - -# Executables -*.exe -*.out -*.app diff --git a/thread-pool/CMakeLists.txt b/thread-pool/CMakeLists.txt deleted file mode 100644 index 0b179d7..0000000 --- a/thread-pool/CMakeLists.txt +++ /dev/null @@ -1,5 +0,0 @@ -cmake_minimum_required (VERSION 3.5) -project (threadpool) - -add_library(threadpool INTERFACE) -target_include_directories(threadpool INTERFACE include) diff --git a/thread-pool/LICENSE b/thread-pool/LICENSE deleted file mode 100644 index 4746316..0000000 --- a/thread-pool/LICENSE +++ /dev/null @@ -1,21 +0,0 @@ -MIT License - -Copyright (c) 2016 Mariano Trebino - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/thread-pool/README.md b/thread-pool/README.md deleted file mode 100644 index 3ceed5c..0000000 --- a/thread-pool/README.md +++ /dev/null @@ -1,380 +0,0 @@ -# Table of Contents - [Introduction](https://github.com/mtrebi/thread-pool/blob/master/README.md#introduction)
- [Build instructions](https://github.com/mtrebi/thread-pool/blob/master/README.md#build-instructions)
- [Thread pool](https://github.com/mtrebi/thread-pool/blob/master/README.md#thread-pool)
-        [Queue](https://github.com/mtrebi/thread-pool/blob/master/README.md#queue)
-        [Submit function](https://github.com/mtrebi/thread-pool/blob/master/README.md#submit-function)
-        [Thread worker](https://github.com/mtrebi/thread-pool/blob/master/README.md#thread-worker)
- [Usage example](https://github.com/mtrebi/thread-pool/blob/master/README.md#usage-example)
-        [Use case#1](https://github.com/mtrebi/thread-pool#use-case-1)
-        [Use case#2](https://github.com/mtrebi/thread-pool#use-case-2)
-        [Use case#3](https://github.com/mtrebi/thread-pool#use-case-3)
- [Future work](https://github.com/mtrebi/thread-pool/blob/master/README.md#future-work)
- [References](https://github.com/mtrebi/thread-pool/blob/master/README.md#references)
- -# Introduction: - -A [thread pool](https://en.wikipedia.org/wiki/Thread_pool) is a technique that allows developers to exploit the concurrency of modern processors in an **easy** and **efficient** manner. It's easy because you send "work" to the pool and somehow this work gets done without blocking the main thread. It's efficient because threads are not initialized each time we want the work to be done. Threads are initialized once and remain inactive until some work has to be done. This way we minimize the overhead. - -There are many more Thread pool implementations in C++, many of them are probably better (safer, faster...) than mine. However,I believe my implementation are **very straightforward and easy to understand**. - -__Disclaimer: Please Do not use this project in a professional environment. It may contain bugs and/or not work as expected.__ I did this project to learn how C++11 Threads work and provide an easy way for other people to understand it too. - -# Build instructions: - -This project has been developed using Netbeans and Linux but it should work on Windows, MAC OS and Linux. It can be easily build using CMake and different other generators. The following code can be used to generate the VS 2017 project files: - -```c -// VS 2017 -cd -mkdir build -cd build/ -cmake .. "Visual Studio 15 2017 Win64" -``` - -Then, from VS you can edit and execute the project. Make sure that __main project is set up as the startup project__ - -If you are using Linux, you need to change the generator (use the default) and execute an extra operation to actually make it executable: - -```c -// Linux -cd -mkdir build -cd build/ -cmake .. -make -``` - -# Thread pool - -The way that I understand things better is with images. So, let's take a look at the image of thread pool given by Wikipedia: - -

- -As you can see, we have three important elements here: -* *Tasks Queue*. This is where the work that has to be done is stored. -* *Thread Pool*. This is set of threads (or workers) that continuously take work from the queue and do it. -* *Completed Tasks*. When the Thread has finished the work we return "something" to notify that the work has finished. - -## Queue - -We use a queue to store the work because it's the more sensible data structure. We want the work to be **started** in the same order that we sent it. However, this queue is a little bit **special**. As I said in the previous section, threads are continuously (well, not really, but let's assume that they are) querying the queue to ask for work. When there's work available, threads take the work from the queue and do it. What would happen if two threads try to take the same work at the same time? Well, the program would crash. - -To avoid these kinds of problems, I implemented a wrapper over the standard C++ Queue that uses mutex to restrict the concurrent access. Let's see a small sample of the SafeQueue class: - -```c -void enqueue(T& t) { - std::unique_lock lock(m_mutex); - m_queue.push(t); -} - -``` -To enqueue the first thing we do is lock the mutex to make sure that no one else is accessing the resource. Then, we push the element to the queue. When the lock goes out of scopes it gets automatically released. Easy, huh? This way, we make the Queue thread-safe and thus we don't have to worry many threads accessing and/or modifying it at the same "time". - -## Submit function - -The most important method of the thread pool is the one responsible of adding work to the queue. I called this method **submit**. It's not difficult to understand how it works but its implementation can seem scary at first. Let's think about **what** should do and after that we will worry about **how** to do it. What: -* Accept any function with any parameters. -* Return "something" immediately to avoid blocking main thread. This returned object should **eventually** contain the result of the operation. - -Cool, let's see **how** we can implement it. - -### Submit implementation - -The complete submit functions looks like this: - -```c -// Submit a function to be executed asynchronously by the pool -template -auto submit(F&& f, Args&&... args) -> std::future { - // Create a function with bounded parameters ready to execute - std::function func = std::bind(std::forward(f), std::forward(args)...); - // Encapsulate it into a shared ptr in order to be able to copy construct / assign - auto task_ptr = std::make_shared>(func); - - // Wrap packaged task into void function - std::function wrapper_func = [task_ptr]() { - (*task_ptr)(); - }; - - // Enqueue generic wrapper function - m_queue.enqueue(wrapperfunc); - - // Wake up one thread if its waiting - m_conditional_lock.notify_one(); - - // Return future from promise - return task_ptr->get_future(); -} -``` - -Nevertheless, we're going to inspect line by line what's going on in order to fully understand how it works. - -#### Variadic template function - -```c -template -``` - -This means that the next statement is templated. The first template parameter is called F (our function) and second one is a parameter pack. A parameter pack is a special template parameter that can accept zero or more template arguments. It is, in fact, a way to express a variable number of arguments in a template. A template with at least one parameter pack is called **variadic template** - -Summarizing, we are telling the compiler that our submit function is going to take one generic parameter of type F (our function) and a parameter pack Args (the parameters of the function F). - -#### Function declaration - -```c -auto submit(F&& f, Args&&... args) -> std::future { -``` - -This may seem weird but, it's not. A function, in fact, can be declared using two different syntaxes. The following is the most well known: - -```c -return-type identifier ( argument-declarations... ) -``` - -But, we can also declare the function like this: - -```c -auto identifier ( argument-declarations... ) -> return_type - ``` - -Why two syntaxes? Well, imagine that you have a function that has a return type that depends on the input parameters of the function. Using the first syntax you can't declare that function without getting a compiler error since you would be using a variable in the return type that has not been declared yet (because the return type declaration goes before the parameters type declaration). - -Using the second syntax you can declare the function to have return type **auto** then, using the -> you can declare the return type depending on the arguments of the functions that have been declared previously. - -Now, let's inspect the parameters of the submit function. When the type of a parameter is declared as **T&&** for some deducted type T that parameter is a **universal reference**. This term was coined by [Scott Meyers](https://isocpp.org/blog/2012/11/universal-references-in-c11-scott-meyers) because **T&&** can also mean r-value reference. However, in the context of type deduction, it means that it can be bound to both l-values and r-values, unlike l-value references that can only be bound to non-const objects (they bind only to modifiable lvalues) and r-value references (they bind only to rvalues). - - -The return type of the function is of type **std::future**. An std::future is a special type that provides a mechanism to access the result of asynchronous operations, in our case, the result of executing a specific function. This makes sense with what we said earlier. - -Finally, the template type of std::future is **decltype(f(args...))**. Decltype is a special C++ keyword that inspects the declared type of an entity or the type and value category of an expression. In our case, we want to know the return type of the function _f_, so we give decltype our generic function _f_ and the parameter pack _args_. - -#### Function body - -```c -// Create a function with bounded parameters ready to execute -std::function func = std::bind(std::forward(f), std::forward(args)...); -``` - -There are many many things happening here. First of all, the **std::bind(F, Args)** is a function that creates a wrapper for F with the given Args. Caling this wrapper is the same as calling F with the Args that it has been bound. Here, we are simply calling bind with our generic function _f_ and the parameter pack _args_ but using another wrapper **std::forward(t)** for each parameter. This second wrapper is needed to achieve perfect forwarding of universal references. -The result of this bind call is a **std::function**. The std::function is a C++ object that encapsulates a function. It allows you to execute the function as if it were a normal function calling the operator() with the required parameters BUT, because it is an object, you can store it, copy it and move it around. The template type of any std::function is the signature of that function: std::function< return-type (arguments)>. In this case, we already know how to get the return type of this function using decltype. But, what about the arguments? Well, because we bound all arguments _args_ to the function _f_ we just have to add an empty pair of parenthesis that represents an empty list of arguments: **decltype(f(args...))()**. - - -```c -// Encapsulate it into a shared ptr in order to be able to copy construct / assign -auto task_ptr = std::make_shared>(func); -``` - -The next thing we do is we create a **std::packaged_task(t)**. A packaged_task is a wrapper around a function that can be executed asynchronously. It's result is stored in a shared state inside an std::future object. The templated type T of an std::packaged_task(t) is the type of the function _t_ that is wrapping. Because we said it before, the signature of the function _f_ is **decltype(f(args...))()** that is the same type of the packaged_task. Then, we just wrap again this packaged task inside a **std::shared_ptr** using the initialize function **std::make_shared**. - -```c -// Wrap packaged task into void function -std::function wrapperfunc = [task_ptr]() { - (*task_ptr)(); -}; - -``` - -Again, we create a std:.function, but, note that this time its template type is **void()**. Independently of the function _f_ and its parameters _args_ this _wrapperfunc_ the return type will always be **void**. Since all functions _f_ may have different return types, the only way to store them in a container (our Queue) is wrapping them with a generic void function. Here, we are just declaring this _wrapperfunc_ to execute the actual task _taskptr_ that will execute the bound function _func_. - -```c -// Enqueue generic wrapper function -m_queue.enqueue(wrapperfunc); -``` - -We enqueue this _wrapperfunc_. - -```c -// Wake up one thread if its waiting -m_conditional_lock.notify_one(); -``` - -Before finishing, we wake up one thread in case it is waiting. - -```c -// Return future from promise -return task_ptr->get_future(); -``` - -And finally, we return the future of the packaged_task. Because we are returning the future that is bound to the packaged_task _taskptr_ that, at the same time, is bound with the function _func_, executing this _taskptr_ will automatically update the future. Because we wrapped the execution of the _taskptr_ with a generic wrapper function, is the execution of _wrapperfunc_ that, in fact, updates the future. Aaaaand. since we enqueued this wrapper function, it will be executed by a thread after being dequeued calling the operator(). - - -## Thread worker - -Now that we understand how the submit method works, we're going to focus on how the work gets done. Probably, the simplest implementation of a thread worker could be using polling: - - Loop - If Queue is not empty - Dequeue work - Do it - -This looks alright but it's **not very efficient**. Do you see why? What would happen if there is no work in the Queue? The threads would keep looping and asking all the time: Is the queue empty? - -The more sensible implementation is done by "sleeping" the threads until some work is added to the queue. As we saw before, as soon as we enqueue work, a signal **notify_one()** is sent. This allows us to implement a more efficient algorithm: - - Loop - If Queue is empty - Wait signal - Dequeue work - Do it - -This signal system is implemented in C++ with **conditional variables**. Conditional variables are always bound to a mutex, so I added a mutex to the thread pool class just to manage this. The final code of a worker looks like this: - -```c -void operator()() { - std::function func; - bool dequeued; - while (!m_pool->m_shutdown) { - { - std::unique_lock lock(m_pool->m_conditional_mutex); - if (m_pool->m_queue.empty()) { - m_pool->m_conditional_lock.wait(lock); - } - dequeued = m_pool->m_queue.dequeue(func); - } - if (dequeued) { - func(); - } - } -} - -``` - -The code is really easy to understand so I am not going to explain anything. The only thing to note here is that, _func_ is our wrapper function declared as: - -```c -std::function wrapperfunc = [task_ptr]() { - (*task_ptr)(); -}; - -``` - -So, executing this function will automatically update the future. - -# Usage example - -Creating the thread pool is as easy as: - -```c -// Create pool with 3 threads -ThreadPool pool(3); - -// Initialize pool -pool.init(); -``` - -When we want to shutdown the pool just call: - -```c -// Shutdown the pool, releasing all threads -pool.shutdown() -``` - -Ff we want to send some work to the pool, after we have initialized it, we just have to call the submit function: - -```c -pool.submit(work); -``` - -Depending on the type of work, I've distinguished different use-cases. Suppose that the work that we have to do is multiply two numbers. We can do it in many different ways. I've implemented the three most common ways to do it that I can imagine: -* Use-Case #1. Function returns the result -* Use-Case #2. Function updates by ref parameter with the result -* Use-Case #3. Function prints the result - -_Note: This is just to show how the submit function works. Options are not exclusive_ - -## Use-Case #1 -The multiply function with a return looks like this: - -```c -// Simple function that adds multiplies two numbers and returns the result -int multiply(const int a, const int b) { - const int res = a * b; - return res; -} -``` - -Then, the submit: - -```c -// The type of future is given by the return type of the function -std::future future = pool.submit(multiply, 2, 3); -``` - -We can also use the **auto** keyword for convenience: - -```c -auto future = pool.submit(multiply, 2, 3); -``` - -Nice, when the work is finished by the thread pool we know that the future will get updated and we can retrieve the result calling: -```c -const int result = future.get(); -std::cout << result << std::endl; -``` - -The get() function of std::future always return the type T of the future. **This type will always be equal to the return type of the function passed to the submit method**. In this case, int. - -## Use-Case #2 -The multiply function has a parameter passed by ref: - -```c -// Simple function that adds multiplies two numbers and updates the out_res variable passed by ref -void multiply(int& out_res, const int a, const int b) { - out_res = a * b; -} -``` - -Now, we have to call the submit function with a subtle difference. Because we are using templates and type deduction (universal references), the parameter passed by ref needs to be called using **std::ref(param)** to make sure that we are passing it by ref and not by value. - -```c -int result = 0; -auto future = pool.submit(multiply, std::ref(result), 2, 3); -// result is 0 -future.get(); -// result is 6 -std::cout << result << std::endl; -``` - -In this case, what's the type of future? Well, as I said before, the return type will always be equal to the return type of the function passed to the submit method. Because this function is of type void, the future is **std::future**. Calling future.get() returns void. That's not very useful, but we still need to call .get() to make sure that the work has been done. - -## Use-Case #3 -The last case is the easiest one. Our multiply function simply prints the result: - -We have a simple function without output parameters. For this example I implemented the following multiplication function: - -```c -// Simple function that adds multiplies two numbers and prints the result -void multiply(const int a, const int b) { - const int result = a * b; - std::cout << result << std::endl; -} -``` - -Then, we can simply call: - -```c -auto future = pool.submit(multiply, 2, 3); -future.get(); -``` - -In this case, we know that as soon as the multiplication is done it will be printed. If we care when this is done, we can wait for it calling future.get(). - -Checkout the [main](https://github.com/mtrebi/thread-pool/blob/master/src/main.cpp) program for a complete example. - -# Future work - -* Make it more reliable and safer (exceptions) -* Find a better way to use it with member functions (thanks to @rajenk) -* Run benchmarks and improve performance if needed - * Evaluate performance and impact of std::function in the heap and try alternatives if necessary. (thanks to @JensMunkHansen) - -# References - -* [MULTI-THREADED PROGRAMMING TERMINOLOGY - 2017](http://www.bogotobogo.com/cplusplus/multithreaded.php): Fast analysis of how a multi-thread system works - -* [Universal References in C++11—Scott Meyers](https://isocpp.org/blog/2012/11/universal-references-in-c11-scott-meyers): Universal references in C++11 by Scott Meyers - -* [Perfect forwarding and universal references in C++](http://eli.thegreenplace.net/2014/perfect-forwarding-and-universal-references-in-c/): Article about how and when to use perfect forwarding and universal references - -* [C++ documentation](http://www.cplusplus.com/reference/): Thread, conditional variables, mutex and many others... diff --git a/thread-pool/affinity.patch b/thread-pool/affinity.patch deleted file mode 100644 index 16ca12b..0000000 --- a/thread-pool/affinity.patch +++ /dev/null @@ -1,100 +0,0 @@ ---- ThreadPool.h Wed May 17 15:01:04 2017 -+++ ThreadPool.h Sun Mar 4 04:32:07 2018 -@@ -1,5 +1,18 @@ - #pragma once - -+#ifdef AFFINITY -+#if defined __sun__ -+#include -+#include -+#include -+#include /* For sysconf */ -+#elif __linux__ -+#include /* For fprintf */ -+#include -+#endif -+#endif -+ -+#include /* For std::size_t */ - #include - #include - #include -@@ -14,10 +27,10 @@ - private: - class ThreadWorker { - private: -- int m_id; -+ std::size_t m_id; - ThreadPool * m_pool; - public: -- ThreadWorker(ThreadPool * pool, const int id) -+ ThreadWorker(ThreadPool * pool, const std::size_t id) - : m_pool(pool), m_id(id) { - } - -@@ -45,7 +58,7 @@ - std::mutex m_conditional_mutex; - std::condition_variable m_conditional_lock; - public: -- ThreadPool(const int n_threads) -+ ThreadPool(const std::size_t n_threads) - : m_threads(std::vector(n_threads)), m_shutdown(false) { - } - -@@ -57,7 +70,44 @@ - - // Inits thread pool - void init() { -- for (int i = 0; i < m_threads.size(); ++i) { -+ #if (defined __sun__ || defined __linux__) && defined AFFINITY -+ std::size_t v_cpu = 0; -+ std::size_t v_cpu_max = std::thread::hardware_concurrency() - 1; -+ #endif -+ -+ #if defined __sun__ && defined AFFINITY -+ std::vector v_cpu_id; /* Struct for CPU/core ID */ -+ -+ processorid_t i, cpuid_max; -+ cpuid_max = sysconf(_SC_CPUID_MAX); -+ for (i = 0; i <= cpuid_max; i++) { -+ if (p_online(i, P_STATUS) != -1) /* Get only online cores ID */ -+ v_cpu_id.push_back(i); -+ } -+ #endif -+ -+ for (std::size_t i = 0; i < m_threads.size(); ++i) { -+ -+ #if (defined __sun__ || defined __linux__) && defined AFFINITY -+ if (v_cpu > v_cpu_max) { -+ v_cpu = 0; -+ } -+ -+ #ifdef __sun__ -+ processor_bind(P_LWPID, P_MYID, v_cpu_id[v_cpu], NULL); -+ #elif __linux__ -+ cpu_set_t mask; -+ CPU_ZERO(&mask); -+ CPU_SET(v_cpu, &mask); -+ pthread_t thread = pthread_self(); -+ if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &mask) != 0) { -+ fprintf(stderr, "Error setting thread affinity\n"); -+ } -+ #endif -+ -+ ++v_cpu; -+ #endif -+ - m_threads[i] = std::thread(ThreadWorker(this, i)); - } - } ---- SafeQueue.h Wed May 17 15:01:04 2017 -+++ SafeQueue.h Sun Mar 4 04:32:07 2018 -@@ -28,7 +28,7 @@ - return m_queue.empty(); - } - -- int size() { -+ std::size_t size() { - std::unique_lock lock(m_mutex); - return m_queue.size(); - } diff --git a/thread-pool/example/main.cpp b/thread-pool/example/main.cpp deleted file mode 100644 index 469d068..0000000 --- a/thread-pool/example/main.cpp +++ /dev/null @@ -1,72 +0,0 @@ -#include -#include - -#include "../include/ThreadPool.h" - -std::random_device rd; -std::mt19937 mt(rd()); -std::uniform_int_distribution dist(-1000, 1000); -auto rnd = std::bind(dist, mt); - - -void simulate_hard_computation() { - std::this_thread::sleep_for(std::chrono::milliseconds(2000 + rnd())); -} - -// Simple function that adds multiplies two numbers and prints the result -void multiply(const int a, const int b) { - simulate_hard_computation(); - const int res = a * b; - std::cout << a << " * " << b << " = " << res << std::endl; -} - -// Same as before but now we have an output parameter -void multiply_output(int & out, const int a, const int b) { - simulate_hard_computation(); - out = a * b; - std::cout << a << " * " << b << " = " << out << std::endl; -} - -// Same as before but now we have an output parameter -int multiply_return(const int a, const int b) { - simulate_hard_computation(); - const int res = a * b; - std::cout << a << " * " << b << " = " << res << std::endl; - return res; -} - - -int main(int argc, char *argv[]) -{ - // Create pool with 3 threads - ThreadPool pool(3); - - // Initialize pool - pool.init(); - - // Submit (partial) multiplication table - for (int i = 1; i < 3; ++i) { - for (int j = 1; j < 10; ++j) { - pool.submit(multiply, i, j); - } - } - - // Submit function with output parameter passed by ref - int output_ref; - auto future1 = pool.submit(multiply_output, std::ref(output_ref), 5, 6); - - // Wait for multiplication output to finish - future1.get(); - std::cout << "Last operation result is equals to " << output_ref << std::endl; - - // Submit function with return parameter - auto future2 = pool.submit(multiply_return, 5, 3); - - // Wait for multiplication output to finish - int res = future2.get(); - std::cout << "Last operation result is equals to " << res << std::endl; - - pool.shutdown(); - - return 0; -} diff --git a/thread-pool/include/SafeQueue.h b/thread-pool/include/SafeQueue.h deleted file mode 100644 index 3c0ff74..0000000 --- a/thread-pool/include/SafeQueue.h +++ /dev/null @@ -1,52 +0,0 @@ -#pragma once - -#include -#include - -// Thread safe implementation of a Queue using an std::queue -template -class SafeQueue { -private: - std::queue m_queue; - std::mutex m_mutex; -public: - SafeQueue() { - - } - - SafeQueue(SafeQueue& other) { - //TODO: - } - - ~SafeQueue() { - - } - - - bool empty() { - std::unique_lock lock(m_mutex); - return m_queue.empty(); - } - - int size() { - std::unique_lock lock(m_mutex); - return m_queue.size(); - } - - void enqueue(T& t) { - std::unique_lock lock(m_mutex); - m_queue.push(t); - } - - bool dequeue(T& t) { - std::unique_lock lock(m_mutex); - - if (m_queue.empty()) { - return false; - } - t = std::move(m_queue.front()); - - m_queue.pop(); - return true; - } -}; \ No newline at end of file diff --git a/thread-pool/include/ThreadPool.h b/thread-pool/include/ThreadPool.h deleted file mode 100644 index b7bf84d..0000000 --- a/thread-pool/include/ThreadPool.h +++ /dev/null @@ -1,99 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "SafeQueue.h" - -class ThreadPool { -private: - class ThreadWorker { - private: - int m_id; - ThreadPool * m_pool; - public: - ThreadWorker(ThreadPool * pool, const int id) - : m_pool(pool), m_id(id) { - } - - void operator()() { - std::function func; - bool dequeued; - while (!m_pool->m_shutdown) { - { - std::unique_lock lock(m_pool->m_conditional_mutex); - if (m_pool->m_queue.empty()) { - m_pool->m_conditional_lock.wait(lock); - } - dequeued = m_pool->m_queue.dequeue(func); - } - if (dequeued) { - func(); - } - } - } - }; - - bool m_shutdown; - SafeQueue> m_queue; - std::vector m_threads; - std::mutex m_conditional_mutex; - std::condition_variable m_conditional_lock; -public: - ThreadPool(const int n_threads) - : m_threads(std::vector(n_threads)), m_shutdown(false) { - } - - ThreadPool(const ThreadPool &) = delete; - ThreadPool(ThreadPool &&) = delete; - - ThreadPool & operator=(const ThreadPool &) = delete; - ThreadPool & operator=(ThreadPool &&) = delete; - - // Inits thread pool - void init() { - for (int i = 0; i < m_threads.size(); ++i) { - m_threads[i] = std::thread(ThreadWorker(this, i)); - } - } - - // Waits until threads finish their current task and shutdowns the pool - void shutdown() { - m_shutdown = true; - m_conditional_lock.notify_all(); - - for (int i = 0; i < m_threads.size(); ++i) { - if(m_threads[i].joinable()) { - m_threads[i].join(); - } - } - } - - // Submit a function to be executed asynchronously by the pool - template - auto submit(F&& f, Args&&... args) -> std::future { - // Create a function with bounded parameters ready to execute - std::function func = std::bind(std::forward(f), std::forward(args)...); - // Encapsulate it into a shared ptr in order to be able to copy construct / assign - auto task_ptr = std::make_shared>(func); - - // Wrap packaged task into void function - std::function wrapper_func = [task_ptr]() { - (*task_ptr)(); - }; - - // Enqueue generic wrapper function - m_queue.enqueue(wrapper_func); - - // Wake up one thread if its waiting - m_conditional_lock.notify_one(); - - // Return future from promise - return task_ptr->get_future(); - } -};