mirror of
https://gitlab.com/niansa/libcrosscoro.git
synced 2025-03-06 20:53:32 +01:00
task<void> working, task co_await task working
Turns out that the final_suspend() method is required to be std::suspend_always() otherwise the coroutine_handle<>.done() function will not trigger properly. Refactored the task class to allow the user to decide if they want to suspend at the beginning but it now forces a suspend at the end to guarantee that task.is_ready() will work properly.
This commit is contained in:
parent
fb04c43370
commit
4aa248cd17
6 changed files with 202 additions and 156 deletions
|
@ -19,7 +19,14 @@ set_target_properties(${PROJECT_NAME} PROPERTIES LINKER_LANGUAGE CXX)
|
|||
target_compile_features(${PROJECT_NAME} PUBLIC cxx_std_20)
|
||||
target_include_directories(${PROJECT_NAME} PUBLIC src)
|
||||
target_link_libraries(${PROJECT_NAME} PUBLIC zmq pthread)
|
||||
target_compile_options(${PROJECT_NAME} PUBLIC -fcoroutines)
|
||||
|
||||
|
||||
if(${CMAKE_CXX_COMPILER_ID} MATCHES "GNU")
|
||||
target_compile_options(${PROJECT_NAME} PUBLIC -fcoroutines)
|
||||
elseif(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
|
||||
target_compile_options(${PROJECT_NAME} PUBLIC -fcoroutines-ts)
|
||||
endif()
|
||||
|
||||
|
||||
if(CORO_BUILD_TESTS)
|
||||
if(CORO_CODE_COVERAGE)
|
||||
|
|
|
@ -14,49 +14,12 @@
|
|||
namespace coro
|
||||
{
|
||||
|
||||
// class message
|
||||
// {
|
||||
// public:
|
||||
// enum class type
|
||||
// {
|
||||
// new_web_request,
|
||||
// async_resume
|
||||
// };
|
||||
|
||||
// message() = default;
|
||||
// message(type t, int socket)
|
||||
// : m_type(t),
|
||||
// m_socket(socket)
|
||||
// {
|
||||
|
||||
// }
|
||||
// ~message() = default;
|
||||
|
||||
// type m_type;
|
||||
// int m_socket;
|
||||
// };
|
||||
|
||||
// class web_request
|
||||
// {
|
||||
// public:
|
||||
// web_request() = default;
|
||||
|
||||
// web_request(int socket) : m_socket(socket)
|
||||
// {
|
||||
|
||||
// }
|
||||
|
||||
// ~web_request() = default;
|
||||
// private:
|
||||
// int m_socket{0};
|
||||
// };
|
||||
|
||||
class engine
|
||||
{
|
||||
public:
|
||||
/// Always suspend at the start since the engine will call the first `resume()`.
|
||||
using task = coro::task<void, std::suspend_always>;
|
||||
using message = uint8_t;
|
||||
using task_type = coro::task<void>;
|
||||
using message_type = uint8_t;
|
||||
|
||||
engine()
|
||||
:
|
||||
|
@ -92,14 +55,14 @@ public:
|
|||
m_background_thread.join();
|
||||
}
|
||||
|
||||
auto submit_task(std::unique_ptr<task> t) -> bool
|
||||
auto submit_task(task_type t) -> bool
|
||||
{
|
||||
{
|
||||
std::lock_guard<std::mutex> lock{m_queued_tasks_mutex};
|
||||
m_queued_tasks.push_back(std::move(t));
|
||||
}
|
||||
|
||||
message msg = 1;
|
||||
message_type msg = 1;
|
||||
zmq::message_t zmq_msg{&msg, sizeof(msg)};
|
||||
|
||||
zmq::send_result_t result;
|
||||
|
@ -133,7 +96,6 @@ public:
|
|||
using namespace std::chrono_literals;
|
||||
|
||||
m_is_running = true;
|
||||
std::cerr << "running\n";
|
||||
|
||||
std::vector<zmq::pollitem_t> poll_items {
|
||||
zmq::pollitem_t{static_cast<void*>(m_async_recv_events_socket), 0, ZMQ_POLLIN, 0}
|
||||
|
@ -141,32 +103,27 @@ public:
|
|||
|
||||
while(!m_stop)
|
||||
{
|
||||
std::cerr << "polling\n";
|
||||
auto events = zmq::poll(poll_items, 1000ms);
|
||||
|
||||
if(events > 0)
|
||||
{
|
||||
while(true)
|
||||
{
|
||||
message msg;
|
||||
zmq::mutable_buffer buffer(static_cast<void*>(&msg), sizeof(message));
|
||||
message_type msg;
|
||||
zmq::mutable_buffer buffer(static_cast<void*>(&msg), sizeof(msg));
|
||||
auto result = m_async_recv_events_socket.recv(buffer, zmq::recv_flags::dontwait);
|
||||
|
||||
if(!result.has_value())
|
||||
{
|
||||
std::cerr << "result no value\n";
|
||||
// zmq returns 0 on no messages available
|
||||
break; // while(true)
|
||||
}
|
||||
else if(result.value().truncated())
|
||||
{
|
||||
std::cerr << "message received with incorrect size " << result.value().size << "\n";
|
||||
// let the task die? malformed message
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cerr << "message received\n";
|
||||
|
||||
std::vector<std::unique_ptr<task>> grabbed_tasks;
|
||||
std::vector<task_type> grabbed_tasks;
|
||||
{
|
||||
std::lock_guard<std::mutex> lock{m_queued_tasks_mutex};
|
||||
grabbed_tasks.swap(m_queued_tasks);
|
||||
|
@ -174,24 +131,28 @@ public:
|
|||
|
||||
for(auto& t : grabbed_tasks)
|
||||
{
|
||||
// start executing now
|
||||
t->resume();
|
||||
t.resume();
|
||||
|
||||
// if the task is awaiting then push into active tasks.
|
||||
if(!t->is_done())
|
||||
// if the task is still awaiting then push into active tasks.
|
||||
if(!t.is_ready())
|
||||
{
|
||||
m_active_tasks.push_back(std::move(t));
|
||||
}
|
||||
}
|
||||
m_active_tasks_count = m_active_tasks.size();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m_is_running = false;
|
||||
std::cerr << "stopping\n";
|
||||
}
|
||||
|
||||
/**
|
||||
* @return The number of active tasks still executing.
|
||||
*/
|
||||
auto size() const -> std::size_t { return m_active_tasks_count; }
|
||||
|
||||
private:
|
||||
static std::atomic<uint32_t> m_engine_id_counter;
|
||||
const uint32_t m_engine_id{m_engine_id_counter++};
|
||||
|
@ -206,8 +167,10 @@ private:
|
|||
std::thread m_background_thread;
|
||||
|
||||
std::mutex m_queued_tasks_mutex;
|
||||
std::vector<std::unique_ptr<task>> m_queued_tasks;
|
||||
std::vector<std::unique_ptr<task>> m_active_tasks;
|
||||
std::vector<task_type> m_queued_tasks;
|
||||
std::vector<task_type> m_active_tasks;
|
||||
|
||||
std::atomic<std::size_t> m_active_tasks_count{0};
|
||||
};
|
||||
|
||||
} // namespace coro
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#pragma once
|
||||
|
||||
#include <atomic>
|
||||
#include <coroutine>
|
||||
#include <optional>
|
||||
|
||||
|
@ -8,16 +9,14 @@ namespace coro
|
|||
|
||||
template<
|
||||
typename return_type = void,
|
||||
typename initial_suspend_type = std::suspend_never,
|
||||
typename final_suspend_type = std::suspend_never>
|
||||
typename initial_suspend_type = std::suspend_always>
|
||||
class task;
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
template<
|
||||
typename initial_suspend_type,
|
||||
typename final_suspend_type>
|
||||
typename initial_suspend_type>
|
||||
struct promise_base
|
||||
{
|
||||
promise_base() noexcept = default;
|
||||
|
@ -30,7 +29,7 @@ struct promise_base
|
|||
|
||||
auto final_suspend()
|
||||
{
|
||||
return final_suspend_type();
|
||||
return std::suspend_always();
|
||||
}
|
||||
|
||||
auto unhandled_exception() -> void
|
||||
|
@ -38,23 +37,17 @@ struct promise_base
|
|||
m_exception_ptr = std::current_exception();
|
||||
}
|
||||
|
||||
auto return_void() -> void
|
||||
{
|
||||
// no-op
|
||||
}
|
||||
|
||||
protected:
|
||||
std::optional<std::exception_ptr> m_exception_ptr;
|
||||
};
|
||||
|
||||
template<
|
||||
typename return_type,
|
||||
typename initial_suspend_type,
|
||||
typename final_suspend_type>
|
||||
struct promise : public promise_base<initial_suspend_type, final_suspend_type>
|
||||
typename initial_suspend_type>
|
||||
struct promise : public promise_base<initial_suspend_type>
|
||||
{
|
||||
using task_type = task<return_type, initial_suspend_type, final_suspend_type>;
|
||||
using coro_handle = std::coroutine_handle<promise<return_type, initial_suspend_type, final_suspend_type>>;
|
||||
using task_type = task<return_type, initial_suspend_type>;
|
||||
using coro_handle = std::coroutine_handle<promise<return_type, initial_suspend_type>>;
|
||||
|
||||
promise() noexcept = default;
|
||||
~promise() = default;
|
||||
|
@ -91,18 +84,19 @@ private:
|
|||
};
|
||||
|
||||
template<
|
||||
typename initial_suspend_type,
|
||||
typename final_suspend_type>
|
||||
struct promise<void, initial_suspend_type, final_suspend_type> : public promise_base<initial_suspend_type, final_suspend_type>
|
||||
typename initial_suspend_type>
|
||||
struct promise<void, initial_suspend_type> : public promise_base<initial_suspend_type>
|
||||
{
|
||||
using task_type = task<void, initial_suspend_type, final_suspend_type>;
|
||||
using coro_handle = std::coroutine_handle<promise<void, initial_suspend_type, final_suspend_type>>;
|
||||
using task_type = task<void, initial_suspend_type>;
|
||||
using coro_handle = std::coroutine_handle<promise<void, initial_suspend_type>>;
|
||||
|
||||
promise() noexcept = default;
|
||||
~promise() = default;
|
||||
|
||||
auto get_return_object() -> task_type;
|
||||
|
||||
auto return_void() -> void { }
|
||||
|
||||
auto result() const -> void
|
||||
{
|
||||
if(this->m_exception_ptr.has_value())
|
||||
|
@ -116,37 +110,51 @@ struct promise<void, initial_suspend_type, final_suspend_type> : public promise_
|
|||
|
||||
template<
|
||||
typename return_type,
|
||||
typename initial_suspend_type,
|
||||
typename final_suspend_type>
|
||||
typename initial_suspend_type>
|
||||
class task
|
||||
{
|
||||
public:
|
||||
using promise_type = detail::promise<return_type, initial_suspend_type, final_suspend_type>;
|
||||
using task_type = task<return_type, initial_suspend_type>;
|
||||
using promise_type = detail::promise<return_type, initial_suspend_type>;
|
||||
using coro_handle = std::coroutine_handle<promise_type>;
|
||||
|
||||
task() noexcept
|
||||
: m_handle(nullptr)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
task(coro_handle handle)
|
||||
: m_handle(handle)
|
||||
{
|
||||
|
||||
}
|
||||
task(const task&) = delete;
|
||||
task(task&&) = delete;
|
||||
task(task&& other) noexcept
|
||||
: m_handle(other.m_handle)
|
||||
{
|
||||
other.m_handle = nullptr;
|
||||
}
|
||||
|
||||
auto operator=(const task&) -> task& = delete;
|
||||
auto operator=(task&& other) -> task& = delete;
|
||||
// {
|
||||
// if(std::addressof(other) != this)
|
||||
// {
|
||||
// if(m_handle)
|
||||
// {
|
||||
// m_handle.destroy();
|
||||
// }
|
||||
auto operator=(task&& other) noexcept -> task&
|
||||
{
|
||||
if(std::addressof(other) != this)
|
||||
{
|
||||
if(m_handle)
|
||||
{
|
||||
m_handle.destroy();
|
||||
}
|
||||
|
||||
// m_handle = other.m_handle;
|
||||
// other.m_handle = nullptr;
|
||||
// }
|
||||
// }
|
||||
m_handle = other.m_handle;
|
||||
other.m_handle = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
auto is_done() const noexcept -> bool
|
||||
/**
|
||||
* @return True if the task is in its final suspend or if the task has been destroyed.
|
||||
*/
|
||||
auto is_ready() const noexcept -> bool
|
||||
{
|
||||
return m_handle == nullptr || m_handle.done();
|
||||
}
|
||||
|
@ -160,35 +168,47 @@ public:
|
|||
return !m_handle.done();
|
||||
}
|
||||
|
||||
auto destroy() -> bool
|
||||
{
|
||||
if(m_handle != nullptr)
|
||||
{
|
||||
m_handle.destroy();
|
||||
m_handle = nullptr;
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
struct awaiter
|
||||
{
|
||||
awaiter(coro_handle handle) noexcept
|
||||
: m_handle(handle)
|
||||
awaiter(const task_type& t) noexcept
|
||||
: m_task(t)
|
||||
{
|
||||
|
||||
}
|
||||
|
||||
auto await_ready() const noexcept -> bool
|
||||
{
|
||||
return !m_handle || m_handle.done();
|
||||
return m_task.is_ready();
|
||||
}
|
||||
|
||||
auto await_suspend(std::coroutine_handle<>) noexcept -> void
|
||||
{
|
||||
|
||||
// no-op, the handle passed in is the same as m_task.promise()
|
||||
}
|
||||
|
||||
auto await_resume() noexcept -> return_type
|
||||
{
|
||||
return m_handle.promise().result();
|
||||
return m_task.promise().result();
|
||||
}
|
||||
|
||||
coro_handle m_handle;
|
||||
const task_type& m_task;
|
||||
};
|
||||
|
||||
auto operator co_await() const noexcept -> awaiter
|
||||
{
|
||||
return awaiter(m_handle);
|
||||
return awaiter(*this);
|
||||
}
|
||||
|
||||
auto promise() const & -> const promise_type& { return m_handle.promise(); }
|
||||
|
@ -203,18 +223,16 @@ namespace detail
|
|||
|
||||
template<
|
||||
typename return_type,
|
||||
typename initial_suspend_type,
|
||||
typename final_suspend_type>
|
||||
auto promise<return_type, initial_suspend_type, final_suspend_type>::get_return_object()
|
||||
typename initial_suspend_type>
|
||||
auto promise<return_type, initial_suspend_type>::get_return_object()
|
||||
-> task_type
|
||||
{
|
||||
return coro_handle::from_promise(*this);
|
||||
}
|
||||
|
||||
template<
|
||||
typename initial_suspend_type,
|
||||
typename final_suspend_type>
|
||||
auto promise<void,initial_suspend_type, final_suspend_type>::get_return_object()
|
||||
typename initial_suspend_type>
|
||||
auto promise<void,initial_suspend_type>::get_return_object()
|
||||
-> task_type
|
||||
{
|
||||
return coro_handle::from_promise(*this);
|
||||
|
|
|
@ -10,14 +10,14 @@ auto mre_producer(coro::async_manual_reset_event<return_type>& event, return_typ
|
|||
{
|
||||
// simulate complicated background task
|
||||
using namespace std::chrono_literals;
|
||||
std::this_thread::sleep_for(100ms);
|
||||
std::this_thread::sleep_for(10ms);
|
||||
event.set(std::move(produced_value));
|
||||
}
|
||||
|
||||
template<typename return_type>
|
||||
auto mre_consumer(
|
||||
const coro::async_manual_reset_event<return_type>& event
|
||||
) -> coro::task<return_type, std::suspend_never, std::suspend_always>
|
||||
) -> coro::task<return_type, std::suspend_never>
|
||||
{
|
||||
co_await event;
|
||||
co_return event.return_value();
|
||||
|
|
|
@ -1,16 +1,53 @@
|
|||
// #include "catch.hpp"
|
||||
#include "catch.hpp"
|
||||
|
||||
// #include <coro/coro.hpp>
|
||||
#include <coro/coro.hpp>
|
||||
|
||||
// auto execute_task() -> coro::engine::task
|
||||
// {
|
||||
// std::cerr << "engine task successfully executed\n";
|
||||
// co_return;
|
||||
// }
|
||||
#include <thread>
|
||||
#include <chrono>
|
||||
|
||||
// TEST_CASE("engine submit one request")
|
||||
// {
|
||||
// coro::engine eng{};
|
||||
TEST_CASE("engine submit single task")
|
||||
{
|
||||
using namespace std::chrono_literals;
|
||||
using task_type = coro::engine::task_type;
|
||||
|
||||
// eng.submit_task(execute_task());
|
||||
// }
|
||||
coro::engine eng{};
|
||||
|
||||
std::atomic<uint64_t> counter{0};
|
||||
|
||||
auto task1 = [&]() -> task_type { counter++; co_return; }();
|
||||
|
||||
eng.submit_task(std::move(task1));
|
||||
while(counter != 1)
|
||||
{
|
||||
std::this_thread::sleep_for(10ms);
|
||||
}
|
||||
|
||||
REQUIRE(eng.size() == 0);
|
||||
}
|
||||
|
||||
TEST_CASE("engine submit mutiple tasks")
|
||||
{
|
||||
using namespace std::chrono_literals;
|
||||
using task_type = coro::engine::task_type;
|
||||
|
||||
coro::engine eng{};
|
||||
|
||||
std::atomic<uint64_t> counter{0};
|
||||
|
||||
auto func = [&]() -> task_type { counter++; co_return; };
|
||||
|
||||
auto task1 = func();
|
||||
auto task2 = func();
|
||||
auto task3 = func();
|
||||
|
||||
eng.submit_task(std::move(task1));
|
||||
eng.submit_task(std::move(task2));
|
||||
eng.submit_task(std::move(task3));
|
||||
while(counter != 3)
|
||||
{
|
||||
std::this_thread::sleep_for(10ms);
|
||||
}
|
||||
|
||||
// Make sure every task is also destroyed since they have completed.
|
||||
REQUIRE(eng.size() == 0);
|
||||
}
|
|
@ -5,32 +5,13 @@
|
|||
#include <chrono>
|
||||
#include <thread>
|
||||
|
||||
static auto hello() -> coro::task<std::string, std::suspend_always>
|
||||
{
|
||||
co_return "Hello";
|
||||
}
|
||||
|
||||
static auto world() -> coro::task<std::string, std::suspend_always>
|
||||
{
|
||||
co_return "World";
|
||||
}
|
||||
|
||||
static auto void_task() -> coro::task<void, std::suspend_always>
|
||||
{
|
||||
co_return;
|
||||
}
|
||||
|
||||
static auto throws_exception() -> coro::task<std::string, std::suspend_always>
|
||||
{
|
||||
co_await std::suspend_always();
|
||||
throw std::runtime_error("I'll be reached");
|
||||
co_return "I'll never be reached";
|
||||
}
|
||||
|
||||
TEST_CASE("hello world task")
|
||||
{
|
||||
auto h = hello();
|
||||
auto w = world();
|
||||
using task_type = coro::task<std::string>;
|
||||
|
||||
auto h = []() -> task_type { co_return "Hello"; }();
|
||||
auto w = []() -> task_type { co_return "World"; }();
|
||||
|
||||
REQUIRE(h.promise().result().empty());
|
||||
REQUIRE(w.promise().result().empty());
|
||||
|
@ -38,6 +19,9 @@ TEST_CASE("hello world task")
|
|||
h.resume(); // task suspends immediately
|
||||
w.resume();
|
||||
|
||||
REQUIRE(h.is_ready());
|
||||
REQUIRE(w.is_ready());
|
||||
|
||||
auto w_value = std::move(w).promise().result();
|
||||
|
||||
REQUIRE(h.promise().result() == "Hello");
|
||||
|
@ -45,25 +29,62 @@ TEST_CASE("hello world task")
|
|||
REQUIRE(w.promise().result().empty());
|
||||
}
|
||||
|
||||
// This currently won't report as is_done(), not sure why yet...
|
||||
// TEST_CASE("void task")
|
||||
// {
|
||||
// auto task = void_task();
|
||||
// task.resume();
|
||||
TEST_CASE("void task")
|
||||
{
|
||||
using namespace std::chrono_literals;
|
||||
using task_type = coro::task<void>;
|
||||
|
||||
// REQUIRE(task.is_done());
|
||||
// }
|
||||
auto t = []() -> task_type {
|
||||
std::this_thread::sleep_for(10ms);
|
||||
co_return;
|
||||
}();
|
||||
t.resume();
|
||||
|
||||
REQUIRE(t.is_ready());
|
||||
}
|
||||
|
||||
TEST_CASE("Exception thrown")
|
||||
{
|
||||
auto task = throws_exception();
|
||||
using task_type = coro::task<std::string>;
|
||||
|
||||
std::string throw_msg = "I'll be reached";
|
||||
|
||||
auto task = [&]() -> task_type {
|
||||
throw std::runtime_error(throw_msg);
|
||||
co_return "I'll never be reached";
|
||||
}();
|
||||
|
||||
task.resume();
|
||||
|
||||
REQUIRE(task.is_ready());
|
||||
|
||||
bool thrown{false};
|
||||
try
|
||||
{
|
||||
task.resume();
|
||||
auto value = task.promise().result();
|
||||
}
|
||||
catch(const std::exception& e)
|
||||
{
|
||||
REQUIRE(e.what() == "I'll be reached");
|
||||
thrown = true;
|
||||
REQUIRE(e.what() == throw_msg);
|
||||
}
|
||||
|
||||
REQUIRE(thrown);
|
||||
}
|
||||
|
||||
TEST_CASE("Task in a task")
|
||||
{
|
||||
auto inner_task = []() -> coro::task<int> {
|
||||
std::cerr << "inner_task start\n";
|
||||
std::cerr << "inner_task stop\n";
|
||||
co_return 42;
|
||||
};
|
||||
auto outer_task = [&]() -> coro::task<> {
|
||||
std::cerr << "outer_task start\n";
|
||||
auto v = co_await inner_task();
|
||||
REQUIRE(v == 42);
|
||||
std::cerr << "outer_task stop\n";
|
||||
}();
|
||||
|
||||
outer_task.resume();
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue