mirror of
https://gitlab.com/niansa/libcrosscoro.git
synced 2025-03-06 20:53:32 +01:00
See issue for more details, in general attempting to implement a coro::thread_pool exposed that the coro::sync_wait and coro::when_all only worked if the coroutines executed on that same thread. They should now possibly have the ability to execute on another thread, to be determined in a later issue. Fixes #7
543 lines
14 KiB
C++
543 lines
14 KiB
C++
#pragma once
|
|
|
|
#include "coro/awaitable.hpp"
|
|
#include "coro/detail/void_value.hpp"
|
|
|
|
#include <atomic>
|
|
#include <coroutine>
|
|
#include <tuple>
|
|
|
|
namespace coro
|
|
{
|
|
|
|
namespace detail
|
|
{
|
|
|
|
class when_all_latch
|
|
{
|
|
public:
|
|
when_all_latch(std::size_t count) noexcept
|
|
: m_count(count + 1)
|
|
{ }
|
|
|
|
when_all_latch(const when_all_latch&) = delete;
|
|
when_all_latch(when_all_latch&& other)
|
|
: m_count(other.m_count.load(std::memory_order::acquire)),
|
|
m_awaiting_coroutine(std::exchange(other.m_awaiting_coroutine, nullptr))
|
|
{ }
|
|
|
|
auto operator=(const when_all_latch&) -> when_all_latch& = delete;
|
|
auto operator=(when_all_latch&& other) -> when_all_latch&
|
|
{
|
|
if(std::addressof(other) != this)
|
|
{
|
|
m_count.store(
|
|
other.m_count.load(std::memory_order::acquire),
|
|
std::memory_order::relaxed
|
|
);
|
|
m_awaiting_coroutine = std::exchange(other.m_awaiting_coroutine, nullptr);
|
|
}
|
|
|
|
return *this;
|
|
}
|
|
|
|
auto is_ready() const noexcept -> bool
|
|
{
|
|
return m_awaiting_coroutine != nullptr && m_awaiting_coroutine.done();
|
|
}
|
|
|
|
auto try_await(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool
|
|
{
|
|
m_awaiting_coroutine = awaiting_coroutine;
|
|
return m_count.fetch_sub(1, std::memory_order::acq_rel) > 1;
|
|
}
|
|
|
|
auto notify_awaitable_completed() noexcept -> void
|
|
{
|
|
if(m_count.fetch_sub(1, std::memory_order::acq_rel) == 1)
|
|
{
|
|
m_awaiting_coroutine.resume();
|
|
}
|
|
}
|
|
|
|
private:
|
|
/// The number of tasks that are being waited on.
|
|
std::atomic<std::size_t> m_count;
|
|
/// The when_all_task awaiting to be resumed upon all task completions.
|
|
std::coroutine_handle<> m_awaiting_coroutine{nullptr};
|
|
};
|
|
|
|
template<typename task_container_type>
|
|
class when_all_ready_awaitable;
|
|
|
|
template<typename return_type>
|
|
class when_all_task;
|
|
|
|
/// Empty tuple<> implementation.
|
|
template<>
|
|
class when_all_ready_awaitable<std::tuple<>>
|
|
{
|
|
public:
|
|
constexpr when_all_ready_awaitable() noexcept {}
|
|
explicit constexpr when_all_ready_awaitable(std::tuple<>) noexcept {}
|
|
|
|
constexpr auto await_ready() const noexcept -> bool { return true; }
|
|
auto await_suspend(std::coroutine_handle<>) noexcept -> void { }
|
|
auto await_resume() const noexcept -> std::tuple<> { return {}; }
|
|
};
|
|
|
|
template<typename... task_types>
|
|
class when_all_ready_awaitable<std::tuple<task_types...>>
|
|
{
|
|
public:
|
|
explicit when_all_ready_awaitable(task_types&&... tasks)
|
|
noexcept(std::conjunction_v<std::is_nothrow_move_constructible_v<task_types>...>)
|
|
: m_latch(sizeof...(task_types)),
|
|
m_tasks(std::move(tasks)...)
|
|
{}
|
|
|
|
explicit when_all_ready_awaitable(std::tuple<task_types...>&& tasks)
|
|
noexcept(std::is_nothrow_move_constructible_v<std::tuple<task_types...>>)
|
|
: m_latch(sizeof...(task_types)),
|
|
m_tasks(std::move(tasks))
|
|
{ }
|
|
|
|
when_all_ready_awaitable(const when_all_ready_awaitable&) = delete;
|
|
when_all_ready_awaitable(when_all_ready_awaitable&& other)
|
|
: m_latch(std::move(other.m_latch)),
|
|
m_tasks(std::move(other.m_tasks))
|
|
{ }
|
|
|
|
auto operator=(const when_all_ready_awaitable&) -> when_all_ready_awaitable& = delete;
|
|
auto operator=(when_all_ready_awaitable&&) -> when_all_ready_awaitable& = delete;
|
|
|
|
auto operator co_await() & noexcept
|
|
{
|
|
struct awaiter
|
|
{
|
|
explicit awaiter(when_all_ready_awaitable& awaitable) noexcept
|
|
: m_awaitable(awaitable)
|
|
{ }
|
|
|
|
auto await_ready() const noexcept -> bool
|
|
{
|
|
return m_awaitable.is_ready();
|
|
}
|
|
|
|
auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool
|
|
{
|
|
return m_awaitable.try_await(awaiting_coroutine);
|
|
}
|
|
|
|
auto await_resume() noexcept -> std::tuple<task_types...>&
|
|
{
|
|
return m_awaitable.m_tasks;
|
|
}
|
|
private:
|
|
when_all_ready_awaitable& m_awaitable;
|
|
};
|
|
|
|
return awaiter{ *this };
|
|
}
|
|
|
|
auto operator co_await() && noexcept
|
|
{
|
|
struct awaiter
|
|
{
|
|
explicit awaiter(when_all_ready_awaitable& awaitable) noexcept
|
|
: m_awaitable(awaitable)
|
|
{ }
|
|
|
|
auto await_ready() const noexcept -> bool
|
|
{
|
|
return m_awaitable.is_ready();
|
|
}
|
|
|
|
auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool
|
|
{
|
|
return m_awaitable.try_await(awaiting_coroutine);
|
|
}
|
|
|
|
auto await_resume() noexcept -> std::tuple<task_types...>&&
|
|
{
|
|
return std::move(m_awaitable.m_tasks);
|
|
}
|
|
private:
|
|
when_all_ready_awaitable& m_awaitable;
|
|
};
|
|
|
|
return awaiter{ *this };
|
|
}
|
|
private:
|
|
|
|
auto is_ready() const noexcept -> bool
|
|
{
|
|
return m_latch.is_ready();
|
|
}
|
|
|
|
auto try_await(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool
|
|
{
|
|
std::apply(
|
|
[this](auto&&... tasks) { ((tasks.start(m_latch)), ...); },
|
|
m_tasks);
|
|
return m_latch.try_await(awaiting_coroutine);
|
|
}
|
|
|
|
when_all_latch m_latch;
|
|
std::tuple<task_types...> m_tasks;
|
|
};
|
|
|
|
template<typename task_container_type>
|
|
class when_all_ready_awaitable
|
|
{
|
|
public:
|
|
explicit when_all_ready_awaitable(task_container_type&& tasks) noexcept
|
|
: m_latch(std::size(tasks)),
|
|
m_tasks(std::forward<task_container_type>(tasks))
|
|
{}
|
|
|
|
when_all_ready_awaitable(const when_all_ready_awaitable&) = delete;
|
|
when_all_ready_awaitable(when_all_ready_awaitable&& other)
|
|
noexcept(std::is_nothrow_move_constructible_v<task_container_type>)
|
|
: m_latch(std::move(other.m_latch)),
|
|
m_tasks(std::move(m_tasks))
|
|
{}
|
|
|
|
auto operator=(const when_all_ready_awaitable&) -> when_all_ready_awaitable& = delete;
|
|
auto operator=(when_all_ready_awaitable&) -> when_all_ready_awaitable& = delete;
|
|
|
|
auto operator co_await() & noexcept
|
|
{
|
|
struct awaiter
|
|
{
|
|
awaiter(when_all_ready_awaitable& awaitable)
|
|
: m_awaitable(awaitable)
|
|
{}
|
|
|
|
auto await_ready() const noexcept -> bool
|
|
{
|
|
return m_awaitable.is_ready();
|
|
}
|
|
|
|
auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool
|
|
{
|
|
return m_awaitable.try_await(awaiting_coroutine);
|
|
}
|
|
|
|
auto await_resume() noexcept -> task_container_type&
|
|
{
|
|
return m_awaitable.m_tasks;
|
|
}
|
|
private:
|
|
when_all_ready_awaitable& m_awaitable;
|
|
};
|
|
|
|
return awaiter{*this};
|
|
}
|
|
|
|
auto operator co_await() && noexcept
|
|
{
|
|
struct awaiter
|
|
{
|
|
awaiter(when_all_ready_awaitable& awaitable)
|
|
: m_awaitable(awaitable)
|
|
{}
|
|
|
|
auto await_ready() const noexcept -> bool
|
|
{
|
|
return m_awaitable.is_ready();
|
|
}
|
|
|
|
auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool
|
|
{
|
|
return m_awaitable.try_await(awaiting_coroutine);
|
|
}
|
|
|
|
auto await_resume() noexcept -> task_container_type&&
|
|
{
|
|
return std::move(m_awaitable.m_tasks);
|
|
}
|
|
private:
|
|
when_all_ready_awaitable& m_awaitable;
|
|
};
|
|
|
|
return awaiter{*this};
|
|
}
|
|
private:
|
|
auto is_ready() const noexcept -> bool
|
|
{
|
|
return m_latch.is_ready();
|
|
}
|
|
|
|
auto try_await(std::coroutine_handle<> awaiting_coroutine) noexcept -> bool
|
|
{
|
|
for(auto& task : m_tasks)
|
|
{
|
|
task.start(m_latch);
|
|
}
|
|
|
|
return m_latch.try_await(awaiting_coroutine);
|
|
}
|
|
|
|
when_all_latch m_latch;
|
|
task_container_type m_tasks;
|
|
};
|
|
|
|
template<typename return_type>
|
|
class when_all_task_promise
|
|
{
|
|
public:
|
|
using coroutine_handle_type = std::coroutine_handle<when_all_task_promise<return_type>>;
|
|
|
|
when_all_task_promise() noexcept
|
|
{}
|
|
|
|
auto get_return_object() noexcept
|
|
{
|
|
return coroutine_handle_type::from_promise(*this);
|
|
}
|
|
|
|
auto initial_suspend() noexcept -> std::suspend_always
|
|
{
|
|
return {};
|
|
}
|
|
|
|
auto final_suspend() noexcept
|
|
{
|
|
struct completion_notifier
|
|
{
|
|
auto await_ready() const noexcept -> bool { return false; }
|
|
auto await_suspend(coroutine_handle_type coroutine) const noexcept -> void
|
|
{
|
|
coroutine.promise().m_latch->notify_awaitable_completed();
|
|
}
|
|
auto await_resume() const noexcept { }
|
|
};
|
|
|
|
return completion_notifier{};
|
|
}
|
|
|
|
auto unhandled_exception() noexcept
|
|
{
|
|
m_exception_ptr = std::current_exception();
|
|
}
|
|
|
|
auto yield_value(return_type&& value) noexcept
|
|
{
|
|
m_return_value = std::addressof(value);
|
|
return final_suspend();
|
|
}
|
|
|
|
auto start(when_all_latch& latch) noexcept -> void
|
|
{
|
|
m_latch = &latch;
|
|
coroutine_handle_type::from_promise(*this).resume();
|
|
}
|
|
|
|
auto return_value() & -> return_type&
|
|
{
|
|
if(m_exception_ptr)
|
|
{
|
|
std::rethrow_exception(m_exception_ptr);
|
|
}
|
|
return *m_return_value;
|
|
}
|
|
|
|
auto return_value() && -> return_type&&
|
|
{
|
|
if(m_exception_ptr)
|
|
{
|
|
std::rethrow_exception(m_exception_ptr);
|
|
}
|
|
return std::forward(*m_return_value);
|
|
}
|
|
|
|
private:
|
|
when_all_latch* m_latch{nullptr};
|
|
std::exception_ptr m_exception_ptr;
|
|
std::add_pointer_t<return_type> m_return_value;
|
|
};
|
|
|
|
template<>
|
|
class when_all_task_promise<void>
|
|
{
|
|
public:
|
|
using coroutine_handle_type = std::coroutine_handle<when_all_task_promise<void>>;
|
|
|
|
when_all_task_promise() noexcept
|
|
{}
|
|
|
|
auto get_return_object() noexcept
|
|
{
|
|
return coroutine_handle_type::from_promise(*this);
|
|
}
|
|
|
|
auto initial_suspend() noexcept -> std::suspend_always
|
|
{
|
|
return {};
|
|
}
|
|
|
|
auto final_suspend() noexcept
|
|
{
|
|
struct completion_notifier
|
|
{
|
|
auto await_ready() const noexcept -> bool { return false; }
|
|
auto await_suspend(coroutine_handle_type coroutine) const noexcept -> void
|
|
{
|
|
coroutine.promise().m_latch->notify_awaitable_completed();
|
|
}
|
|
auto await_resume() const noexcept -> void { }
|
|
};
|
|
|
|
return completion_notifier{};
|
|
}
|
|
|
|
auto unhandled_exception() noexcept -> void
|
|
{
|
|
m_exception_ptr = std::current_exception();
|
|
}
|
|
|
|
auto return_void() noexcept -> void
|
|
{}
|
|
|
|
auto start(when_all_latch& latch) -> void
|
|
{
|
|
m_latch = &latch;
|
|
coroutine_handle_type::from_promise(*this).resume();
|
|
}
|
|
|
|
auto return_value() -> void
|
|
{
|
|
if(m_exception_ptr)
|
|
{
|
|
std::rethrow_exception(m_exception_ptr);
|
|
}
|
|
}
|
|
private:
|
|
when_all_latch* m_latch{nullptr};
|
|
std::exception_ptr m_exception_ptr;
|
|
};
|
|
|
|
template<typename return_type>
|
|
class when_all_task
|
|
{
|
|
public:
|
|
// To be able to call start().
|
|
template<typename task_container_type>
|
|
friend class when_all_ready_awaitable;
|
|
|
|
using promise_type = when_all_task_promise<return_type>;
|
|
using coroutine_handle_type = typename promise_type::coroutine_handle_type;
|
|
|
|
when_all_task(coroutine_handle_type coroutine) noexcept
|
|
: m_coroutine(coroutine)
|
|
{}
|
|
|
|
when_all_task(const when_all_task&) = delete;
|
|
when_all_task(when_all_task&& other) noexcept
|
|
: m_coroutine(std::exchange(other.m_coroutine, coroutine_handle_type{}))
|
|
{}
|
|
|
|
auto operator=(const when_all_task&) -> when_all_task& = delete;
|
|
auto operator=(when_all_task&&) -> when_all_task& = delete;
|
|
|
|
~when_all_task()
|
|
{
|
|
if(m_coroutine != nullptr)
|
|
{
|
|
m_coroutine.destroy();
|
|
}
|
|
}
|
|
|
|
auto return_value() & -> decltype(auto)
|
|
{
|
|
if constexpr (std::is_void_v<return_type>)
|
|
{
|
|
m_coroutine.promise().return_void();
|
|
return void_value{};
|
|
}
|
|
else
|
|
{
|
|
return m_coroutine.promise().return_value();
|
|
}
|
|
}
|
|
|
|
auto return_value() const & -> decltype(auto)
|
|
{
|
|
if constexpr (std::is_void_v<return_type>)
|
|
{
|
|
m_coroutine.promise().return_void();
|
|
return void_value{};
|
|
}
|
|
else
|
|
{
|
|
return m_coroutine.promise().return_value();
|
|
}
|
|
}
|
|
|
|
auto return_value() && -> decltype(auto)
|
|
{
|
|
if constexpr (std::is_void_v<return_type>)
|
|
{
|
|
m_coroutine.promise().return_void();
|
|
return void_value{};
|
|
}
|
|
else
|
|
{
|
|
return m_coroutine.promise().return_value();
|
|
}
|
|
}
|
|
|
|
private:
|
|
auto start(when_all_latch& latch) noexcept -> void
|
|
{
|
|
m_coroutine.promise().start(latch);
|
|
}
|
|
|
|
coroutine_handle_type m_coroutine;
|
|
};
|
|
|
|
template<awaitable_type awaitable, typename return_type = awaitable_traits<awaitable&&>::awaiter_return_t>
|
|
static auto make_when_all_task(awaitable a) -> when_all_task<return_type>
|
|
{
|
|
if constexpr (std::is_void_v<return_type>)
|
|
{
|
|
co_await static_cast<awaitable&&>(a);
|
|
co_return;
|
|
}
|
|
else
|
|
{
|
|
co_yield co_await static_cast<awaitable&&>(a);
|
|
}
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
template<awaitable_type... awaitables_type>
|
|
[[nodiscard]] auto when_all_awaitable(awaitables_type&&... awaitables)
|
|
{
|
|
return
|
|
detail::when_all_ready_awaitable<
|
|
std::tuple<
|
|
detail::when_all_task<
|
|
typename awaitable_traits<awaitables_type>::awaiter_return_t
|
|
>...
|
|
>
|
|
>(std::make_tuple(detail::make_when_all_task(std::forward<awaitables_type>(awaitables))...));
|
|
}
|
|
|
|
template<awaitable_type awaitable, typename return_type = awaitable_traits<awaitable>::awaiter_return_t>
|
|
[[nodiscard]] auto when_all_awaitable(std::vector<awaitable>&& awaitables) -> detail::when_all_ready_awaitable<std::vector<detail::when_all_task<return_type>>>
|
|
{
|
|
std::vector<detail::when_all_task<return_type>> tasks;
|
|
tasks.reserve(std::size(awaitables));
|
|
|
|
for(auto& a : awaitables)
|
|
{
|
|
tasks.emplace_back(detail::make_when_all_task(std::move(a)));
|
|
}
|
|
|
|
return detail::when_all_ready_awaitable(std::move(tasks));
|
|
}
|
|
|
|
} // namespace coro
|