From 76b41a6ca06c46e972c0741a3266cefe009a1ec1 Mon Sep 17 00:00:00 2001 From: Josh Baldwin Date: Wed, 28 Oct 2020 17:35:23 -0600 Subject: [PATCH] Scheduler now correctly co_await's the user tasks from cleanup task (#14) Previously it set the continuation manually, which sort of works but is not canonical. --- inc/coro/awaitable.hpp | 18 ++- inc/coro/detail/void_value.hpp | 7 +- inc/coro/promise.hpp | 33 +++-- inc/coro/scheduler.hpp | 89 +++++-------- inc/coro/shutdown.hpp | 1 - inc/coro/sync_wait.hpp | 92 +++++-------- inc/coro/task.hpp | 2 +- inc/coro/thread_pool.hpp | 33 +++-- inc/coro/when_all.hpp | 237 ++++++++++++--------------------- src/sync_wait.cpp | 5 +- src/thread_pool.cpp | 34 +++-- test/bench.cpp | 30 ++--- test/test_scheduler.cpp | 24 +++- test/test_sync_wait.cpp | 7 +- test/test_thread_pool.cpp | 38 ++---- test/test_when_all.cpp | 32 ++--- 16 files changed, 277 insertions(+), 405 deletions(-) diff --git a/inc/coro/awaitable.hpp b/inc/coro/awaitable.hpp index f3e48d0..60feb50 100644 --- a/inc/coro/awaitable.hpp +++ b/inc/coro/awaitable.hpp @@ -7,7 +7,6 @@ namespace coro { - /** * This concept declares a type that is required to meet the c++20 coroutine operator co_await() * retun type. It requires the following three member functions: @@ -19,11 +18,13 @@ namespace coro template concept awaiter = requires(type t, std::coroutine_handle<> c) { - { t.await_ready() } -> std::same_as; - std::same_as || - std::same_as || + { + t.await_ready() + } + ->std::same_as; + std::same_as || std::same_as || std::same_as>; - { t.await_resume() }; + {t.await_resume()}; }; /** @@ -33,7 +34,10 @@ template concept awaitable = requires(type t) { // operator co_await() - { t.operator co_await() } -> awaiter; + { + t.operator co_await() + } + ->awaiter; }; template @@ -50,7 +54,7 @@ static auto get_awaiter(awaitable&& value) template struct awaitable_traits { - using awaiter_type = decltype(get_awaiter(std::declval())); + using awaiter_type = decltype(get_awaiter(std::declval())); using awaiter_return_type = decltype(std::declval().await_resume()); }; diff --git a/inc/coro/detail/void_value.hpp b/inc/coro/detail/void_value.hpp index 7793c89..ff599fc 100644 --- a/inc/coro/detail/void_value.hpp +++ b/inc/coro/detail/void_value.hpp @@ -2,7 +2,8 @@ namespace coro::detail { +struct void_value +{ +}; -struct void_value{}; - -} // coro::detail +} // namespace coro::detail diff --git a/inc/coro/promise.hpp b/inc/coro/promise.hpp index 9bf359e..b51aa40 100644 --- a/inc/coro/promise.hpp +++ b/inc/coro/promise.hpp @@ -6,20 +6,33 @@ namespace coro { - template concept promise_type = requires(type t) { - { t.get_return_object() } -> std::convertible_to>; - { t.initial_suspend() } -> awaiter; - { t.final_suspend() } -> awaiter; - { t.yield_value() } -> awaitable; -} && -requires(type t, return_type return_value) + { + t.get_return_object() + } + ->std::convertible_to>; + { + t.initial_suspend() + } + ->awaiter; + { + t.final_suspend() + } + ->awaiter; + { + t.yield_value() + } + ->awaitable; +} +&&requires(type t, return_type return_value) { - std::same_as || - std::same_as || - requires { t.yield_value(return_value); }; + std::same_as || std::same_as || + requires + { + t.yield_value(return_value); + }; }; } // namespace coro diff --git a/inc/coro/scheduler.hpp b/inc/coro/scheduler.hpp index 8ba254a..19f0e60 100644 --- a/inc/coro/scheduler.hpp +++ b/inc/coro/scheduler.hpp @@ -1,7 +1,7 @@ #pragma once -#include "coro/task.hpp" #include "coro/shutdown.hpp" +#include "coro/task.hpp" #include #include @@ -182,14 +182,6 @@ private: template friend class resume_token; - struct task_data - { - /// The user's task, lifetime is maintained by the scheduler. - coro::task m_user_task; - /// The post processing cleanup tasks to remove a completed task from the scheduler. - coro::task m_cleanup_task; - }; - class task_manager { public: @@ -210,8 +202,9 @@ private: * as deleted upon the coroutines completion. * @param user_task The scheduled user's task to store since it has suspended after its * first execution. + * @return The task just stored wrapped in the self cleanup task. */ - auto store(coro::task user_task) -> void + auto store(coro::task user_task) -> task& { // Only grow if completely full and attempting to add more. if (m_free_pos == m_task_indexes.end()) @@ -219,17 +212,14 @@ private: m_free_pos = grow(); } - // Store the user task with its cleanup task to maintain their lifetimes until completed. - auto index = *m_free_pos; - auto& task_data = m_tasks[index]; - task_data.m_user_task = std::move(user_task); - task_data.m_cleanup_task = cleanup_func(m_free_pos); - - // Attach the cleanup task to be the continuation after the users task. - task_data.m_user_task.promise().continuation(task_data.m_cleanup_task.handle()); + // Store the task inside a cleanup task for self deletion. + auto index = *m_free_pos; + m_tasks[index] = make_cleanup_task(std::move(user_task), m_free_pos); // Mark the current used slot as used. std::advance(m_free_pos, 1); + + return m_tasks[index]; } /** @@ -294,20 +284,34 @@ private: } /** - * Each task the user schedules has this task chained as a continuation to execute after - * the user's task completes. This function takes the task position in the indexes list - * and upon execution marks that slot for deletion. It cannot self delete otherwise it - * would corrupt/double free its own coroutine stack frame. + * Encapsulate the users tasks in a cleanup task which marks itself for deletion upon + * completion. Simply co_await the users task until its completed and then mark the given + * position within the task manager as being deletable. The scheduler's next iteration + * in its event loop will then free that position up to be re-used. + * + * This function will also unconditionally catch all unhandled exceptions by the user's + * task to prevent the scheduler from throwing exceptions. + * @param user_task The user's task. + * @param pos The position where the task data will be stored in the task manager. + * @return The user's task wrapped in a self cleanup task. */ - auto cleanup_func(task_position pos) -> coro::task + auto make_cleanup_task(task user_task, task_position pos) -> task { - // Mark this task for deletion, it cannot delete itself. + try + { + co_await user_task; + } + catch (const std::runtime_error& e) + { + std::cerr << "scheduler user_task had an unhandled exception e.what()= " << e.what() << "\n"; + } + m_tasks_to_delete.push_back(pos); co_return; - }; + } /// Maintains the lifetime of the tasks until they are completed. - std::vector m_tasks{}; + std::vector> m_tasks{}; /// The full set of indexes into `m_tasks`. std::list m_task_indexes{}; /// The set of tasks that have completed and need to be deleted. @@ -732,44 +736,19 @@ private: static constexpr std::size_t m_max_events = 8; std::array m_events{}; - auto task_start(coro::task& task) -> void - { - if (!task.is_ready()) // sanity check, the user could have manually resumed. - { - // Attempt to process the task synchronously before suspending. - task.resume(); - - if (!task.is_ready()) - { - m_task_manager.store(std::move(task)); - // This task is now suspended waiting for an event. - } - else - { - // This task completed synchronously. - m_size.fetch_sub(1, std::memory_order::relaxed); - } - } - else - { - m_size.fetch_sub(1, std::memory_order::relaxed); - } - } - inline auto process_task_variant(task_variant& tv) -> void { if (std::holds_alternative>(tv)) { auto& task = std::get>(tv); - task_start(task); + // Store the users task and immediately start executing it. + m_task_manager.store(std::move(task)).resume(); } else { auto handle = std::get>(tv); - if (!handle.done()) - { - handle.resume(); - } + // The cleanup wrapper task will catch all thrown exceptions unconditionally. + handle.resume(); } } diff --git a/inc/coro/shutdown.hpp b/inc/coro/shutdown.hpp index 640b53e..399f592 100644 --- a/inc/coro/shutdown.hpp +++ b/inc/coro/shutdown.hpp @@ -2,7 +2,6 @@ namespace coro { - enum class shutdown_t { /// Synchronously wait for all tasks to complete when calling shutdown. diff --git a/inc/coro/sync_wait.hpp b/inc/coro/sync_wait.hpp index ef0b792..ba45eca 100644 --- a/inc/coro/sync_wait.hpp +++ b/inc/coro/sync_wait.hpp @@ -2,32 +2,31 @@ #include "coro/awaitable.hpp" -#include #include +#include namespace coro { - namespace detail { - class sync_wait_event { public: sync_wait_event(bool initially_set = false); sync_wait_event(const sync_wait_event&) = delete; - sync_wait_event(sync_wait_event&&) = delete; + sync_wait_event(sync_wait_event&&) = delete; auto operator=(const sync_wait_event&) -> sync_wait_event& = delete; - auto operator=(sync_wait_event&&) -> sync_wait_event& = delete; - ~sync_wait_event() = default; + auto operator=(sync_wait_event &&) -> sync_wait_event& = delete; + ~sync_wait_event() = default; auto set() noexcept -> void; auto reset() noexcept -> void; auto wait() noexcept -> void; + private: - std::mutex m_mutex; + std::mutex m_mutex; std::condition_variable m_cv; - bool m_set{false}; + bool m_set{false}; }; class sync_wait_task_promise_base @@ -36,17 +35,12 @@ public: sync_wait_task_promise_base() noexcept = default; virtual ~sync_wait_task_promise_base() = default; - auto initial_suspend() noexcept -> std::suspend_always - { - return {}; - } + auto initial_suspend() noexcept -> std::suspend_always { return {}; } + + auto unhandled_exception() -> void { m_exception = std::current_exception(); } - auto unhandled_exception() -> void - { - m_exception = std::current_exception(); - } protected: - sync_wait_event* m_event{nullptr}; + sync_wait_event* m_event{nullptr}; std::exception_ptr m_exception; }; @@ -56,7 +50,7 @@ class sync_wait_task_promise : public sync_wait_task_promise_base public: using coroutine_type = std::coroutine_handle>; - sync_wait_task_promise() noexcept = default; + sync_wait_task_promise() noexcept = default; ~sync_wait_task_promise() override = default; auto start(sync_wait_event& event) @@ -65,10 +59,7 @@ public: coroutine_type::from_promise(*this).resume(); } - auto get_return_object() noexcept - { - return coroutine_type::from_promise(*this); - } + auto get_return_object() noexcept { return coroutine_type::from_promise(*this); } auto yield_value(return_type&& value) noexcept { @@ -81,11 +72,8 @@ public: struct completion_notifier { auto await_ready() const noexcept { return false; } - auto await_suspend(coroutine_type coroutine) const noexcept - { - coroutine.promise().m_event->set(); - } - auto await_resume() noexcept { }; + auto await_suspend(coroutine_type coroutine) const noexcept { coroutine.promise().m_event->set(); } + auto await_resume() noexcept {}; }; return completion_notifier{}; @@ -93,7 +81,7 @@ public: auto return_value() -> return_type&& { - if(m_exception) + if (m_exception) { std::rethrow_exception(m_exception); } @@ -105,13 +93,13 @@ private: std::remove_reference_t* m_return_value; }; - template<> class sync_wait_task_promise : public sync_wait_task_promise_base { using coroutine_type = std::coroutine_handle>; + public: - sync_wait_task_promise() noexcept = default; + sync_wait_task_promise() noexcept = default; ~sync_wait_task_promise() override = default; auto start(sync_wait_event& event) @@ -120,31 +108,25 @@ public: coroutine_type::from_promise(*this).resume(); } - auto get_return_object() noexcept - { - return coroutine_type::from_promise(*this); - } + auto get_return_object() noexcept { return coroutine_type::from_promise(*this); } auto final_suspend() noexcept { struct completion_notifier { auto await_ready() const noexcept { return false; } - auto await_suspend(coroutine_type coroutine) const noexcept - { - coroutine.promise().m_event->set(); - } - auto await_resume() noexcept { }; + auto await_suspend(coroutine_type coroutine) const noexcept { coroutine.promise().m_event->set(); } + auto await_resume() noexcept {}; }; return completion_notifier{}; } - auto return_void() noexcept -> void { } + auto return_void() noexcept -> void {} auto return_value() { - if(m_exception) + if (m_exception) { std::rethrow_exception(m_exception); } @@ -155,25 +137,17 @@ template class sync_wait_task { public: - using promise_type = sync_wait_task_promise; + using promise_type = sync_wait_task_promise; using coroutine_type = std::coroutine_handle; - sync_wait_task(coroutine_type coroutine) noexcept - : m_coroutine(coroutine) - { - - } + sync_wait_task(coroutine_type coroutine) noexcept : m_coroutine(coroutine) {} sync_wait_task(const sync_wait_task&) = delete; - sync_wait_task(sync_wait_task&& other) noexcept - : m_coroutine(std::exchange(other.m_coroutine, coroutine_type{})) - { - - } + sync_wait_task(sync_wait_task&& other) noexcept : m_coroutine(std::exchange(other.m_coroutine, coroutine_type{})) {} auto operator=(const sync_wait_task&) -> sync_wait_task& = delete; - auto operator=(sync_wait_task&& other) -> sync_wait_task& + auto operator =(sync_wait_task&& other) -> sync_wait_task& { - if(std::addressof(other) != this) + if (std::addressof(other) != this) { m_coroutine = std::exchange(other.m_coroutine, coroutine_type{}); } @@ -183,16 +157,13 @@ public: ~sync_wait_task() { - if(m_coroutine) + if (m_coroutine) { m_coroutine.destroy(); } } - auto start(sync_wait_event& event) noexcept - { - m_coroutine.promise().start(event); - } + auto start(sync_wait_event& event) noexcept { m_coroutine.promise().start(event); } auto return_value() -> decltype(auto) { @@ -211,7 +182,6 @@ private: coroutine_type m_coroutine; }; - template::awaiter_return_type> static auto make_sync_wait_task(awaitable&& a) -> sync_wait_task { @@ -232,7 +202,7 @@ template auto sync_wait(awaitable&& a) -> decltype(auto) { detail::sync_wait_event e{}; - auto task = detail::make_sync_wait_task(std::forward(a)); + auto task = detail::make_sync_wait_task(std::forward(a)); task.start(e); e.wait(); diff --git a/inc/coro/task.hpp b/inc/coro/task.hpp index 477d18e..bad058e 100644 --- a/inc/coro/task.hpp +++ b/inc/coro/task.hpp @@ -143,7 +143,7 @@ public: explicit task(coroutine_handle handle) : m_coroutine(handle) {} task(const task&) = delete; - task(task&& other) noexcept : m_coroutine(std::exchange(other.m_coroutine, nullptr)) { } + task(task&& other) noexcept : m_coroutine(std::exchange(other.m_coroutine, nullptr)) {} ~task() { diff --git a/inc/coro/thread_pool.hpp b/inc/coro/thread_pool.hpp index f0f2281..7df52b8 100644 --- a/inc/coro/thread_pool.hpp +++ b/inc/coro/thread_pool.hpp @@ -4,18 +4,17 @@ #include "coro/task.hpp" #include -#include -#include -#include #include #include #include -#include #include +#include +#include +#include +#include namespace coro { - /** * Creates a thread pool that executes arbitrary coroutine tasks in a FIFO scheduler policy. * The thread pool by default will create an execution thread per available core on the system. @@ -39,6 +38,7 @@ public: * @param tp The thread pool that created this operation. */ explicit operation(thread_pool& tp) noexcept; + public: /** * Operations always pause so the executing thread and be switched. @@ -54,7 +54,8 @@ public: /** * no-op as this is the function called first by the thread pool's executing thread. */ - auto await_resume() noexcept -> void { } + auto await_resume() noexcept -> void {} + private: /// The thread pool that this operation will execute on. thread_pool& m_thread_pool; @@ -76,16 +77,12 @@ public: /** * @param opts Thread pool configuration options. */ - explicit thread_pool(options opts = options{ - std::thread::hardware_concurrency(), - nullptr, - nullptr - }); + explicit thread_pool(options opts = options{std::thread::hardware_concurrency(), nullptr, nullptr}); thread_pool(const thread_pool&) = delete; - thread_pool(thread_pool&&) = delete; + thread_pool(thread_pool&&) = delete; auto operator=(const thread_pool&) -> thread_pool& = delete; - auto operator=(thread_pool&&) -> thread_pool& = delete; + auto operator=(thread_pool &&) -> thread_pool& = delete; ~thread_pool(); @@ -98,8 +95,7 @@ public: * pool thread. This will return nullopt if the schedule fails, currently the only * way for this to fail is if `shudown()` has been called. */ - [[nodiscard]] - auto schedule() noexcept -> std::optional; + [[nodiscard]] auto schedule() noexcept -> std::optional; /** * @throw std::runtime_error If the thread pool is `shutdown()` scheduling new tasks is not permitted. @@ -108,11 +104,11 @@ public: * @return A task that wraps the given functor to be executed on the thread pool. */ template - [[nodiscard]] - auto schedule(functor&& f, arguments... args) noexcept -> task(args)...))> + [[nodiscard]] auto schedule(functor&& f, arguments... args) noexcept + -> task(args)...))> { auto scheduled = schedule(); - if(!scheduled.has_value()) + if (!scheduled.has_value()) { throw std::runtime_error("coro::thread_pool is shutting down, unable to schedule new tasks."); } @@ -162,6 +158,7 @@ public: * @return True if the task queue is currently empty. */ auto queue_empty() const noexcept -> bool { return queue_size() == 0; } + private: /// The configuration options. options m_opts; diff --git a/inc/coro/when_all.hpp b/inc/coro/when_all.hpp index aa18e7d..85fbbe5 100644 --- a/inc/coro/when_all.hpp +++ b/inc/coro/when_all.hpp @@ -9,42 +9,33 @@ namespace coro { - namespace detail { - class when_all_latch { public: - when_all_latch(std::size_t count) noexcept - : m_count(count + 1) - { } + when_all_latch(std::size_t count) noexcept : m_count(count + 1) {} when_all_latch(const when_all_latch&) = delete; when_all_latch(when_all_latch&& other) : m_count(other.m_count.load(std::memory_order::acquire)), m_awaiting_coroutine(std::exchange(other.m_awaiting_coroutine, nullptr)) - { } + { + } auto operator=(const when_all_latch&) -> when_all_latch& = delete; - auto operator=(when_all_latch&& other) -> when_all_latch& + auto operator =(when_all_latch&& other) -> when_all_latch& { - if(std::addressof(other) != this) + if (std::addressof(other) != this) { - m_count.store( - other.m_count.load(std::memory_order::acquire), - std::memory_order::relaxed - ); + m_count.store(other.m_count.load(std::memory_order::acquire), std::memory_order::relaxed); m_awaiting_coroutine = std::exchange(other.m_awaiting_coroutine, nullptr); } return *this; } - auto is_ready() const noexcept -> bool - { - return m_awaiting_coroutine != nullptr && m_awaiting_coroutine.done(); - } + auto is_ready() const noexcept -> bool { return m_awaiting_coroutine != nullptr && m_awaiting_coroutine.done(); } auto try_await(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool { @@ -54,7 +45,7 @@ public: auto notify_awaitable_completed() noexcept -> void { - if(m_count.fetch_sub(1, std::memory_order::acq_rel) == 1) + if (m_count.fetch_sub(1, std::memory_order::acq_rel) == 1) { m_awaiting_coroutine.resume(); } @@ -82,108 +73,92 @@ public: explicit constexpr when_all_ready_awaitable(std::tuple<>) noexcept {} constexpr auto await_ready() const noexcept -> bool { return true; } - auto await_suspend(std::coroutine_handle<>) noexcept -> void { } - auto await_resume() const noexcept -> std::tuple<> { return {}; } + auto await_suspend(std::coroutine_handle<>) noexcept -> void {} + auto await_resume() const noexcept -> std::tuple<> { return {}; } }; template class when_all_ready_awaitable> { public: - explicit when_all_ready_awaitable(task_types&&... tasks) - noexcept(std::conjunction_v...>) + explicit when_all_ready_awaitable(task_types&&... tasks) noexcept( + std::conjunction_v...>) : m_latch(sizeof...(task_types)), m_tasks(std::move(tasks)...) - {} + { + } - explicit when_all_ready_awaitable(std::tuple&& tasks) - noexcept(std::is_nothrow_move_constructible_v>) + explicit when_all_ready_awaitable(std::tuple&& tasks) noexcept( + std::is_nothrow_move_constructible_v>) : m_latch(sizeof...(task_types)), m_tasks(std::move(tasks)) - { } + { + } when_all_ready_awaitable(const when_all_ready_awaitable&) = delete; when_all_ready_awaitable(when_all_ready_awaitable&& other) : m_latch(std::move(other.m_latch)), m_tasks(std::move(other.m_tasks)) - { } + { + } auto operator=(const when_all_ready_awaitable&) -> when_all_ready_awaitable& = delete; - auto operator=(when_all_ready_awaitable&&) -> when_all_ready_awaitable& = delete; + auto operator=(when_all_ready_awaitable &&) -> when_all_ready_awaitable& = delete; auto operator co_await() & noexcept { struct awaiter { - explicit awaiter(when_all_ready_awaitable& awaitable) noexcept - : m_awaitable(awaitable) - { } + explicit awaiter(when_all_ready_awaitable& awaitable) noexcept : m_awaitable(awaitable) {} - auto await_ready() const noexcept -> bool - { - return m_awaitable.is_ready(); - } + auto await_ready() const noexcept -> bool { return m_awaitable.is_ready(); } auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool { return m_awaitable.try_await(awaiting_coroutine); } - auto await_resume() noexcept -> std::tuple& - { - return m_awaitable.m_tasks; - } + auto await_resume() noexcept -> std::tuple& { return m_awaitable.m_tasks; } + private: when_all_ready_awaitable& m_awaitable; }; - return awaiter{ *this }; + return awaiter{*this}; } auto operator co_await() && noexcept { struct awaiter { - explicit awaiter(when_all_ready_awaitable& awaitable) noexcept - : m_awaitable(awaitable) - { } + explicit awaiter(when_all_ready_awaitable& awaitable) noexcept : m_awaitable(awaitable) {} - auto await_ready() const noexcept -> bool - { - return m_awaitable.is_ready(); - } + auto await_ready() const noexcept -> bool { return m_awaitable.is_ready(); } auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool { return m_awaitable.try_await(awaiting_coroutine); } - auto await_resume() noexcept -> std::tuple&& - { - return std::move(m_awaitable.m_tasks); - } + auto await_resume() noexcept -> std::tuple&& { return std::move(m_awaitable.m_tasks); } + private: when_all_ready_awaitable& m_awaitable; }; - return awaiter{ *this }; + return awaiter{*this}; } -private: - auto is_ready() const noexcept -> bool - { - return m_latch.is_ready(); - } +private: + auto is_ready() const noexcept -> bool { return m_latch.is_ready(); } auto try_await(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool { - std::apply( - [this](auto&&... tasks) { ((tasks.start(m_latch)), ...); }, - m_tasks); + std::apply([this](auto&&... tasks) { ((tasks.start(m_latch)), ...); }, m_tasks); return m_latch.try_await(awaiting_coroutine); } - when_all_latch m_latch; + when_all_latch m_latch; std::tuple m_tasks; }; @@ -194,14 +169,16 @@ public: explicit when_all_ready_awaitable(task_container_type&& tasks) noexcept : m_latch(std::size(tasks)), m_tasks(std::forward(tasks)) - {} + { + } when_all_ready_awaitable(const when_all_ready_awaitable&) = delete; - when_all_ready_awaitable(when_all_ready_awaitable&& other) - noexcept(std::is_nothrow_move_constructible_v) + when_all_ready_awaitable(when_all_ready_awaitable&& other) noexcept( + std::is_nothrow_move_constructible_v) : m_latch(std::move(other.m_latch)), m_tasks(std::move(m_tasks)) - {} + { + } auto operator=(const when_all_ready_awaitable&) -> when_all_ready_awaitable& = delete; auto operator=(when_all_ready_awaitable&) -> when_all_ready_awaitable& = delete; @@ -210,24 +187,17 @@ public: { struct awaiter { - awaiter(when_all_ready_awaitable& awaitable) - : m_awaitable(awaitable) - {} + awaiter(when_all_ready_awaitable& awaitable) : m_awaitable(awaitable) {} - auto await_ready() const noexcept -> bool - { - return m_awaitable.is_ready(); - } + auto await_ready() const noexcept -> bool { return m_awaitable.is_ready(); } auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool { return m_awaitable.try_await(awaiting_coroutine); } - auto await_resume() noexcept -> task_container_type& - { - return m_awaitable.m_tasks; - } + auto await_resume() noexcept -> task_container_type& { return m_awaitable.m_tasks; } + private: when_all_ready_awaitable& m_awaitable; }; @@ -239,39 +209,30 @@ public: { struct awaiter { - awaiter(when_all_ready_awaitable& awaitable) - : m_awaitable(awaitable) - {} + awaiter(when_all_ready_awaitable& awaitable) : m_awaitable(awaitable) {} - auto await_ready() const noexcept -> bool - { - return m_awaitable.is_ready(); - } + auto await_ready() const noexcept -> bool { return m_awaitable.is_ready(); } auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool { return m_awaitable.try_await(awaiting_coroutine); } - auto await_resume() noexcept -> task_container_type&& - { - return std::move(m_awaitable.m_tasks); - } + auto await_resume() noexcept -> task_container_type&& { return std::move(m_awaitable.m_tasks); } + private: when_all_ready_awaitable& m_awaitable; }; return awaiter{*this}; } + private: - auto is_ready() const noexcept -> bool - { - return m_latch.is_ready(); - } + auto is_ready() const noexcept -> bool { return m_latch.is_ready(); } auto try_await(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool { - for(auto& task : m_tasks) + for (auto& task : m_tasks) { task.start(m_latch); } @@ -279,7 +240,7 @@ private: return m_latch.try_await(awaiting_coroutine); } - when_all_latch m_latch; + when_all_latch m_latch; task_container_type m_tasks; }; @@ -289,18 +250,11 @@ class when_all_task_promise public: using coroutine_handle_type = std::coroutine_handle>; - when_all_task_promise() noexcept - {} + when_all_task_promise() noexcept {} - auto get_return_object() noexcept - { - return coroutine_handle_type::from_promise(*this); - } + auto get_return_object() noexcept { return coroutine_handle_type::from_promise(*this); } - auto initial_suspend() noexcept -> std::suspend_always - { - return {}; - } + auto initial_suspend() noexcept -> std::suspend_always { return {}; } auto final_suspend() noexcept { @@ -311,16 +265,13 @@ public: { coroutine.promise().m_latch->notify_awaitable_completed(); } - auto await_resume() const noexcept { } + auto await_resume() const noexcept {} }; return completion_notifier{}; } - auto unhandled_exception() noexcept - { - m_exception_ptr = std::current_exception(); - } + auto unhandled_exception() noexcept { m_exception_ptr = std::current_exception(); } auto yield_value(return_type&& value) noexcept { @@ -336,7 +287,7 @@ public: auto return_value() & -> return_type& { - if(m_exception_ptr) + if (m_exception_ptr) { std::rethrow_exception(m_exception_ptr); } @@ -345,7 +296,7 @@ public: auto return_value() && -> return_type&& { - if(m_exception_ptr) + if (m_exception_ptr) { std::rethrow_exception(m_exception_ptr); } @@ -353,8 +304,8 @@ public: } private: - when_all_latch* m_latch{nullptr}; - std::exception_ptr m_exception_ptr; + when_all_latch* m_latch{nullptr}; + std::exception_ptr m_exception_ptr; std::add_pointer_t m_return_value; }; @@ -364,18 +315,11 @@ class when_all_task_promise public: using coroutine_handle_type = std::coroutine_handle>; - when_all_task_promise() noexcept - {} + when_all_task_promise() noexcept {} - auto get_return_object() noexcept - { - return coroutine_handle_type::from_promise(*this); - } + auto get_return_object() noexcept { return coroutine_handle_type::from_promise(*this); } - auto initial_suspend() noexcept -> std::suspend_always - { - return {}; - } + auto initial_suspend() noexcept -> std::suspend_always { return {}; } auto final_suspend() noexcept { @@ -386,19 +330,15 @@ public: { coroutine.promise().m_latch->notify_awaitable_completed(); } - auto await_resume() const noexcept -> void { } + auto await_resume() const noexcept -> void {} }; return completion_notifier{}; } - auto unhandled_exception() noexcept -> void - { - m_exception_ptr = std::current_exception(); - } + auto unhandled_exception() noexcept -> void { m_exception_ptr = std::current_exception(); } - auto return_void() noexcept -> void - {} + auto return_void() noexcept -> void {} auto start(when_all_latch& latch) -> void { @@ -408,13 +348,14 @@ public: auto return_value() -> void { - if(m_exception_ptr) + if (m_exception_ptr) { std::rethrow_exception(m_exception_ptr); } } + private: - when_all_latch* m_latch{nullptr}; + when_all_latch* m_latch{nullptr}; std::exception_ptr m_exception_ptr; }; @@ -426,24 +367,23 @@ public: template friend class when_all_ready_awaitable; - using promise_type = when_all_task_promise; + using promise_type = when_all_task_promise; using coroutine_handle_type = typename promise_type::coroutine_handle_type; - when_all_task(coroutine_handle_type coroutine) noexcept - : m_coroutine(coroutine) - {} + when_all_task(coroutine_handle_type coroutine) noexcept : m_coroutine(coroutine) {} when_all_task(const when_all_task&) = delete; when_all_task(when_all_task&& other) noexcept : m_coroutine(std::exchange(other.m_coroutine, coroutine_handle_type{})) - {} + { + } auto operator=(const when_all_task&) -> when_all_task& = delete; - auto operator=(when_all_task&&) -> when_all_task& = delete; + auto operator=(when_all_task &&) -> when_all_task& = delete; ~when_all_task() { - if(m_coroutine != nullptr) + if (m_coroutine != nullptr) { m_coroutine.destroy(); } @@ -462,7 +402,7 @@ public: } } - auto return_value() const & -> decltype(auto) + auto return_value() const& -> decltype(auto) { if constexpr (std::is_void_v) { @@ -489,10 +429,7 @@ public: } private: - auto start(when_all_latch& latch) noexcept -> void - { - m_coroutine.promise().start(latch); - } + auto start(when_all_latch& latch) noexcept -> void { m_coroutine.promise().start(latch); } coroutine_handle_type m_coroutine; }; @@ -516,23 +453,19 @@ static auto make_when_all_task(awaitable a) -> when_all_task template [[nodiscard]] auto when_all_awaitable(awaitables_type&&... awaitables) { - return - detail::when_all_ready_awaitable< - std::tuple< - detail::when_all_task< - typename awaitable_traits::awaiter_return_type - >... - > - >(std::make_tuple(detail::make_when_all_task(std::forward(awaitables))...)); + return detail::when_all_ready_awaitable< + std::tuple::awaiter_return_type>...>>( + std::make_tuple(detail::make_when_all_task(std::forward(awaitables))...)); } template::awaiter_return_type> -[[nodiscard]] auto when_all_awaitable(std::vector& awaitables) -> detail::when_all_ready_awaitable>> +[[nodiscard]] auto when_all_awaitable(std::vector& awaitables) + -> detail::when_all_ready_awaitable>> { std::vector> tasks; tasks.reserve(std::size(awaitables)); - for(auto& a : awaitables) + for (auto& a : awaitables) { tasks.emplace_back(detail::make_when_all_task(std::move(a))); } diff --git a/src/sync_wait.cpp b/src/sync_wait.cpp index 95abbd8..61b287f 100644 --- a/src/sync_wait.cpp +++ b/src/sync_wait.cpp @@ -2,11 +2,8 @@ namespace coro::detail { - -sync_wait_event::sync_wait_event(bool initially_set) - : m_set(initially_set) +sync_wait_event::sync_wait_event(bool initially_set) : m_set(initially_set) { - } auto sync_wait_event::set() noexcept -> void diff --git a/src/thread_pool.cpp b/src/thread_pool.cpp index f8ebba6..58a1f9d 100644 --- a/src/thread_pool.cpp +++ b/src/thread_pool.cpp @@ -2,11 +2,8 @@ namespace coro { - -thread_pool::operation::operation(thread_pool& tp) noexcept - : m_thread_pool(tp) +thread_pool::operation::operation(thread_pool& tp) noexcept : m_thread_pool(tp) { - } auto thread_pool::operation::await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> void @@ -19,12 +16,11 @@ auto thread_pool::operation::await_suspend(std::coroutine_handle<> awaiting_coro // something else while this coroutine gets picked up by the thread pool. } -thread_pool::thread_pool(options opts) - : m_opts(std::move(opts)) +thread_pool::thread_pool(options opts) : m_opts(std::move(opts)) { m_threads.reserve(m_opts.thread_count); - for(uint32_t i = 0; i < m_opts.thread_count; ++i) + for (uint32_t i = 0; i < m_opts.thread_count; ++i) { m_threads.emplace_back([this, i](std::stop_token st) { executor(std::move(st), i); }); } @@ -37,7 +33,7 @@ thread_pool::~thread_pool() auto thread_pool::schedule() noexcept -> std::optional { - if(!m_shutdown_requested.load(std::memory_order::relaxed)) + if (!m_shutdown_requested.load(std::memory_order::relaxed)) { m_size.fetch_add(1, std::memory_order_relaxed); return {operation{*this}}; @@ -50,16 +46,16 @@ auto thread_pool::shutdown(shutdown_t wait_for_tasks) noexcept -> void { if (!m_shutdown_requested.exchange(true, std::memory_order::release)) { - for(auto& thread : m_threads) + for (auto& thread : m_threads) { thread.request_stop(); } - if(wait_for_tasks == shutdown_t::sync) + if (wait_for_tasks == shutdown_t::sync) { - for(auto& thread : m_threads) + for (auto& thread : m_threads) { - if(thread.joinable()) + if (thread.joinable()) { thread.join(); } @@ -70,12 +66,12 @@ auto thread_pool::shutdown(shutdown_t wait_for_tasks) noexcept -> void auto thread_pool::executor(std::stop_token stop_token, std::size_t idx) -> void { - if(m_opts.on_thread_start_functor != nullptr) + if (m_opts.on_thread_start_functor != nullptr) { m_opts.on_thread_start_functor(idx); } - while(true) + while (true) { // Wait until the queue has operations to execute or shutdown has been requested. { @@ -84,12 +80,12 @@ auto thread_pool::executor(std::stop_token stop_token, std::size_t idx) -> void } // Continue to pull operations from the global queue until its empty. - while(true) + while (true) { operation* op{nullptr}; { std::lock_guard lk{m_queue_mutex}; - if(!m_queue.empty()) + if (!m_queue.empty()) { op = m_queue.front(); m_queue.pop_front(); @@ -100,7 +96,7 @@ auto thread_pool::executor(std::stop_token stop_token, std::size_t idx) -> void } } - if(op != nullptr && op->m_awaiting_coroutine != nullptr) + if (op != nullptr && op->m_awaiting_coroutine != nullptr) { op->m_awaiting_coroutine.resume(); m_size.fetch_sub(1, std::memory_order_relaxed); @@ -111,13 +107,13 @@ auto thread_pool::executor(std::stop_token stop_token, std::size_t idx) -> void } } - if(stop_token.stop_requested()) + if (stop_token.stop_requested()) { break; // while(true); } } - if(m_opts.on_thread_stop_functor != nullptr) + if (m_opts.on_thread_stop_functor != nullptr) { m_opts.on_thread_stop_functor(idx); } diff --git a/test/bench.cpp b/test/bench.cpp index 5f597b8..c7eece3 100644 --- a/test/bench.cpp +++ b/test/bench.cpp @@ -50,10 +50,8 @@ TEST_CASE("benchmark counter func direct call") TEST_CASE("benchmark counter func coro::sync_wait(awaitable)") { constexpr std::size_t iterations = default_iterations; - uint64_t counter{0}; - auto func = []() -> coro::task { - co_return 1; - }; + uint64_t counter{0}; + auto func = []() -> coro::task { co_return 1; }; auto start = sc::now(); @@ -69,10 +67,8 @@ TEST_CASE("benchmark counter func coro::sync_wait(awaitable)") TEST_CASE("benchmark counter func coro::sync_wait(coro::when_all_awaitable(awaitable)) x10") { constexpr std::size_t iterations = default_iterations; - uint64_t counter{0}; - auto f = []() -> coro::task { - co_return 1; - }; + uint64_t counter{0}; + auto f = []() -> coro::task { co_return 1; }; auto start = sc::now(); @@ -80,13 +76,11 @@ TEST_CASE("benchmark counter func coro::sync_wait(coro::when_all_awaitable(await { auto tasks = coro::sync_wait(coro::when_all_awaitable(f(), f(), f(), f(), f(), f(), f(), f(), f(), f())); - std::apply([&counter](auto&&... t) { - ((counter += t.return_value()), ...); - }, - tasks); + std::apply([&counter](auto&&... t) { ((counter += t.return_value()), ...); }, tasks); } - print_stats("benchmark counter func coro::sync_wait(coro::when_all_awaitable(awaitable))", iterations, start, sc::now()); + print_stats( + "benchmark counter func coro::sync_wait(coro::when_all_awaitable(awaitable))", iterations, start, sc::now()); REQUIRE(counter == iterations); } @@ -94,11 +88,10 @@ TEST_CASE("benchmark thread_pool{1} counter task") { constexpr std::size_t iterations = default_iterations; - coro::thread_pool tp{coro::thread_pool::options{1}}; + coro::thread_pool tp{coro::thread_pool::options{1}}; std::atomic counter{0}; - auto make_task = [](coro::thread_pool& tp, std::atomic& c) -> coro::task - { + auto make_task = [](coro::thread_pool& tp, std::atomic& c) -> coro::task { co_await tp.schedule().value(); c.fetch_add(1, std::memory_order::relaxed); co_return; @@ -126,11 +119,10 @@ TEST_CASE("benchmark thread_pool{2} counter task") { constexpr std::size_t iterations = default_iterations; - coro::thread_pool tp{coro::thread_pool::options{2}}; + coro::thread_pool tp{coro::thread_pool::options{2}}; std::atomic counter{0}; - auto make_task = [](coro::thread_pool& tp, std::atomic& c) -> coro::task - { + auto make_task = [](coro::thread_pool& tp, std::atomic& c) -> coro::task { co_await tp.schedule().value(); c.fetch_add(1, std::memory_order::relaxed); co_return; diff --git a/test/test_scheduler.cpp b/test/test_scheduler.cpp index 79aa251..a6dc875 100644 --- a/test/test_scheduler.cpp +++ b/test/test_scheduler.cpp @@ -516,7 +516,7 @@ TEST_CASE("scheduler task throws") auto func = []() -> coro::task { // Is it possible to actually notify the user when running a task in a scheduler? - // Seems like the user will need to manually catch. + // Seems like the user will need to manually catch within the task themselves. throw std::runtime_error{"I always throw."}; co_return; }; @@ -525,4 +525,24 @@ TEST_CASE("scheduler task throws") s.shutdown(); REQUIRE(s.empty()); -} \ No newline at end of file +} + +TEST_CASE("scheduler task throws after resume") +{ + coro::scheduler s{}; + auto token = s.generate_resume_token(); + + auto func = [&]() -> coro::task { + co_await token; + throw std::runtime_error{"I always throw."}; + co_return; + }; + + s.schedule(func()); + + std::this_thread::sleep_for(50ms); + token.resume(); + + s.shutdown(); + REQUIRE(s.empty()); +} diff --git a/test/test_sync_wait.cpp b/test/test_sync_wait.cpp index 4848268..350579b 100644 --- a/test/test_sync_wait.cpp +++ b/test/test_sync_wait.cpp @@ -4,9 +4,7 @@ TEST_CASE("sync_wait simple integer return") { - auto func = []() -> coro::task { - co_return 11; - }; + auto func = []() -> coro::task { co_return 11; }; auto result = coro::sync_wait(func()); REQUIRE(result == 11); @@ -51,8 +49,7 @@ TEST_CASE("sync_wait task co_await single") TEST_CASE("sync_wait task that throws") { - auto f = []() -> coro::task - { + auto f = []() -> coro::task { throw std::runtime_error("I always throw!"); co_return 1; }; diff --git a/test/test_thread_pool.cpp b/test/test_thread_pool.cpp index 2306c07..8e8afde 100644 --- a/test/test_thread_pool.cpp +++ b/test/test_thread_pool.cpp @@ -8,8 +8,7 @@ TEST_CASE("thread_pool one worker one task") { coro::thread_pool tp{coro::thread_pool::options{1}}; - auto func = [&tp]() -> coro::task - { + auto func = [&tp]() -> coro::task { co_await tp.schedule().value(); // Schedule this coroutine on the scheduler. co_return 42; }; @@ -22,8 +21,7 @@ TEST_CASE("thread_pool one worker many tasks tuple") { coro::thread_pool tp{coro::thread_pool::options{1}}; - auto f = [&tp]() -> coro::task - { + auto f = [&tp]() -> coro::task { co_await tp.schedule().value(); // Schedule this coroutine on the scheduler. co_return 50; }; @@ -32,10 +30,7 @@ TEST_CASE("thread_pool one worker many tasks tuple") REQUIRE(std::tuple_size() == 5); uint64_t counter{0}; - std::apply([&counter](auto&&... t) -> void { - ((counter += t.return_value()), ...); - }, - tasks); + std::apply([&counter](auto&&... t) -> void { ((counter += t.return_value()), ...); }, tasks); REQUIRE(counter == 250); } @@ -44,8 +39,7 @@ TEST_CASE("thread_pool one worker many tasks vector") { coro::thread_pool tp{coro::thread_pool::options{1}}; - auto f = [&tp]() -> coro::task - { + auto f = [&tp]() -> coro::task { co_await tp.schedule().value(); // Schedule this coroutine on the scheduler. co_return 50; }; @@ -60,7 +54,7 @@ TEST_CASE("thread_pool one worker many tasks vector") REQUIRE(output_tasks.size() == 3); uint64_t counter{0}; - for(const auto& task : output_tasks) + for (const auto& task : output_tasks) { counter += task.return_value(); } @@ -71,17 +65,16 @@ TEST_CASE("thread_pool one worker many tasks vector") TEST_CASE("thread_pool N workers 100k tasks") { constexpr const std::size_t iterations = 100'000; - coro::thread_pool tp{}; + coro::thread_pool tp{}; - auto make_task = [](coro::thread_pool& tp) -> coro::task - { + auto make_task = [](coro::thread_pool& tp) -> coro::task { co_await tp.schedule().value(); co_return 1; }; std::vector> input_tasks{}; input_tasks.reserve(iterations); - for(std::size_t i = 0; i < iterations; ++i) + for (std::size_t i = 0; i < iterations; ++i) { input_tasks.emplace_back(make_task(tp)); } @@ -90,7 +83,7 @@ TEST_CASE("thread_pool N workers 100k tasks") REQUIRE(output_tasks.size() == iterations); uint64_t counter{0}; - for(const auto& task : output_tasks) + for (const auto& task : output_tasks) { counter += task.return_value(); } @@ -102,12 +95,10 @@ TEST_CASE("thread_pool 1 worker task spawns another task") { coro::thread_pool tp{coro::thread_pool::options{1}}; - auto f1 = [](coro::thread_pool& tp) -> coro::task - { + auto f1 = [](coro::thread_pool& tp) -> coro::task { co_await tp.schedule().value(); - auto f2 = [](coro::thread_pool& tp) -> coro::task - { + auto f2 = [](coro::thread_pool& tp) -> coro::task { co_await tp.schedule().value(); co_return 5; }; @@ -122,10 +113,9 @@ TEST_CASE("thread_pool shutdown") { coro::thread_pool tp{coro::thread_pool::options{1}}; - auto f = [](coro::thread_pool& tp) -> coro::task - { + auto f = [](coro::thread_pool& tp) -> coro::task { auto scheduled = tp.schedule(); - if(!scheduled.has_value()) + if (!scheduled.has_value()) { co_return true; } @@ -158,7 +148,7 @@ TEST_CASE("thread_pool schedule functor return_type = void") coro::thread_pool tp{coro::thread_pool::options{1}}; std::atomic counter{0}; - auto f = [](std::atomic& c) -> void { c++; }; + auto f = [](std::atomic& c) -> void { c++; }; coro::sync_wait(tp.schedule(f, std::ref(counter))); REQUIRE(counter == 1); diff --git a/test/test_when_all.cpp b/test/test_when_all.cpp index 5c79b30..0fbd453 100644 --- a/test/test_when_all.cpp +++ b/test/test_when_all.cpp @@ -4,47 +4,33 @@ TEST_CASE("when_all_awaitable single task with tuple container") { - auto make_task = [](uint64_t amount) -> coro::task { - co_return amount; - }; + auto make_task = [](uint64_t amount) -> coro::task { co_return amount; }; auto output_tasks = coro::sync_wait(coro::when_all_awaitable(make_task(100))); REQUIRE(std::tuple_size() == 1); uint64_t counter{0}; - std::apply( - [&counter](auto&&... tasks) -> void { - ((counter += tasks.return_value()), ...); - }, - output_tasks); + std::apply([&counter](auto&&... tasks) -> void { ((counter += tasks.return_value()), ...); }, output_tasks); REQUIRE(counter == 100); } TEST_CASE("when_all_awaitable multiple tasks with tuple container") { - auto make_task = [](uint64_t amount) -> coro::task { - co_return amount; - }; + auto make_task = [](uint64_t amount) -> coro::task { co_return amount; }; auto output_tasks = coro::sync_wait(coro::when_all_awaitable(make_task(100), make_task(50), make_task(20))); REQUIRE(std::tuple_size() == 3); uint64_t counter{0}; - std::apply( - [&counter](auto&&... tasks) -> void { - ((counter += tasks.return_value()), ...); - }, - output_tasks); + std::apply([&counter](auto&&... tasks) -> void { ((counter += tasks.return_value()), ...); }, output_tasks); REQUIRE(counter == 170); } TEST_CASE("when_all_awaitable single task with vector container") { - auto make_task = [](uint64_t amount) -> coro::task { - co_return amount; - }; + auto make_task = [](uint64_t amount) -> coro::task { co_return amount; }; std::vector> input_tasks; input_tasks.emplace_back(make_task(100)); @@ -53,7 +39,7 @@ TEST_CASE("when_all_awaitable single task with vector container") REQUIRE(output_tasks.size() == 1); uint64_t counter{0}; - for(const auto& task : output_tasks) + for (const auto& task : output_tasks) { counter += task.return_value(); } @@ -63,9 +49,7 @@ TEST_CASE("when_all_awaitable single task with vector container") TEST_CASE("when_all_ready multple task withs vector container") { - auto make_task = [](uint64_t amount) -> coro::task { - co_return amount; - }; + auto make_task = [](uint64_t amount) -> coro::task { co_return amount; }; std::vector> input_tasks; input_tasks.emplace_back(make_task(100)); @@ -77,7 +61,7 @@ TEST_CASE("when_all_ready multple task withs vector container") REQUIRE(output_tasks.size() == 4); uint64_t counter{0}; - for(const auto& task : output_tasks) + for (const auto& task : output_tasks) { counter += task.return_value(); }