#pragma once #include "coro/concepts/awaitable.hpp" #include "coro/detail/void_value.hpp" #include #include #include namespace coro { namespace detail { class when_all_latch { public: when_all_latch(std::size_t count) noexcept : m_count(count + 1) {} when_all_latch(const when_all_latch&) = delete; when_all_latch(when_all_latch&& other) : m_count(other.m_count.load(std::memory_order::acquire)), m_awaiting_coroutine(std::exchange(other.m_awaiting_coroutine, nullptr)) { } auto operator=(const when_all_latch&) -> when_all_latch& = delete; auto operator =(when_all_latch&& other) -> when_all_latch& { if (std::addressof(other) != this) { m_count.store(other.m_count.load(std::memory_order::acquire), std::memory_order::relaxed); m_awaiting_coroutine = std::exchange(other.m_awaiting_coroutine, nullptr); } return *this; } auto is_ready() const noexcept -> bool { return m_awaiting_coroutine != nullptr && m_awaiting_coroutine.done(); } auto try_await(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool { m_awaiting_coroutine = awaiting_coroutine; return m_count.fetch_sub(1, std::memory_order::acq_rel) > 1; } auto notify_awaitable_completed() noexcept -> void { if (m_count.fetch_sub(1, std::memory_order::acq_rel) == 1) { m_awaiting_coroutine.resume(); } } private: /// The number of tasks that are being waited on. std::atomic m_count; /// The when_all_task awaiting to be resumed upon all task completions. std::coroutine_handle<> m_awaiting_coroutine{nullptr}; }; template class when_all_ready_awaitable; template class when_all_task; /// Empty tuple<> implementation. template<> class when_all_ready_awaitable> { public: constexpr when_all_ready_awaitable() noexcept {} explicit constexpr when_all_ready_awaitable(std::tuple<>) noexcept {} constexpr auto await_ready() const noexcept -> bool { return true; } auto await_suspend(std::coroutine_handle<>) noexcept -> void {} auto await_resume() const noexcept -> std::tuple<> { return {}; } }; template class when_all_ready_awaitable> { public: explicit when_all_ready_awaitable(task_types&&... tasks) noexcept( std::conjunction_v...>) : m_latch(sizeof...(task_types)), m_tasks(std::move(tasks)...) { } explicit when_all_ready_awaitable(std::tuple&& tasks) noexcept( std::is_nothrow_move_constructible_v>) : m_latch(sizeof...(task_types)), m_tasks(std::move(tasks)) { } when_all_ready_awaitable(const when_all_ready_awaitable&) = delete; when_all_ready_awaitable(when_all_ready_awaitable&& other) : m_latch(std::move(other.m_latch)), m_tasks(std::move(other.m_tasks)) { } auto operator=(const when_all_ready_awaitable&) -> when_all_ready_awaitable& = delete; auto operator=(when_all_ready_awaitable&&) -> when_all_ready_awaitable& = delete; auto operator co_await() & noexcept { struct awaiter { explicit awaiter(when_all_ready_awaitable& awaitable) noexcept : m_awaitable(awaitable) {} auto await_ready() const noexcept -> bool { return m_awaitable.is_ready(); } auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool { return m_awaitable.try_await(awaiting_coroutine); } auto await_resume() noexcept -> std::tuple& { return m_awaitable.m_tasks; } private: when_all_ready_awaitable& m_awaitable; }; return awaiter{*this}; } auto operator co_await() && noexcept { struct awaiter { explicit awaiter(when_all_ready_awaitable& awaitable) noexcept : m_awaitable(awaitable) {} auto await_ready() const noexcept -> bool { return m_awaitable.is_ready(); } auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool { return m_awaitable.try_await(awaiting_coroutine); } auto await_resume() noexcept -> std::tuple&& { return std::move(m_awaitable.m_tasks); } private: when_all_ready_awaitable& m_awaitable; }; return awaiter{*this}; } private: auto is_ready() const noexcept -> bool { return m_latch.is_ready(); } auto try_await(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool { std::apply([this](auto&&... tasks) { ((tasks.start(m_latch)), ...); }, m_tasks); return m_latch.try_await(awaiting_coroutine); } when_all_latch m_latch; std::tuple m_tasks; }; template class when_all_ready_awaitable { public: explicit when_all_ready_awaitable(task_container_type&& tasks) noexcept : m_latch(std::size(tasks)), m_tasks(std::forward(tasks)) { } when_all_ready_awaitable(const when_all_ready_awaitable&) = delete; when_all_ready_awaitable(when_all_ready_awaitable&& other) noexcept( std::is_nothrow_move_constructible_v) : m_latch(std::move(other.m_latch)), m_tasks(std::move(m_tasks)) { } auto operator=(const when_all_ready_awaitable&) -> when_all_ready_awaitable& = delete; auto operator=(when_all_ready_awaitable&) -> when_all_ready_awaitable& = delete; auto operator co_await() & noexcept { struct awaiter { awaiter(when_all_ready_awaitable& awaitable) : m_awaitable(awaitable) {} auto await_ready() const noexcept -> bool { return m_awaitable.is_ready(); } auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool { return m_awaitable.try_await(awaiting_coroutine); } auto await_resume() noexcept -> task_container_type& { return m_awaitable.m_tasks; } private: when_all_ready_awaitable& m_awaitable; }; return awaiter{*this}; } auto operator co_await() && noexcept { struct awaiter { awaiter(when_all_ready_awaitable& awaitable) : m_awaitable(awaitable) {} auto await_ready() const noexcept -> bool { return m_awaitable.is_ready(); } auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool { return m_awaitable.try_await(awaiting_coroutine); } auto await_resume() noexcept -> task_container_type&& { return std::move(m_awaitable.m_tasks); } private: when_all_ready_awaitable& m_awaitable; }; return awaiter{*this}; } private: auto is_ready() const noexcept -> bool { return m_latch.is_ready(); } auto try_await(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool { for (auto& task : m_tasks) { task.start(m_latch); } return m_latch.try_await(awaiting_coroutine); } when_all_latch m_latch; task_container_type m_tasks; }; template class when_all_task_promise { public: using coroutine_handle_type = std::coroutine_handle>; when_all_task_promise() noexcept {} auto get_return_object() noexcept { return coroutine_handle_type::from_promise(*this); } auto initial_suspend() noexcept -> std::suspend_always { return {}; } auto final_suspend() noexcept { struct completion_notifier { auto await_ready() const noexcept -> bool { return false; } auto await_suspend(coroutine_handle_type coroutine) const noexcept -> void { coroutine.promise().m_latch->notify_awaitable_completed(); } auto await_resume() const noexcept {} }; return completion_notifier{}; } auto unhandled_exception() noexcept { m_exception_ptr = std::current_exception(); } auto yield_value(return_type&& value) noexcept { m_return_value = std::addressof(value); return final_suspend(); } auto start(when_all_latch& latch) noexcept -> void { m_latch = &latch; coroutine_handle_type::from_promise(*this).resume(); } auto return_value() & -> return_type& { if (m_exception_ptr) { std::rethrow_exception(m_exception_ptr); } return *m_return_value; } auto return_value() && -> return_type&& { if (m_exception_ptr) { std::rethrow_exception(m_exception_ptr); } return std::forward(*m_return_value); } private: when_all_latch* m_latch{nullptr}; std::exception_ptr m_exception_ptr; std::add_pointer_t m_return_value; }; template<> class when_all_task_promise { public: using coroutine_handle_type = std::coroutine_handle>; when_all_task_promise() noexcept {} auto get_return_object() noexcept { return coroutine_handle_type::from_promise(*this); } auto initial_suspend() noexcept -> std::suspend_always { return {}; } auto final_suspend() noexcept { struct completion_notifier { auto await_ready() const noexcept -> bool { return false; } auto await_suspend(coroutine_handle_type coroutine) const noexcept -> void { coroutine.promise().m_latch->notify_awaitable_completed(); } auto await_resume() const noexcept -> void {} }; return completion_notifier{}; } auto unhandled_exception() noexcept -> void { m_exception_ptr = std::current_exception(); } auto return_void() noexcept -> void {} auto start(when_all_latch& latch) -> void { m_latch = &latch; coroutine_handle_type::from_promise(*this).resume(); } auto return_value() -> void { if (m_exception_ptr) { std::rethrow_exception(m_exception_ptr); } } private: when_all_latch* m_latch{nullptr}; std::exception_ptr m_exception_ptr; }; template class when_all_task { public: // To be able to call start(). template friend class when_all_ready_awaitable; using promise_type = when_all_task_promise; using coroutine_handle_type = typename promise_type::coroutine_handle_type; when_all_task(coroutine_handle_type coroutine) noexcept : m_coroutine(coroutine) {} when_all_task(const when_all_task&) = delete; when_all_task(when_all_task&& other) noexcept : m_coroutine(std::exchange(other.m_coroutine, coroutine_handle_type{})) { } auto operator=(const when_all_task&) -> when_all_task& = delete; auto operator=(when_all_task&&) -> when_all_task& = delete; ~when_all_task() { if (m_coroutine != nullptr) { m_coroutine.destroy(); } } auto return_value() & -> decltype(auto) { if constexpr (std::is_void_v) { m_coroutine.promise().return_void(); return void_value{}; } else { return m_coroutine.promise().return_value(); } } auto return_value() const& -> decltype(auto) { if constexpr (std::is_void_v) { m_coroutine.promise().return_void(); return void_value{}; } else { return m_coroutine.promise().return_value(); } } auto return_value() && -> decltype(auto) { if constexpr (std::is_void_v) { m_coroutine.promise().return_void(); return void_value{}; } else { return m_coroutine.promise().return_value(); } } private: auto start(when_all_latch& latch) noexcept -> void { m_coroutine.promise().start(latch); } coroutine_handle_type m_coroutine; }; template< concepts::awaitable awaitable, typename return_type = concepts::awaitable_traits::awaiter_return_type> static auto make_when_all_task(awaitable&& a) -> when_all_task { if constexpr (std::is_void_v) { co_await static_cast(a); co_return; } else { co_yield co_await static_cast(a); } } } // namespace detail template [[nodiscard]] auto when_all_awaitable(awaitables_type&&... awaitables) { return detail::when_all_ready_awaitable::awaiter_return_type>...>>( std::make_tuple(detail::make_when_all_task(std::forward(awaitables))...)); } template< concepts::awaitable awaitable, typename return_type = concepts::awaitable_traits::awaiter_return_type> [[nodiscard]] auto when_all_awaitable(std::vector& awaitables) -> detail::when_all_ready_awaitable>> { std::vector> tasks; tasks.reserve(std::size(awaitables)); for (auto& a : awaitables) { tasks.emplace_back(detail::make_when_all_task(std::move(a))); } return detail::when_all_ready_awaitable(std::move(tasks)); } } // namespace coro