From c548433dd93904e959c1782aeddcc3b36f7c1507 Mon Sep 17 00:00:00 2001 From: Josh Baldwin Date: Sun, 25 Oct 2020 20:54:19 -0600 Subject: [PATCH] Correctly implement sync_wait and when_all_awaitable (#8) See issue for more details, in general attempting to implement a coro::thread_pool exposed that the coro::sync_wait and coro::when_all only worked if the coroutines executed on that same thread. They should now possibly have the ability to execute on another thread, to be determined in a later issue. Fixes #7 --- .gitignore | 1 + CMakeLists.txt | 9 +- Testing/Temporary/CTestCostData.txt | 1 - Testing/Temporary/LastTest.log | 3 - inc/coro/awaitable.hpp | 57 +++ inc/coro/coro.hpp | 4 + inc/coro/detail/void_value.hpp | 8 + inc/coro/promise.hpp | 25 ++ inc/coro/scheduler.hpp | 17 +- inc/coro/shutdown.hpp | 14 + inc/coro/sync_wait.hpp | 235 +++++++++++- inc/coro/task.hpp | 3 +- inc/coro/thread_pool.hpp | 73 ++++ inc/coro/when_all.hpp | 543 ++++++++++++++++++++++++++++ src/sync_wait.cpp | 34 ++ src/thread_pool.cpp | 124 +++++++ test/CMakeLists.txt | 2 + test/bench.cpp | 27 +- test/test_sync_wait.cpp | 32 +- test/test_thread_pool.cpp | 22 ++ test/test_when_all.cpp | 86 +++++ 21 files changed, 1257 insertions(+), 63 deletions(-) delete mode 100644 Testing/Temporary/CTestCostData.txt delete mode 100644 Testing/Temporary/LastTest.log create mode 100644 inc/coro/awaitable.hpp create mode 100644 inc/coro/detail/void_value.hpp create mode 100644 inc/coro/promise.hpp create mode 100644 inc/coro/shutdown.hpp create mode 100644 inc/coro/thread_pool.hpp create mode 100644 inc/coro/when_all.hpp create mode 100644 src/sync_wait.cpp create mode 100644 src/thread_pool.cpp create mode 100644 test/test_thread_pool.cpp create mode 100644 test/test_when_all.cpp diff --git a/.gitignore b/.gitignore index 1ea9d37..c8a3549 100644 --- a/.gitignore +++ b/.gitignore @@ -35,5 +35,6 @@ /Debug/ /RelWithDebInfo/ /Release/ +/Testing/ /.vscode/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 253e1fc..40c8fef 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,13 +8,20 @@ message("${PROJECT_NAME} CORO_BUILD_TESTS = ${CORO_BUILD_TESTS}") message("${PROJECT_NAME} CORO_CODE_COVERAGE = ${CORO_CODE_COVERAGE}") set(LIBCORO_SOURCE_FILES + inc/coro/detail/void_value.hpp + + inc/coro/awaitable.hpp inc/coro/coro.hpp inc/coro/event.hpp src/event.cpp inc/coro/generator.hpp inc/coro/latch.hpp + inc/coro/promise.hpp inc/coro/scheduler.hpp - inc/coro/sync_wait.hpp + inc/coro/shutdown.hpp + inc/coro/sync_wait.hpp src/sync_wait.cpp inc/coro/task.hpp + inc/coro/thread_pool.hpp src/thread_pool.cpp + inc/coro/when_all.hpp ) add_library(${PROJECT_NAME} STATIC ${LIBCORO_SOURCE_FILES}) diff --git a/Testing/Temporary/CTestCostData.txt b/Testing/Temporary/CTestCostData.txt deleted file mode 100644 index ed97d53..0000000 --- a/Testing/Temporary/CTestCostData.txt +++ /dev/null @@ -1 +0,0 @@ ---- diff --git a/Testing/Temporary/LastTest.log b/Testing/Temporary/LastTest.log deleted file mode 100644 index de83ee1..0000000 --- a/Testing/Temporary/LastTest.log +++ /dev/null @@ -1,3 +0,0 @@ -Start testing: Oct 12 14:19 MST ----------------------------------------------------------- -End testing: Oct 12 14:19 MST diff --git a/inc/coro/awaitable.hpp b/inc/coro/awaitable.hpp new file mode 100644 index 0000000..c187023 --- /dev/null +++ b/inc/coro/awaitable.hpp @@ -0,0 +1,57 @@ +#pragma once + +#include +#include +#include + +namespace coro +{ + +/** + * This concept declares a type that is required to meet the c++20 coroutine operator co_await() + * retun type. It requires the following three member functions: + * await_ready() -> bool + * await_suspend(std::coroutine_handle<>) -> void|bool|std::coroutine_handle<> + * await_resume() -> decltype(auto) + * Where the return type on await_resume is the requested return of the awaitable. + */ +template +concept awaiter_type = requires(T t, std::coroutine_handle<> c) +{ + { t.await_ready() } -> std::same_as; + std::same_as || + std::same_as || + std::same_as>; + { t.await_resume() }; +}; + +/** + * This concept declares a type that can be operator co_await()'ed and returns an awaiter_type. + */ +template +concept awaitable_type = requires(T t) +{ + // operator co_await() + { t.operator co_await() } -> awaiter_type; +}; + +template +struct awaitable_traits +{ +}; + +template +static auto get_awaiter(T&& value) +{ + return static_cast(value).operator co_await(); +} + +template +struct awaitable_traits +{ + using awaiter_t = decltype(get_awaiter(std::declval())); + using awaiter_return_t = decltype(std::declval().await_resume()); + // using awaiter_return_decay_t = std::decay_t().await_resume())>; +}; + +} // namespace coro diff --git a/inc/coro/coro.hpp b/inc/coro/coro.hpp index eb7b59e..c31a9d2 100644 --- a/inc/coro/coro.hpp +++ b/inc/coro/coro.hpp @@ -1,8 +1,12 @@ #pragma once +#include "coro/awaitable.hpp" #include "coro/event.hpp" #include "coro/generator.hpp" #include "coro/latch.hpp" +#include "coro/promise.hpp" #include "coro/scheduler.hpp" #include "coro/sync_wait.hpp" #include "coro/task.hpp" +#include "coro/thread_pool.hpp" +#include "coro/when_all.hpp" diff --git a/inc/coro/detail/void_value.hpp b/inc/coro/detail/void_value.hpp new file mode 100644 index 0000000..7793c89 --- /dev/null +++ b/inc/coro/detail/void_value.hpp @@ -0,0 +1,8 @@ +#pragma once + +namespace coro::detail +{ + +struct void_value{}; + +} // coro::detail diff --git a/inc/coro/promise.hpp b/inc/coro/promise.hpp new file mode 100644 index 0000000..c83cf9f --- /dev/null +++ b/inc/coro/promise.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include "coro/awaitable.hpp" + +#include + +namespace coro +{ + +template +concept promise_type = requires(T t) +{ + { t.get_return_object() } -> std::convertible_to>; + { t.initial_suspend() } -> awaiter_type; + { t.final_suspend() } -> awaiter_type; + { t.yield_value() } -> awaitable_type; +} && +requires(T t, return_type return_value) +{ + std::same_as || + std::same_as || + requires { t.yield_value(return_value); }; +}; + +} // namespace coro diff --git a/inc/coro/scheduler.hpp b/inc/coro/scheduler.hpp index a40fca9..8ba254a 100644 --- a/inc/coro/scheduler.hpp +++ b/inc/coro/scheduler.hpp @@ -1,6 +1,7 @@ #pragma once #include "coro/task.hpp" +#include "coro/shutdown.hpp" #include #include @@ -42,17 +43,18 @@ public: resume_token_base(resume_token_base&& other) { m_scheduler = other.m_scheduler; - m_state = other.m_state.exchange(0); + m_state = other.m_state.exchange(nullptr); other.m_scheduler = nullptr; } auto operator=(const resume_token_base&) -> resume_token_base& = delete; - auto operator =(resume_token_base&& other) -> resume_token_base& + + auto operator=(resume_token_base&& other) -> resume_token_base& { if (std::addressof(other) != this) { m_scheduler = other.m_scheduler; - m_state = other.m_state.exchange(0); + m_state = other.m_state.exchange(nullptr); other.m_scheduler = nullptr; } @@ -323,14 +325,6 @@ private: public: using fd_t = int; - enum class shutdown_t - { - /// Synchronously wait for all tasks to complete when calling shutdown. - sync, - /// Asynchronously let tasks finish on the background thread on shutdown. - async - }; - enum class thread_strategy_t { /// Spawns a background thread for the scheduler to run on. @@ -810,7 +804,6 @@ private: std::atomic_thread_fence(std::memory_order::acquire); bool tasks_ready = !m_accept_queue.empty(); - // bool tasks_ready = m_event_set.load(std::memory_order::acquire); auto timeout = (tasks_ready) ? m_no_timeout : user_timeout; // Poll is run every iteration to make sure 'waiting' events are properly put into diff --git a/inc/coro/shutdown.hpp b/inc/coro/shutdown.hpp new file mode 100644 index 0000000..640b53e --- /dev/null +++ b/inc/coro/shutdown.hpp @@ -0,0 +1,14 @@ +#pragma once + +namespace coro +{ + +enum class shutdown_t +{ + /// Synchronously wait for all tasks to complete when calling shutdown. + sync, + /// Asynchronously let tasks finish on the background thread on shutdown. + async +}; + +} // namespace coro diff --git a/inc/coro/sync_wait.hpp b/inc/coro/sync_wait.hpp index 0a620ba..bc1684c 100644 --- a/inc/coro/sync_wait.hpp +++ b/inc/coro/sync_wait.hpp @@ -1,30 +1,235 @@ #pragma once -#include "coro/scheduler.hpp" -#include "coro/task.hpp" +#include "coro/awaitable.hpp" + +#include +#include namespace coro { -template -auto sync_wait(task_type&& task) -> decltype(auto) + +namespace detail { - while (!task.is_ready()) + +class sync_wait_event +{ +public: + sync_wait_event(bool initially_set = false); + sync_wait_event(const sync_wait_event&) = delete; + sync_wait_event(sync_wait_event&&) = delete; + auto operator=(const sync_wait_event&) -> sync_wait_event& = delete; + auto operator=(sync_wait_event&&) -> sync_wait_event& = delete; + ~sync_wait_event() = default; + + auto set() noexcept -> void; + auto reset() noexcept -> void; + auto wait() noexcept -> void; +private: + std::mutex m_mutex; + std::condition_variable m_cv; + bool m_set{false}; +}; + +class sync_wait_task_promise_base +{ +public: + sync_wait_task_promise_base() noexcept = default; + virtual ~sync_wait_task_promise_base() = default; + + auto initial_suspend() noexcept -> std::suspend_always { - task.resume(); + return {}; + } + + auto unhandled_exception() -> void + { + m_exception = std::current_exception(); + } +protected: + sync_wait_event* m_event{nullptr}; + std::exception_ptr m_exception; +}; + +template +class sync_wait_task_promise : public sync_wait_task_promise_base +{ +public: + using coroutine_type = std::coroutine_handle>; + + sync_wait_task_promise() noexcept = default; + ~sync_wait_task_promise() override = default; + + auto start(sync_wait_event& event) + { + m_event = &event; + coroutine_type::from_promise(*this).resume(); + } + + auto get_return_object() noexcept + { + return coroutine_type::from_promise(*this); + } + + auto yield_value(return_type&& value) noexcept + { + m_return_value = std::addressof(value); + return final_suspend(); + } + + auto final_suspend() noexcept + { + struct completion_notifier + { + auto await_ready() const noexcept { return false; } + auto await_suspend(coroutine_type coroutine) const noexcept + { + coroutine.promise().m_event->set(); + } + auto await_resume() noexcept { }; + }; + + return completion_notifier{}; + } + + auto return_value() -> return_type&& + { + if(m_exception) + { + std::rethrow_exception(m_exception); + } + + return static_cast(*m_return_value); + } + +private: + std::remove_reference_t* m_return_value; +}; + + +template<> +class sync_wait_task_promise : public sync_wait_task_promise_base +{ + using coroutine_type = std::coroutine_handle>; +public: + sync_wait_task_promise() noexcept = default; + ~sync_wait_task_promise() override = default; + + auto start(sync_wait_event& event) + { + m_event = &event; + coroutine_type::from_promise(*this).resume(); + } + + auto get_return_object() noexcept + { + return coroutine_type::from_promise(*this); + } + + auto final_suspend() noexcept + { + struct completion_notifier + { + auto await_ready() const noexcept { return false; } + auto await_suspend(coroutine_type coroutine) const noexcept + { + coroutine.promise().m_event->set(); + } + auto await_resume() noexcept { }; + }; + + return completion_notifier{}; + } + + auto return_void() noexcept -> void { } + + auto return_value() + { + if(m_exception) + { + std::rethrow_exception(m_exception); + } + } +}; + +template +class sync_wait_task +{ +public: + using promise_type = sync_wait_task_promise; + using coroutine_type = std::coroutine_handle; + + sync_wait_task(coroutine_type coroutine) noexcept + : m_coroutine(coroutine) + { + + } + + sync_wait_task(const sync_wait_task&) = delete; + sync_wait_task(sync_wait_task&& other) noexcept + : m_coroutine(std::exchange(other.m_coroutine, coroutine_type{})) + { + + } + auto operator=(const sync_wait_task&) -> sync_wait_task& = delete; + auto operator=(sync_wait_task&& other) -> sync_wait_task& + { + if(std::addressof(other) != this) + { + m_coroutine = std::exchange(other.m_coroutine, coroutine_type{}); + } + + return *this; + } + + ~sync_wait_task() + { + if(m_coroutine) + { + m_coroutine.destroy(); + } + } + + auto start(sync_wait_event& event) noexcept + { + m_coroutine.promise().start(event); + } + + // todo specialize for type void + auto return_value() -> return_type + { + return m_coroutine.promise().return_value(); + } + +private: + coroutine_type m_coroutine; +}; + + +template::awaiter_return_t> +static auto make_sync_wait_task(awaitable&& a) -> sync_wait_task +{ + if constexpr (std::is_void_v) + { + co_await std::forward(a); + co_return; + } + else + { + co_yield co_await std::forward(a); } - return task.promise().return_value(); } -template -auto sync_wait_all(tasks&&... awaitables) -> void +} // namespace detail + +template +auto sync_wait(awaitable&& a) -> decltype(auto) { - scheduler s{scheduler::options{ - .reserve_size = sizeof...(awaitables), .thread_strategy = scheduler::thread_strategy_t::manual}}; + detail::sync_wait_event e{}; + auto task = detail::make_sync_wait_task(std::forward(a)); + task.start(e); + e.wait(); - (s.schedule(std::move(awaitables)), ...); - - while (s.process_events() > 0) - ; + return task.return_value(); } } // namespace coro diff --git a/inc/coro/task.hpp b/inc/coro/task.hpp index 9b387b6..6dd59cc 100644 --- a/inc/coro/task.hpp +++ b/inc/coro/task.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include namespace coro { @@ -198,7 +199,7 @@ public: { struct awaitable : public awaitable_base { - auto await_resume() noexcept -> decltype(auto) { return this->m_coroutine.promise().return_value(); } + auto await_resume() -> decltype(auto) { return this->m_coroutine.promise().return_value(); } }; return awaitable{m_coroutine}; diff --git a/inc/coro/thread_pool.hpp b/inc/coro/thread_pool.hpp new file mode 100644 index 0000000..8d7bbbb --- /dev/null +++ b/inc/coro/thread_pool.hpp @@ -0,0 +1,73 @@ +#pragma once + +#include "coro/shutdown.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace coro +{ + +class thread_pool; + +class thread_pool +{ +public: + class operation + { + friend class thread_pool; + public: + explicit operation(thread_pool& tp) noexcept; + + auto await_ready() noexcept -> bool { std::cerr << "thread_pool::operation::await_ready()\n"; return false; } + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool; + auto await_resume() noexcept -> void { std::cerr << "thread_pool::operation::await_resume()\n";/* no-op */ } + private: + thread_pool& m_thread_pool; + std::coroutine_handle<> m_awaiting_coroutine{nullptr}; + }; + + explicit thread_pool(uint32_t thread_count = std::thread::hardware_concurrency()); + + thread_pool(const thread_pool&) = delete; + thread_pool(thread_pool&&) = delete; + auto operator=(const thread_pool&) -> thread_pool& = delete; + auto operator=(thread_pool&&) -> thread_pool& = delete; + + ~thread_pool(); + + auto thread_count() const -> uint32_t { return m_threads.size(); } + + [[nodiscard]] + auto schedule() noexcept -> std::optional; + + auto shutdown(shutdown_t wait_for_tasks = shutdown_t::sync) -> void; + + auto size() const -> std::size_t { return m_size.load(std::memory_order::relaxed); } + auto empty() const -> bool { return size() == 0; } +private: + std::atomic m_shutdown_requested{false}; + + std::vector m_threads; + + std::mutex m_queue_cv_mutex; + std::condition_variable m_queue_cv; + + std::mutex m_queue_mutex; + std::deque m_queue; + std::atomic m_size{0}; + + auto run(uint32_t worker_idx) -> void; + auto join() -> void; + auto schedule_impl(operation* op) -> void; +}; + +} // namespace coro diff --git a/inc/coro/when_all.hpp b/inc/coro/when_all.hpp new file mode 100644 index 0000000..c0eebda --- /dev/null +++ b/inc/coro/when_all.hpp @@ -0,0 +1,543 @@ +#pragma once + +#include "coro/awaitable.hpp" +#include "coro/detail/void_value.hpp" + +#include +#include +#include + +namespace coro +{ + +namespace detail +{ + +class when_all_latch +{ +public: + when_all_latch(std::size_t count) noexcept + : m_count(count + 1) + { } + + when_all_latch(const when_all_latch&) = delete; + when_all_latch(when_all_latch&& other) + : m_count(other.m_count.load(std::memory_order::acquire)), + m_awaiting_coroutine(std::exchange(other.m_awaiting_coroutine, nullptr)) + { } + + auto operator=(const when_all_latch&) -> when_all_latch& = delete; + auto operator=(when_all_latch&& other) -> when_all_latch& + { + if(std::addressof(other) != this) + { + m_count.store( + other.m_count.load(std::memory_order::acquire), + std::memory_order::relaxed + ); + m_awaiting_coroutine = std::exchange(other.m_awaiting_coroutine, nullptr); + } + + return *this; + } + + auto is_ready() const noexcept -> bool + { + return m_awaiting_coroutine != nullptr && m_awaiting_coroutine.done(); + } + + auto try_await(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + m_awaiting_coroutine = awaiting_coroutine; + return m_count.fetch_sub(1, std::memory_order::acq_rel) > 1; + } + + auto notify_awaitable_completed() noexcept -> void + { + if(m_count.fetch_sub(1, std::memory_order::acq_rel) == 1) + { + m_awaiting_coroutine.resume(); + } + } + +private: + /// The number of tasks that are being waited on. + std::atomic m_count; + /// The when_all_task awaiting to be resumed upon all task completions. + std::coroutine_handle<> m_awaiting_coroutine{nullptr}; +}; + +template +class when_all_ready_awaitable; + +template +class when_all_task; + +/// Empty tuple<> implementation. +template<> +class when_all_ready_awaitable> +{ +public: + constexpr when_all_ready_awaitable() noexcept {} + explicit constexpr when_all_ready_awaitable(std::tuple<>) noexcept {} + + constexpr auto await_ready() const noexcept -> bool { return true; } + auto await_suspend(std::coroutine_handle<>) noexcept -> void { } + auto await_resume() const noexcept -> std::tuple<> { return {}; } +}; + +template +class when_all_ready_awaitable> +{ +public: + explicit when_all_ready_awaitable(task_types&&... tasks) + noexcept(std::conjunction_v...>) + : m_latch(sizeof...(task_types)), + m_tasks(std::move(tasks)...) + {} + + explicit when_all_ready_awaitable(std::tuple&& tasks) + noexcept(std::is_nothrow_move_constructible_v>) + : m_latch(sizeof...(task_types)), + m_tasks(std::move(tasks)) + { } + + when_all_ready_awaitable(const when_all_ready_awaitable&) = delete; + when_all_ready_awaitable(when_all_ready_awaitable&& other) + : m_latch(std::move(other.m_latch)), + m_tasks(std::move(other.m_tasks)) + { } + + auto operator=(const when_all_ready_awaitable&) -> when_all_ready_awaitable& = delete; + auto operator=(when_all_ready_awaitable&&) -> when_all_ready_awaitable& = delete; + + auto operator co_await() & noexcept + { + struct awaiter + { + explicit awaiter(when_all_ready_awaitable& awaitable) noexcept + : m_awaitable(awaitable) + { } + + auto await_ready() const noexcept -> bool + { + return m_awaitable.is_ready(); + } + + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + return m_awaitable.try_await(awaiting_coroutine); + } + + auto await_resume() noexcept -> std::tuple& + { + return m_awaitable.m_tasks; + } + private: + when_all_ready_awaitable& m_awaitable; + }; + + return awaiter{ *this }; + } + + auto operator co_await() && noexcept + { + struct awaiter + { + explicit awaiter(when_all_ready_awaitable& awaitable) noexcept + : m_awaitable(awaitable) + { } + + auto await_ready() const noexcept -> bool + { + return m_awaitable.is_ready(); + } + + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + return m_awaitable.try_await(awaiting_coroutine); + } + + auto await_resume() noexcept -> std::tuple&& + { + return std::move(m_awaitable.m_tasks); + } + private: + when_all_ready_awaitable& m_awaitable; + }; + + return awaiter{ *this }; + } +private: + + auto is_ready() const noexcept -> bool + { + return m_latch.is_ready(); + } + + auto try_await(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + std::apply( + [this](auto&&... tasks) { ((tasks.start(m_latch)), ...); }, + m_tasks); + return m_latch.try_await(awaiting_coroutine); + } + + when_all_latch m_latch; + std::tuple m_tasks; +}; + +template +class when_all_ready_awaitable +{ +public: + explicit when_all_ready_awaitable(task_container_type&& tasks) noexcept + : m_latch(std::size(tasks)), + m_tasks(std::forward(tasks)) + {} + + when_all_ready_awaitable(const when_all_ready_awaitable&) = delete; + when_all_ready_awaitable(when_all_ready_awaitable&& other) + noexcept(std::is_nothrow_move_constructible_v) + : m_latch(std::move(other.m_latch)), + m_tasks(std::move(m_tasks)) + {} + + auto operator=(const when_all_ready_awaitable&) -> when_all_ready_awaitable& = delete; + auto operator=(when_all_ready_awaitable&) -> when_all_ready_awaitable& = delete; + + auto operator co_await() & noexcept + { + struct awaiter + { + awaiter(when_all_ready_awaitable& awaitable) + : m_awaitable(awaitable) + {} + + auto await_ready() const noexcept -> bool + { + return m_awaitable.is_ready(); + } + + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + return m_awaitable.try_await(awaiting_coroutine); + } + + auto await_resume() noexcept -> task_container_type& + { + return m_awaitable.m_tasks; + } + private: + when_all_ready_awaitable& m_awaitable; + }; + + return awaiter{*this}; + } + + auto operator co_await() && noexcept + { + struct awaiter + { + awaiter(when_all_ready_awaitable& awaitable) + : m_awaitable(awaitable) + {} + + auto await_ready() const noexcept -> bool + { + return m_awaitable.is_ready(); + } + + auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + return m_awaitable.try_await(awaiting_coroutine); + } + + auto await_resume() noexcept -> task_container_type&& + { + return std::move(m_awaitable.m_tasks); + } + private: + when_all_ready_awaitable& m_awaitable; + }; + + return awaiter{*this}; + } +private: + auto is_ready() const noexcept -> bool + { + return m_latch.is_ready(); + } + + auto try_await(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool + { + for(auto& task : m_tasks) + { + task.start(m_latch); + } + + return m_latch.try_await(awaiting_coroutine); + } + + when_all_latch m_latch; + task_container_type m_tasks; +}; + +template +class when_all_task_promise +{ +public: + using coroutine_handle_type = std::coroutine_handle>; + + when_all_task_promise() noexcept + {} + + auto get_return_object() noexcept + { + return coroutine_handle_type::from_promise(*this); + } + + auto initial_suspend() noexcept -> std::suspend_always + { + return {}; + } + + auto final_suspend() noexcept + { + struct completion_notifier + { + auto await_ready() const noexcept -> bool { return false; } + auto await_suspend(coroutine_handle_type coroutine) const noexcept -> void + { + coroutine.promise().m_latch->notify_awaitable_completed(); + } + auto await_resume() const noexcept { } + }; + + return completion_notifier{}; + } + + auto unhandled_exception() noexcept + { + m_exception_ptr = std::current_exception(); + } + + auto yield_value(return_type&& value) noexcept + { + m_return_value = std::addressof(value); + return final_suspend(); + } + + auto start(when_all_latch& latch) noexcept -> void + { + m_latch = &latch; + coroutine_handle_type::from_promise(*this).resume(); + } + + auto return_value() & -> return_type& + { + if(m_exception_ptr) + { + std::rethrow_exception(m_exception_ptr); + } + return *m_return_value; + } + + auto return_value() && -> return_type&& + { + if(m_exception_ptr) + { + std::rethrow_exception(m_exception_ptr); + } + return std::forward(*m_return_value); + } + +private: + when_all_latch* m_latch{nullptr}; + std::exception_ptr m_exception_ptr; + std::add_pointer_t m_return_value; +}; + +template<> +class when_all_task_promise +{ +public: + using coroutine_handle_type = std::coroutine_handle>; + + when_all_task_promise() noexcept + {} + + auto get_return_object() noexcept + { + return coroutine_handle_type::from_promise(*this); + } + + auto initial_suspend() noexcept -> std::suspend_always + { + return {}; + } + + auto final_suspend() noexcept + { + struct completion_notifier + { + auto await_ready() const noexcept -> bool { return false; } + auto await_suspend(coroutine_handle_type coroutine) const noexcept -> void + { + coroutine.promise().m_latch->notify_awaitable_completed(); + } + auto await_resume() const noexcept -> void { } + }; + + return completion_notifier{}; + } + + auto unhandled_exception() noexcept -> void + { + m_exception_ptr = std::current_exception(); + } + + auto return_void() noexcept -> void + {} + + auto start(when_all_latch& latch) -> void + { + m_latch = &latch; + coroutine_handle_type::from_promise(*this).resume(); + } + + auto return_value() -> void + { + if(m_exception_ptr) + { + std::rethrow_exception(m_exception_ptr); + } + } +private: + when_all_latch* m_latch{nullptr}; + std::exception_ptr m_exception_ptr; +}; + +template +class when_all_task +{ +public: + // To be able to call start(). + template + friend class when_all_ready_awaitable; + + using promise_type = when_all_task_promise; + using coroutine_handle_type = typename promise_type::coroutine_handle_type; + + when_all_task(coroutine_handle_type coroutine) noexcept + : m_coroutine(coroutine) + {} + + when_all_task(const when_all_task&) = delete; + when_all_task(when_all_task&& other) noexcept + : m_coroutine(std::exchange(other.m_coroutine, coroutine_handle_type{})) + {} + + auto operator=(const when_all_task&) -> when_all_task& = delete; + auto operator=(when_all_task&&) -> when_all_task& = delete; + + ~when_all_task() + { + if(m_coroutine != nullptr) + { + m_coroutine.destroy(); + } + } + + auto return_value() & -> decltype(auto) + { + if constexpr (std::is_void_v) + { + m_coroutine.promise().return_void(); + return void_value{}; + } + else + { + return m_coroutine.promise().return_value(); + } + } + + auto return_value() const & -> decltype(auto) + { + if constexpr (std::is_void_v) + { + m_coroutine.promise().return_void(); + return void_value{}; + } + else + { + return m_coroutine.promise().return_value(); + } + } + + auto return_value() && -> decltype(auto) + { + if constexpr (std::is_void_v) + { + m_coroutine.promise().return_void(); + return void_value{}; + } + else + { + return m_coroutine.promise().return_value(); + } + } + +private: + auto start(when_all_latch& latch) noexcept -> void + { + m_coroutine.promise().start(latch); + } + + coroutine_handle_type m_coroutine; +}; + +template::awaiter_return_t> +static auto make_when_all_task(awaitable a) -> when_all_task +{ + if constexpr (std::is_void_v) + { + co_await static_cast(a); + co_return; + } + else + { + co_yield co_await static_cast(a); + } +} + +} // namespace detail + +template +[[nodiscard]] auto when_all_awaitable(awaitables_type&&... awaitables) +{ + return + detail::when_all_ready_awaitable< + std::tuple< + detail::when_all_task< + typename awaitable_traits::awaiter_return_t + >... + > + >(std::make_tuple(detail::make_when_all_task(std::forward(awaitables))...)); +} + +template::awaiter_return_t> +[[nodiscard]] auto when_all_awaitable(std::vector&& awaitables) -> detail::when_all_ready_awaitable>> +{ + std::vector> tasks; + tasks.reserve(std::size(awaitables)); + + for(auto& a : awaitables) + { + tasks.emplace_back(detail::make_when_all_task(std::move(a))); + } + + return detail::when_all_ready_awaitable(std::move(tasks)); +} + +} // namespace coro diff --git a/src/sync_wait.cpp b/src/sync_wait.cpp new file mode 100644 index 0000000..95abbd8 --- /dev/null +++ b/src/sync_wait.cpp @@ -0,0 +1,34 @@ +#include "coro/sync_wait.hpp" + +namespace coro::detail +{ + +sync_wait_event::sync_wait_event(bool initially_set) + : m_set(initially_set) +{ + +} + +auto sync_wait_event::set() noexcept -> void +{ + { + std::lock_guard g{m_mutex}; + m_set = true; + } + + m_cv.notify_all(); +} + +auto sync_wait_event::reset() noexcept -> void +{ + std::lock_guard g{m_mutex}; + m_set = false; +} + +auto sync_wait_event::wait() noexcept -> void +{ + std::unique_lock lk{m_mutex}; + m_cv.wait(lk, [this] { return m_set; }); +} + +} // namespace coro::detail \ No newline at end of file diff --git a/src/thread_pool.cpp b/src/thread_pool.cpp new file mode 100644 index 0000000..3992bda --- /dev/null +++ b/src/thread_pool.cpp @@ -0,0 +1,124 @@ +#include "coro/thread_pool.hpp" + +namespace coro +{ + +thread_pool::operation::operation(thread_pool& tp) noexcept + : m_thread_pool(tp) +{ + +} + +auto thread_pool::operation::await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool +{ + std::cerr << "thread_pool::operation::await_suspend()\n"; + m_awaiting_coroutine = awaiting_coroutine; + m_thread_pool.schedule_impl(this); + return false; +} + +thread_pool::thread_pool(uint32_t thread_count) +{ + m_threads.reserve(thread_count); + for(uint32_t i = 0; i < thread_count; ++i) + { + m_threads.emplace_back([this, i] { run(i); }); + } +} + +thread_pool::~thread_pool() +{ + shutdown(); + + // If shutdown was called manually by the user with shutdown_t::async then the background + // worker threads need to be joined upon the thread pool destruction. + join(); +} + +auto thread_pool::schedule() noexcept -> std::optional +{ + std::cerr << "thread_pool::schedule()\n"; + if(!m_shutdown_requested.load(std::memory_order::relaxed)) + { + m_size.fetch_add(1, std::memory_order::relaxed); + return {operation{*this}}; + } + + return std::nullopt; +} + +auto thread_pool::shutdown(shutdown_t wait_for_tasks) -> void +{ + if (!m_shutdown_requested.exchange(true, std::memory_order::release)) + { + m_queue_cv.notify_all(); + if(wait_for_tasks == shutdown_t::sync) + { + join(); + } + } +} + +auto thread_pool::run(uint32_t worker_idx) -> void +{ + while(true) + { + // Wait until the queue has operations to execute or shutdown has been requested. + { + std::unique_lock lk{m_queue_cv_mutex}; + m_queue_cv.wait(lk, [this] { return !m_queue.empty() || m_shutdown_requested.load(std::memory_order::relaxed); }); + } + + // Continue to pull operations from the global queue until its empty. + while(true) + { + operation* op{nullptr}; + { + std::lock_guard lk{m_queue_mutex}; + if(!m_queue.empty()) + { + std::cerr << "thread_pool::run m_queue.pop_front()\n"; + op = m_queue.front(); + m_queue.pop_front(); + } + else + { + break; // while true, the queue is currently empty + } + } + + if(op != nullptr && op->m_awaiting_coroutine != nullptr) + { + op->m_awaiting_coroutine.resume(); + m_size.fetch_sub(1, std::memory_order::relaxed); + } + } + + if(m_shutdown_requested.load(std::memory_order::relaxed)) + { + break; // while(true); + } + } +} + +auto thread_pool::join() -> void +{ + for(auto& thread : m_threads) + { + thread.join(); + } + m_threads.clear(); +} + +auto thread_pool::schedule_impl(operation* op) -> void +{ + std::cerr << "thread_pool::schedule_impl()\n"; + { + std::lock_guard lk{m_queue_mutex}; + m_queue.emplace_back(op); + } + + m_queue_cv.notify_one(); +} + +} // namespace coro diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1f89be8..ea1ec05 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -9,6 +9,8 @@ set(LIBCORO_TEST_SOURCE_FILES test_scheduler.cpp test_sync_wait.cpp test_task.cpp + test_thread_pool.cpp + test_when_all.cpp ) add_executable(${PROJECT_NAME} main.cpp ${LIBCORO_TEST_SOURCE_FILES}) diff --git a/test/bench.cpp b/test/bench.cpp index b14fbf3..70ae765 100644 --- a/test/bench.cpp +++ b/test/bench.cpp @@ -50,40 +50,43 @@ TEST_CASE("benchmark counter func direct call") TEST_CASE("benchmark counter func coro::sync_wait(awaitable)") { constexpr std::size_t iterations = default_iterations; - std::atomic counter{0}; - auto func = [&]() -> coro::task { - counter.fetch_add(1, std::memory_order::relaxed); - co_return; + uint64_t counter{0}; + auto func = []() -> coro::task { + co_return 1; }; auto start = sc::now(); for (std::size_t i = 0; i < iterations; ++i) { - coro::sync_wait(func()); + counter += coro::sync_wait(func()); } print_stats("benchmark counter func coro::sync_wait(awaitable)", iterations, start, sc::now()); REQUIRE(counter == iterations); } -TEST_CASE("benchmark counter func coro::sync_wait_all(awaitable)") +TEST_CASE("benchmark counter func coro::sync_wait(coro::when_all_awaitable(awaitable)) x10") { constexpr std::size_t iterations = default_iterations; - std::atomic counter{0}; - auto func = [&]() -> coro::task { - counter.fetch_add(1, std::memory_order::relaxed); - co_return; + uint64_t counter{0}; + auto f = []() -> coro::task { + co_return 1; }; auto start = sc::now(); for (std::size_t i = 0; i < iterations; i += 10) { - coro::sync_wait_all(func(), func(), func(), func(), func(), func(), func(), func(), func(), func()); + auto tasks = coro::sync_wait(coro::when_all_awaitable(f(), f(), f(), f(), f(), f(), f(), f(), f(), f())); + + std::apply([&counter](auto&&... t) { + ((counter += t.return_value()), ...); + }, + tasks); } - print_stats("benchmark counter func coro::sync_wait_all(awaitable)", iterations, start, sc::now()); + print_stats("benchmark counter func coro::sync_wait(coro::when_all_awaitable(awaitable))", iterations, start, sc::now()); REQUIRE(counter == iterations); } diff --git a/test/test_sync_wait.cpp b/test/test_sync_wait.cpp index 10cc986..2974704 100644 --- a/test/test_sync_wait.cpp +++ b/test/test_sync_wait.cpp @@ -2,12 +2,9 @@ #include -TEST_CASE("sync_wait task multiple suspends return integer with sync_wait") +TEST_CASE("sync_wait simple integer return") { auto func = []() -> coro::task { - co_await std::suspend_always{}; - co_await std::suspend_always{}; - co_await std::suspend_always{}; co_return 11; }; @@ -15,6 +12,19 @@ TEST_CASE("sync_wait task multiple suspends return integer with sync_wait") REQUIRE(result == 11); } +TEST_CASE("sync_wait void") +{ + std::string output; + + auto func = [&]() -> coro::task { + output = "hello from sync_wait\n"; + co_return; + }; + + coro::sync_wait(func()); + REQUIRE(output == "hello from sync_wait\n"); +} + TEST_CASE("sync_wait task co_await single") { auto answer = []() -> coro::task { @@ -38,17 +48,3 @@ TEST_CASE("sync_wait task co_await single") auto output = coro::sync_wait(await_answer()); REQUIRE(output == 1337); } - -TEST_CASE("sync_wait_all accumulate") -{ - std::atomic counter{0}; - auto func = [&](uint64_t amount) -> coro::task { - std::cerr << "amount=" << amount << "\n"; - counter += amount; - co_return; - }; - - coro::sync_wait_all(func(100), func(10), func(50)); - - REQUIRE(counter == 160); -} diff --git a/test/test_thread_pool.cpp b/test/test_thread_pool.cpp new file mode 100644 index 0000000..891e3a0 --- /dev/null +++ b/test/test_thread_pool.cpp @@ -0,0 +1,22 @@ +#include "catch.hpp" + +#include + +#include + +// TEST_CASE("thread_pool one worker, one task") +// { +// coro::thread_pool tp{1}; + +// auto func = [&tp]() -> coro::task +// { +// std::cerr << "func()\n"; +// co_await tp.schedule().value(); // Schedule this coroutine on the scheduler. +// std::cerr << "func co_return 42\n"; +// co_return 42; +// }; + +// std::cerr << "coro::sync_wait(func()) start\n"; +// coro::sync_wait(func()); +// std::cerr << "coro::sync_wait(func()) end\n"; +// } \ No newline at end of file diff --git a/test/test_when_all.cpp b/test/test_when_all.cpp new file mode 100644 index 0000000..fcf6949 --- /dev/null +++ b/test/test_when_all.cpp @@ -0,0 +1,86 @@ +#include "catch.hpp" + +#include + +TEST_CASE("when_all_awaitable single task with tuple container") +{ + auto make_task = [](uint64_t amount) -> coro::task { + co_return amount; + }; + + auto output_tasks = coro::sync_wait(coro::when_all_awaitable(make_task(100))); + REQUIRE(std::tuple_size() == 1); + + uint64_t counter{0}; + std::apply( + [&counter](auto&&... tasks) -> void { + ((counter += tasks.return_value()), ...); + }, + output_tasks); + + REQUIRE(counter == 100); +} + +TEST_CASE("when_all_awaitable multiple tasks with tuple container") +{ + auto make_task = [](uint64_t amount) -> coro::task { + co_return amount; + }; + + auto output_tasks = coro::sync_wait(coro::when_all_awaitable(make_task(100), make_task(50), make_task(20))); + REQUIRE(std::tuple_size() == 3); + + uint64_t counter{0}; + std::apply( + [&counter](auto&&... tasks) -> void { + ((counter += tasks.return_value()), ...); + }, + output_tasks); + + REQUIRE(counter == 170); +} + +TEST_CASE("when_all_awaitable single task with vector container") +{ + auto make_task = [](uint64_t amount) -> coro::task { + co_return amount; + }; + + std::vector> input_tasks; + input_tasks.emplace_back(make_task(100)); + + auto output_tasks = coro::sync_wait(coro::when_all_awaitable(std::move(input_tasks))); + REQUIRE(output_tasks.size() == 1); + + uint64_t counter{0}; + for(const auto& task : output_tasks) + { + counter += task.return_value(); + } + + REQUIRE(counter == 100); +} + +TEST_CASE("when_all_ready multple task withs vector container") +{ + auto make_task = [](uint64_t amount) -> coro::task { + co_return amount; + }; + + std::vector> input_tasks; + input_tasks.emplace_back(make_task(100)); + input_tasks.emplace_back(make_task(200)); + input_tasks.emplace_back(make_task(550)); + input_tasks.emplace_back(make_task(1000)); + + auto output_tasks = coro::sync_wait(coro::when_all_awaitable(std::move(input_tasks))); + REQUIRE(output_tasks.size() == 4); + + uint64_t counter{0}; + for(const auto& task : output_tasks) + { + counter += task.return_value(); + } + + REQUIRE(counter == 1850); +}