1
0
Fork 0
mirror of https://gitlab.com/niansa/cosched.git synced 2025-03-06 20:53:26 +01:00
cosched/basic-coro/include/AwaitableTask.hpp
2023-05-06 01:13:50 +02:00

195 lines
4.3 KiB
C++

#pragma once
#include <concepts>
#include <coroutine>
#include <exception>
#include <stdexcept>
#include <utility>
namespace basiccoro
{
namespace detail
{
template<class Derived>
struct PromiseBase
{
auto get_return_object() { return std::coroutine_handle<Derived>::from_promise(static_cast<Derived&>(*this)); }
void unhandled_exception() { std::terminate(); }
};
template<class Derived, class T> requires std::movable<T> || std::same_as<T, void>
struct ValuePromise : public PromiseBase<Derived>
{
using value_type = T;
T val;
void return_value(T t) { val = std::move(t); }
};
template<class Derived>
struct ValuePromise<Derived, void> : public PromiseBase<Derived>
{
using value_type = void;
void return_void() {}
};
template<class T>
class AwaitablePromise : public ValuePromise<AwaitablePromise<T>, T>
{
public:
auto initial_suspend() { return std::suspend_never(); }
auto final_suspend() noexcept
{
if (waiting_)
{
waiting_.resume();
if (waiting_.done())
{
waiting_.destroy();
}
waiting_ = nullptr;
}
return std::suspend_always();
}
void storeWaiting(std::coroutine_handle<> handle)
{
if (waiting_)
{
throw std::runtime_error("AwaitablePromise::storeWaiting(): already waiting");
}
waiting_ = handle;
}
~AwaitablePromise()
{
if (waiting_)
{
waiting_.destroy();
}
}
private:
std::coroutine_handle<> waiting_ = nullptr;
};
template<class Promise>
class TaskBase
{
public:
using promise_type = Promise;
TaskBase();
TaskBase(std::coroutine_handle<promise_type> handle);
TaskBase(const TaskBase&) = delete;
TaskBase(TaskBase&&);
TaskBase& operator=(const TaskBase&) = delete;
TaskBase& operator=(TaskBase&&);
~TaskBase();
bool done() const { return handle_.done(); }
protected:
std::coroutine_handle<promise_type> handle_;
bool handleShouldBeDestroyed_;
};
template<class Promise>
TaskBase<Promise>::TaskBase()
: handle_(nullptr), handleShouldBeDestroyed_(false)
{}
template<class Promise>
TaskBase<Promise>::TaskBase(std::coroutine_handle<promise_type> handle)
: handle_(handle)
{
// TODO: this whole system needs revamping with something like UniqueCoroutineHandle
// and custom static interface to awaiter types - so await_suspend method would take in UniqueCoroutineHandle
if (handle.done())
{
// it is resonable to expect that if the coroutine is done before
// the task creation, then the original stack is continued without suspending,
// and coroutine needs to be destroyed with TaskBase object
handleShouldBeDestroyed_ = true;
}
else
{
// otherwise the coroutine should be managed by object that it is awaiting
handleShouldBeDestroyed_ = false;
}
}
template<class Promise>
TaskBase<Promise>::TaskBase(TaskBase&& other)
: handle_(other.handle_), handleShouldBeDestroyed_(std::exchange(other.handleShouldBeDestroyed_, false))
{
}
template<class Promise>
TaskBase<Promise>& TaskBase<Promise>::operator=(TaskBase&& other)
{
handle_ = other.handle_;
handleShouldBeDestroyed_ = std::exchange(other.handleShouldBeDestroyed_, false);
return *this;
}
template<class Promise>
TaskBase<Promise>::~TaskBase()
{
if (handleShouldBeDestroyed_)
{
handle_.destroy();
}
}
} // namespace detail
template<class T>
class AwaitableTask : public detail::TaskBase<detail::AwaitablePromise<T>>
{
using Base = detail::TaskBase<detail::AwaitablePromise<T>>;
public:
using Base::Base;
struct awaiter;
friend struct awaiter;
awaiter operator co_await() const;
};
template<class T>
struct AwaitableTask<T>::awaiter
{
bool await_ready()
{
return task_.done();
}
template<class Promise>
void await_suspend(std::coroutine_handle<Promise> handle)
{
task_.handle_.promise().storeWaiting(handle);
}
T await_resume()
{
if constexpr (!std::is_same_v<void, T>)
{
return std::move(task_.handle_.promise().val);
}
}
const AwaitableTask& task_;
};
template<class T>
typename AwaitableTask<T>::awaiter AwaitableTask<T>::operator co_await() const
{
return awaiter{*this};
}
} // namespace basiccoro