1
0
Fork 0
mirror of https://gitlab.com/niansa/libcrosscoro.git synced 2025-03-06 20:53:32 +01:00

task<void> working, task co_await task working

Turns out that the final_suspend() method is required
to be std::suspend_always() otherwise the coroutine_handle<>.done()
function will not trigger properly.  Refactored the task class
to allow the user to decide if they want to suspend at the beginning
but it now forces a suspend at the end to guarantee that
task.is_ready() will work properly.
This commit is contained in:
jbaldwin 2020-09-08 22:44:38 -06:00
parent fb04c43370
commit 4aa248cd17
6 changed files with 202 additions and 156 deletions

View file

@ -19,7 +19,14 @@ set_target_properties(${PROJECT_NAME} PROPERTIES LINKER_LANGUAGE CXX)
target_compile_features(${PROJECT_NAME} PUBLIC cxx_std_20)
target_include_directories(${PROJECT_NAME} PUBLIC src)
target_link_libraries(${PROJECT_NAME} PUBLIC zmq pthread)
target_compile_options(${PROJECT_NAME} PUBLIC -fcoroutines)
if(${CMAKE_CXX_COMPILER_ID} MATCHES "GNU")
target_compile_options(${PROJECT_NAME} PUBLIC -fcoroutines)
elseif(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
target_compile_options(${PROJECT_NAME} PUBLIC -fcoroutines-ts)
endif()
if(CORO_BUILD_TESTS)
if(CORO_CODE_COVERAGE)

View file

@ -14,49 +14,12 @@
namespace coro
{
// class message
// {
// public:
// enum class type
// {
// new_web_request,
// async_resume
// };
// message() = default;
// message(type t, int socket)
// : m_type(t),
// m_socket(socket)
// {
// }
// ~message() = default;
// type m_type;
// int m_socket;
// };
// class web_request
// {
// public:
// web_request() = default;
// web_request(int socket) : m_socket(socket)
// {
// }
// ~web_request() = default;
// private:
// int m_socket{0};
// };
class engine
{
public:
/// Always suspend at the start since the engine will call the first `resume()`.
using task = coro::task<void, std::suspend_always>;
using message = uint8_t;
using task_type = coro::task<void>;
using message_type = uint8_t;
engine()
:
@ -92,14 +55,14 @@ public:
m_background_thread.join();
}
auto submit_task(std::unique_ptr<task> t) -> bool
auto submit_task(task_type t) -> bool
{
{
std::lock_guard<std::mutex> lock{m_queued_tasks_mutex};
m_queued_tasks.push_back(std::move(t));
}
message msg = 1;
message_type msg = 1;
zmq::message_t zmq_msg{&msg, sizeof(msg)};
zmq::send_result_t result;
@ -133,7 +96,6 @@ public:
using namespace std::chrono_literals;
m_is_running = true;
std::cerr << "running\n";
std::vector<zmq::pollitem_t> poll_items {
zmq::pollitem_t{static_cast<void*>(m_async_recv_events_socket), 0, ZMQ_POLLIN, 0}
@ -141,32 +103,27 @@ public:
while(!m_stop)
{
std::cerr << "polling\n";
auto events = zmq::poll(poll_items, 1000ms);
if(events > 0)
{
while(true)
{
message msg;
zmq::mutable_buffer buffer(static_cast<void*>(&msg), sizeof(message));
message_type msg;
zmq::mutable_buffer buffer(static_cast<void*>(&msg), sizeof(msg));
auto result = m_async_recv_events_socket.recv(buffer, zmq::recv_flags::dontwait);
if(!result.has_value())
{
std::cerr << "result no value\n";
// zmq returns 0 on no messages available
break; // while(true)
}
else if(result.value().truncated())
{
std::cerr << "message received with incorrect size " << result.value().size << "\n";
// let the task die? malformed message
}
else
{
std::cerr << "message received\n";
std::vector<std::unique_ptr<task>> grabbed_tasks;
std::vector<task_type> grabbed_tasks;
{
std::lock_guard<std::mutex> lock{m_queued_tasks_mutex};
grabbed_tasks.swap(m_queued_tasks);
@ -174,24 +131,28 @@ public:
for(auto& t : grabbed_tasks)
{
// start executing now
t->resume();
t.resume();
// if the task is awaiting then push into active tasks.
if(!t->is_done())
// if the task is still awaiting then push into active tasks.
if(!t.is_ready())
{
m_active_tasks.push_back(std::move(t));
}
}
m_active_tasks_count = m_active_tasks.size();
}
}
}
}
m_is_running = false;
std::cerr << "stopping\n";
}
/**
* @return The number of active tasks still executing.
*/
auto size() const -> std::size_t { return m_active_tasks_count; }
private:
static std::atomic<uint32_t> m_engine_id_counter;
const uint32_t m_engine_id{m_engine_id_counter++};
@ -206,8 +167,10 @@ private:
std::thread m_background_thread;
std::mutex m_queued_tasks_mutex;
std::vector<std::unique_ptr<task>> m_queued_tasks;
std::vector<std::unique_ptr<task>> m_active_tasks;
std::vector<task_type> m_queued_tasks;
std::vector<task_type> m_active_tasks;
std::atomic<std::size_t> m_active_tasks_count{0};
};
} // namespace coro

View file

@ -1,5 +1,6 @@
#pragma once
#include <atomic>
#include <coroutine>
#include <optional>
@ -8,16 +9,14 @@ namespace coro
template<
typename return_type = void,
typename initial_suspend_type = std::suspend_never,
typename final_suspend_type = std::suspend_never>
typename initial_suspend_type = std::suspend_always>
class task;
namespace detail
{
template<
typename initial_suspend_type,
typename final_suspend_type>
typename initial_suspend_type>
struct promise_base
{
promise_base() noexcept = default;
@ -30,7 +29,7 @@ struct promise_base
auto final_suspend()
{
return final_suspend_type();
return std::suspend_always();
}
auto unhandled_exception() -> void
@ -38,23 +37,17 @@ struct promise_base
m_exception_ptr = std::current_exception();
}
auto return_void() -> void
{
// no-op
}
protected:
std::optional<std::exception_ptr> m_exception_ptr;
};
template<
typename return_type,
typename initial_suspend_type,
typename final_suspend_type>
struct promise : public promise_base<initial_suspend_type, final_suspend_type>
typename initial_suspend_type>
struct promise : public promise_base<initial_suspend_type>
{
using task_type = task<return_type, initial_suspend_type, final_suspend_type>;
using coro_handle = std::coroutine_handle<promise<return_type, initial_suspend_type, final_suspend_type>>;
using task_type = task<return_type, initial_suspend_type>;
using coro_handle = std::coroutine_handle<promise<return_type, initial_suspend_type>>;
promise() noexcept = default;
~promise() = default;
@ -91,18 +84,19 @@ private:
};
template<
typename initial_suspend_type,
typename final_suspend_type>
struct promise<void, initial_suspend_type, final_suspend_type> : public promise_base<initial_suspend_type, final_suspend_type>
typename initial_suspend_type>
struct promise<void, initial_suspend_type> : public promise_base<initial_suspend_type>
{
using task_type = task<void, initial_suspend_type, final_suspend_type>;
using coro_handle = std::coroutine_handle<promise<void, initial_suspend_type, final_suspend_type>>;
using task_type = task<void, initial_suspend_type>;
using coro_handle = std::coroutine_handle<promise<void, initial_suspend_type>>;
promise() noexcept = default;
~promise() = default;
auto get_return_object() -> task_type;
auto return_void() -> void { }
auto result() const -> void
{
if(this->m_exception_ptr.has_value())
@ -116,37 +110,51 @@ struct promise<void, initial_suspend_type, final_suspend_type> : public promise_
template<
typename return_type,
typename initial_suspend_type,
typename final_suspend_type>
typename initial_suspend_type>
class task
{
public:
using promise_type = detail::promise<return_type, initial_suspend_type, final_suspend_type>;
using task_type = task<return_type, initial_suspend_type>;
using promise_type = detail::promise<return_type, initial_suspend_type>;
using coro_handle = std::coroutine_handle<promise_type>;
task() noexcept
: m_handle(nullptr)
{
}
task(coro_handle handle)
: m_handle(handle)
{
}
task(const task&) = delete;
task(task&&) = delete;
task(task&& other) noexcept
: m_handle(other.m_handle)
{
other.m_handle = nullptr;
}
auto operator=(const task&) -> task& = delete;
auto operator=(task&& other) -> task& = delete;
// {
// if(std::addressof(other) != this)
// {
// if(m_handle)
// {
// m_handle.destroy();
// }
auto operator=(task&& other) noexcept -> task&
{
if(std::addressof(other) != this)
{
if(m_handle)
{
m_handle.destroy();
}
// m_handle = other.m_handle;
// other.m_handle = nullptr;
// }
// }
m_handle = other.m_handle;
other.m_handle = nullptr;
}
}
auto is_done() const noexcept -> bool
/**
* @return True if the task is in its final suspend or if the task has been destroyed.
*/
auto is_ready() const noexcept -> bool
{
return m_handle == nullptr || m_handle.done();
}
@ -160,35 +168,47 @@ public:
return !m_handle.done();
}
auto destroy() -> bool
{
if(m_handle != nullptr)
{
m_handle.destroy();
m_handle = nullptr;
return true;
}
return false;
}
struct awaiter
{
awaiter(coro_handle handle) noexcept
: m_handle(handle)
awaiter(const task_type& t) noexcept
: m_task(t)
{
}
auto await_ready() const noexcept -> bool
{
return !m_handle || m_handle.done();
return m_task.is_ready();
}
auto await_suspend(std::coroutine_handle<>) noexcept -> void
{
// no-op, the handle passed in is the same as m_task.promise()
}
auto await_resume() noexcept -> return_type
{
return m_handle.promise().result();
return m_task.promise().result();
}
coro_handle m_handle;
const task_type& m_task;
};
auto operator co_await() const noexcept -> awaiter
{
return awaiter(m_handle);
return awaiter(*this);
}
auto promise() const & -> const promise_type& { return m_handle.promise(); }
@ -203,18 +223,16 @@ namespace detail
template<
typename return_type,
typename initial_suspend_type,
typename final_suspend_type>
auto promise<return_type, initial_suspend_type, final_suspend_type>::get_return_object()
typename initial_suspend_type>
auto promise<return_type, initial_suspend_type>::get_return_object()
-> task_type
{
return coro_handle::from_promise(*this);
}
template<
typename initial_suspend_type,
typename final_suspend_type>
auto promise<void,initial_suspend_type, final_suspend_type>::get_return_object()
typename initial_suspend_type>
auto promise<void,initial_suspend_type>::get_return_object()
-> task_type
{
return coro_handle::from_promise(*this);

View file

@ -10,14 +10,14 @@ auto mre_producer(coro::async_manual_reset_event<return_type>& event, return_typ
{
// simulate complicated background task
using namespace std::chrono_literals;
std::this_thread::sleep_for(100ms);
std::this_thread::sleep_for(10ms);
event.set(std::move(produced_value));
}
template<typename return_type>
auto mre_consumer(
const coro::async_manual_reset_event<return_type>& event
) -> coro::task<return_type, std::suspend_never, std::suspend_always>
) -> coro::task<return_type, std::suspend_never>
{
co_await event;
co_return event.return_value();

View file

@ -1,16 +1,53 @@
// #include "catch.hpp"
#include "catch.hpp"
// #include <coro/coro.hpp>
#include <coro/coro.hpp>
// auto execute_task() -> coro::engine::task
// {
// std::cerr << "engine task successfully executed\n";
// co_return;
// }
#include <thread>
#include <chrono>
// TEST_CASE("engine submit one request")
// {
// coro::engine eng{};
TEST_CASE("engine submit single task")
{
using namespace std::chrono_literals;
using task_type = coro::engine::task_type;
// eng.submit_task(execute_task());
// }
coro::engine eng{};
std::atomic<uint64_t> counter{0};
auto task1 = [&]() -> task_type { counter++; co_return; }();
eng.submit_task(std::move(task1));
while(counter != 1)
{
std::this_thread::sleep_for(10ms);
}
REQUIRE(eng.size() == 0);
}
TEST_CASE("engine submit mutiple tasks")
{
using namespace std::chrono_literals;
using task_type = coro::engine::task_type;
coro::engine eng{};
std::atomic<uint64_t> counter{0};
auto func = [&]() -> task_type { counter++; co_return; };
auto task1 = func();
auto task2 = func();
auto task3 = func();
eng.submit_task(std::move(task1));
eng.submit_task(std::move(task2));
eng.submit_task(std::move(task3));
while(counter != 3)
{
std::this_thread::sleep_for(10ms);
}
// Make sure every task is also destroyed since they have completed.
REQUIRE(eng.size() == 0);
}

View file

@ -5,32 +5,13 @@
#include <chrono>
#include <thread>
static auto hello() -> coro::task<std::string, std::suspend_always>
{
co_return "Hello";
}
static auto world() -> coro::task<std::string, std::suspend_always>
{
co_return "World";
}
static auto void_task() -> coro::task<void, std::suspend_always>
{
co_return;
}
static auto throws_exception() -> coro::task<std::string, std::suspend_always>
{
co_await std::suspend_always();
throw std::runtime_error("I'll be reached");
co_return "I'll never be reached";
}
TEST_CASE("hello world task")
{
auto h = hello();
auto w = world();
using task_type = coro::task<std::string>;
auto h = []() -> task_type { co_return "Hello"; }();
auto w = []() -> task_type { co_return "World"; }();
REQUIRE(h.promise().result().empty());
REQUIRE(w.promise().result().empty());
@ -38,6 +19,9 @@ TEST_CASE("hello world task")
h.resume(); // task suspends immediately
w.resume();
REQUIRE(h.is_ready());
REQUIRE(w.is_ready());
auto w_value = std::move(w).promise().result();
REQUIRE(h.promise().result() == "Hello");
@ -45,25 +29,62 @@ TEST_CASE("hello world task")
REQUIRE(w.promise().result().empty());
}
// This currently won't report as is_done(), not sure why yet...
// TEST_CASE("void task")
// {
// auto task = void_task();
// task.resume();
TEST_CASE("void task")
{
using namespace std::chrono_literals;
using task_type = coro::task<void>;
// REQUIRE(task.is_done());
// }
auto t = []() -> task_type {
std::this_thread::sleep_for(10ms);
co_return;
}();
t.resume();
REQUIRE(t.is_ready());
}
TEST_CASE("Exception thrown")
{
auto task = throws_exception();
using task_type = coro::task<std::string>;
std::string throw_msg = "I'll be reached";
auto task = [&]() -> task_type {
throw std::runtime_error(throw_msg);
co_return "I'll never be reached";
}();
task.resume();
REQUIRE(task.is_ready());
bool thrown{false};
try
{
task.resume();
auto value = task.promise().result();
}
catch(const std::exception& e)
{
REQUIRE(e.what() == "I'll be reached");
thrown = true;
REQUIRE(e.what() == throw_msg);
}
REQUIRE(thrown);
}
TEST_CASE("Task in a task")
{
auto inner_task = []() -> coro::task<int> {
std::cerr << "inner_task start\n";
std::cerr << "inner_task stop\n";
co_return 42;
};
auto outer_task = [&]() -> coro::task<> {
std::cerr << "outer_task start\n";
auto v = co_await inner_task();
REQUIRE(v == 42);
std::cerr << "outer_task stop\n";
}();
outer_task.resume();
}