1
0
Fork 0
mirror of https://gitlab.com/niansa/libcrosscoro.git synced 2025-03-06 20:53:32 +01:00
libcrosscoro/inc/coro/when_all.hpp
2021-01-16 20:27:11 -07:00

480 lines
14 KiB
C++

#pragma once
#include "coro/concepts/awaitable.hpp"
#include "coro/detail/void_value.hpp"
#include <atomic>
#include <coroutine>
#include <tuple>
namespace coro
{
namespace detail
{
class when_all_latch
{
public:
when_all_latch(std::size_t count) noexcept : m_count(count + 1) {}
when_all_latch(const when_all_latch&) = delete;
when_all_latch(when_all_latch&& other)
: m_count(other.m_count.load(std::memory_order::acquire)),
m_awaiting_coroutine(std::exchange(other.m_awaiting_coroutine, nullptr))
{
}
auto operator=(const when_all_latch&) -> when_all_latch& = delete;
auto operator =(when_all_latch&& other) -> when_all_latch&
{
if (std::addressof(other) != this)
{
m_count.store(other.m_count.load(std::memory_order::acquire), std::memory_order::relaxed);
m_awaiting_coroutine = std::exchange(other.m_awaiting_coroutine, nullptr);
}
return *this;
}
auto is_ready() const noexcept -> bool { return m_awaiting_coroutine != nullptr && m_awaiting_coroutine.done(); }
auto try_await(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool
{
m_awaiting_coroutine = awaiting_coroutine;
return m_count.fetch_sub(1, std::memory_order::acq_rel) > 1;
}
auto notify_awaitable_completed() noexcept -> void
{
if (m_count.fetch_sub(1, std::memory_order::acq_rel) == 1)
{
m_awaiting_coroutine.resume();
}
}
private:
/// The number of tasks that are being waited on.
std::atomic<std::size_t> m_count;
/// The when_all_task awaiting to be resumed upon all task completions.
std::coroutine_handle<> m_awaiting_coroutine{nullptr};
};
template<typename task_container_type>
class when_all_ready_awaitable;
template<typename return_type>
class when_all_task;
/// Empty tuple<> implementation.
template<>
class when_all_ready_awaitable<std::tuple<>>
{
public:
constexpr when_all_ready_awaitable() noexcept {}
explicit constexpr when_all_ready_awaitable(std::tuple<>) noexcept {}
constexpr auto await_ready() const noexcept -> bool { return true; }
auto await_suspend(std::coroutine_handle<>) noexcept -> void {}
auto await_resume() const noexcept -> std::tuple<> { return {}; }
};
template<typename... task_types>
class when_all_ready_awaitable<std::tuple<task_types...>>
{
public:
explicit when_all_ready_awaitable(task_types&&... tasks) noexcept(
std::conjunction_v<std::is_nothrow_move_constructible_v<task_types>...>)
: m_latch(sizeof...(task_types)),
m_tasks(std::move(tasks)...)
{
}
explicit when_all_ready_awaitable(std::tuple<task_types...>&& tasks) noexcept(
std::is_nothrow_move_constructible_v<std::tuple<task_types...>>)
: m_latch(sizeof...(task_types)),
m_tasks(std::move(tasks))
{
}
when_all_ready_awaitable(const when_all_ready_awaitable&) = delete;
when_all_ready_awaitable(when_all_ready_awaitable&& other)
: m_latch(std::move(other.m_latch)),
m_tasks(std::move(other.m_tasks))
{
}
auto operator=(const when_all_ready_awaitable&) -> when_all_ready_awaitable& = delete;
auto operator=(when_all_ready_awaitable&&) -> when_all_ready_awaitable& = delete;
auto operator co_await() & noexcept
{
struct awaiter
{
explicit awaiter(when_all_ready_awaitable& awaitable) noexcept : m_awaitable(awaitable) {}
auto await_ready() const noexcept -> bool { return m_awaitable.is_ready(); }
auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool
{
return m_awaitable.try_await(awaiting_coroutine);
}
auto await_resume() noexcept -> std::tuple<task_types...>& { return m_awaitable.m_tasks; }
private:
when_all_ready_awaitable& m_awaitable;
};
return awaiter{*this};
}
auto operator co_await() && noexcept
{
struct awaiter
{
explicit awaiter(when_all_ready_awaitable& awaitable) noexcept : m_awaitable(awaitable) {}
auto await_ready() const noexcept -> bool { return m_awaitable.is_ready(); }
auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool
{
return m_awaitable.try_await(awaiting_coroutine);
}
auto await_resume() noexcept -> std::tuple<task_types...>&& { return std::move(m_awaitable.m_tasks); }
private:
when_all_ready_awaitable& m_awaitable;
};
return awaiter{*this};
}
private:
auto is_ready() const noexcept -> bool { return m_latch.is_ready(); }
auto try_await(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool
{
std::apply([this](auto&&... tasks) { ((tasks.start(m_latch)), ...); }, m_tasks);
return m_latch.try_await(awaiting_coroutine);
}
when_all_latch m_latch;
std::tuple<task_types...> m_tasks;
};
template<typename task_container_type>
class when_all_ready_awaitable
{
public:
explicit when_all_ready_awaitable(task_container_type&& tasks) noexcept
: m_latch(std::size(tasks)),
m_tasks(std::forward<task_container_type>(tasks))
{
}
when_all_ready_awaitable(const when_all_ready_awaitable&) = delete;
when_all_ready_awaitable(when_all_ready_awaitable&& other) noexcept(
std::is_nothrow_move_constructible_v<task_container_type>)
: m_latch(std::move(other.m_latch)),
m_tasks(std::move(m_tasks))
{
}
auto operator=(const when_all_ready_awaitable&) -> when_all_ready_awaitable& = delete;
auto operator=(when_all_ready_awaitable&) -> when_all_ready_awaitable& = delete;
auto operator co_await() & noexcept
{
struct awaiter
{
awaiter(when_all_ready_awaitable& awaitable) : m_awaitable(awaitable) {}
auto await_ready() const noexcept -> bool { return m_awaitable.is_ready(); }
auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool
{
return m_awaitable.try_await(awaiting_coroutine);
}
auto await_resume() noexcept -> task_container_type& { return m_awaitable.m_tasks; }
private:
when_all_ready_awaitable& m_awaitable;
};
return awaiter{*this};
}
auto operator co_await() && noexcept
{
struct awaiter
{
awaiter(when_all_ready_awaitable& awaitable) : m_awaitable(awaitable) {}
auto await_ready() const noexcept -> bool { return m_awaitable.is_ready(); }
auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool
{
return m_awaitable.try_await(awaiting_coroutine);
}
auto await_resume() noexcept -> task_container_type&& { return std::move(m_awaitable.m_tasks); }
private:
when_all_ready_awaitable& m_awaitable;
};
return awaiter{*this};
}
private:
auto is_ready() const noexcept -> bool { return m_latch.is_ready(); }
auto try_await(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool
{
for (auto& task : m_tasks)
{
task.start(m_latch);
}
return m_latch.try_await(awaiting_coroutine);
}
when_all_latch m_latch;
task_container_type m_tasks;
};
template<typename return_type>
class when_all_task_promise
{
public:
using coroutine_handle_type = std::coroutine_handle<when_all_task_promise<return_type>>;
when_all_task_promise() noexcept {}
auto get_return_object() noexcept { return coroutine_handle_type::from_promise(*this); }
auto initial_suspend() noexcept -> std::suspend_always { return {}; }
auto final_suspend() noexcept
{
struct completion_notifier
{
auto await_ready() const noexcept -> bool { return false; }
auto await_suspend(coroutine_handle_type coroutine) const noexcept -> void
{
coroutine.promise().m_latch->notify_awaitable_completed();
}
auto await_resume() const noexcept {}
};
return completion_notifier{};
}
auto unhandled_exception() noexcept { m_exception_ptr = std::current_exception(); }
auto yield_value(return_type&& value) noexcept
{
m_return_value = std::addressof(value);
return final_suspend();
}
auto start(when_all_latch& latch) noexcept -> void
{
m_latch = &latch;
coroutine_handle_type::from_promise(*this).resume();
}
auto return_value() & -> return_type&
{
if (m_exception_ptr)
{
std::rethrow_exception(m_exception_ptr);
}
return *m_return_value;
}
auto return_value() && -> return_type&&
{
if (m_exception_ptr)
{
std::rethrow_exception(m_exception_ptr);
}
return std::forward(*m_return_value);
}
private:
when_all_latch* m_latch{nullptr};
std::exception_ptr m_exception_ptr;
std::add_pointer_t<return_type> m_return_value;
};
template<>
class when_all_task_promise<void>
{
public:
using coroutine_handle_type = std::coroutine_handle<when_all_task_promise<void>>;
when_all_task_promise() noexcept {}
auto get_return_object() noexcept { return coroutine_handle_type::from_promise(*this); }
auto initial_suspend() noexcept -> std::suspend_always { return {}; }
auto final_suspend() noexcept
{
struct completion_notifier
{
auto await_ready() const noexcept -> bool { return false; }
auto await_suspend(coroutine_handle_type coroutine) const noexcept -> void
{
coroutine.promise().m_latch->notify_awaitable_completed();
}
auto await_resume() const noexcept -> void {}
};
return completion_notifier{};
}
auto unhandled_exception() noexcept -> void { m_exception_ptr = std::current_exception(); }
auto return_void() noexcept -> void {}
auto start(when_all_latch& latch) -> void
{
m_latch = &latch;
coroutine_handle_type::from_promise(*this).resume();
}
auto return_value() -> void
{
if (m_exception_ptr)
{
std::rethrow_exception(m_exception_ptr);
}
}
private:
when_all_latch* m_latch{nullptr};
std::exception_ptr m_exception_ptr;
};
template<typename return_type>
class when_all_task
{
public:
// To be able to call start().
template<typename task_container_type>
friend class when_all_ready_awaitable;
using promise_type = when_all_task_promise<return_type>;
using coroutine_handle_type = typename promise_type::coroutine_handle_type;
when_all_task(coroutine_handle_type coroutine) noexcept : m_coroutine(coroutine) {}
when_all_task(const when_all_task&) = delete;
when_all_task(when_all_task&& other) noexcept
: m_coroutine(std::exchange(other.m_coroutine, coroutine_handle_type{}))
{
}
auto operator=(const when_all_task&) -> when_all_task& = delete;
auto operator=(when_all_task&&) -> when_all_task& = delete;
~when_all_task()
{
if (m_coroutine != nullptr)
{
m_coroutine.destroy();
}
}
auto return_value() & -> decltype(auto)
{
if constexpr (std::is_void_v<return_type>)
{
m_coroutine.promise().return_void();
return void_value{};
}
else
{
return m_coroutine.promise().return_value();
}
}
auto return_value() const& -> decltype(auto)
{
if constexpr (std::is_void_v<return_type>)
{
m_coroutine.promise().return_void();
return void_value{};
}
else
{
return m_coroutine.promise().return_value();
}
}
auto return_value() && -> decltype(auto)
{
if constexpr (std::is_void_v<return_type>)
{
m_coroutine.promise().return_void();
return void_value{};
}
else
{
return m_coroutine.promise().return_value();
}
}
private:
auto start(when_all_latch& latch) noexcept -> void { m_coroutine.promise().start(latch); }
coroutine_handle_type m_coroutine;
};
template<
concepts::awaitable awaitable,
typename return_type = concepts::awaitable_traits<awaitable&&>::awaiter_return_type>
static auto make_when_all_task(awaitable&& a) -> when_all_task<return_type>
{
if constexpr (std::is_void_v<return_type>)
{
co_await static_cast<awaitable&&>(a);
co_return;
}
else
{
co_yield co_await static_cast<awaitable&&>(a);
}
}
} // namespace detail
template<concepts::awaitable... awaitables_type>
[[nodiscard]] auto when_all_awaitable(awaitables_type&&... awaitables)
{
return detail::when_all_ready_awaitable<std::tuple<
detail::when_all_task<typename concepts::awaitable_traits<awaitables_type>::awaiter_return_type>...>>(
std::make_tuple(detail::make_when_all_task(std::forward<awaitables_type>(awaitables))...));
}
template<
concepts::awaitable awaitable,
typename return_type = concepts::awaitable_traits<awaitable>::awaiter_return_type>
[[nodiscard]] auto when_all_awaitable(std::vector<awaitable>& awaitables)
-> detail::when_all_ready_awaitable<std::vector<detail::when_all_task<return_type>>>
{
std::vector<detail::when_all_task<return_type>> tasks;
tasks.reserve(std::size(awaitables));
for (auto& a : awaitables)
{
tasks.emplace_back(detail::make_when_all_task(std::move(a)));
}
return detail::when_all_ready_awaitable(std::move(tasks));
}
} // namespace coro