1
0
Fork 0
mirror of https://gitlab.com/niansa/libcrosscoro.git synced 2025-03-06 20:53:32 +01:00

add tcp_scheduler (#18)

Closes #17
This commit is contained in:
Josh Baldwin 2020-11-01 18:46:41 -07:00 committed by GitHub
parent ddd3c76c53
commit 1c7b340c72
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 568 additions and 29 deletions

View file

@ -1,6 +1,6 @@
name: build
on: [push]
on: [pull_request, push]
jobs:
build-ubuntu-20-04:

View file

@ -16,10 +16,13 @@ set(LIBCORO_SOURCE_FILES
inc/coro/generator.hpp
inc/coro/io_scheduler.hpp
inc/coro/latch.hpp
inc/coro/poll.hpp
inc/coro/promise.hpp
inc/coro/shutdown.hpp
inc/coro/socket.hpp
inc/coro/sync_wait.hpp src/sync_wait.cpp
inc/coro/task.hpp
inc/coro/tcp_scheduler.hpp
inc/coro/thread_pool.hpp src/thread_pool.cpp
inc/coro/when_all.hpp
)

View file

@ -15,14 +15,13 @@ namespace coro
* await_resume() -> decltype(auto)
* Where the return type on await_resume is the requested return of the awaitable.
*/
// clang-format off
template<typename type>
concept awaiter = requires(type t, std::coroutine_handle<> c)
{
{
t.await_ready()
}
->std::same_as<bool>;
std::same_as<decltype(t.await_suspend(c)), void> || std::same_as<decltype(t.await_suspend(c)), bool> ||
{ t.await_ready() } -> std::same_as<bool>;
std::same_as<decltype(t.await_suspend(c)), void> ||
std::same_as<decltype(t.await_suspend(c)), bool> ||
std::same_as<decltype(t.await_suspend(c)), std::coroutine_handle<>>;
{t.await_resume()};
};
@ -34,10 +33,24 @@ template<typename type>
concept awaitable = requires(type t)
{
// operator co_await()
{
t.operator co_await()
}
->awaiter;
{ t.operator co_await() } -> awaiter;
};
template<typename type>
concept awaiter_void = requires(type t, std::coroutine_handle<> c)
{
{ t.await_ready() } -> std::same_as<bool>;
std::same_as<decltype(t.await_suspend(c)), void> ||
std::same_as<decltype(t.await_suspend(c)), bool> ||
std::same_as<decltype(t.await_suspend(c)), std::coroutine_handle<>>;
{t.await_resume()} -> std::same_as<void>;
};
template<typename type>
concept awaitable_void = requires(type t)
{
// operator co_await()
{ t.operator co_await() } -> awaiter_void;
};
template<awaitable awaitable, typename = void>
@ -57,5 +70,6 @@ struct awaitable_traits<awaitable>
using awaiter_type = decltype(get_awaiter(std::declval<awaitable>()));
using awaiter_return_type = decltype(std::declval<awaiter_type>().await_resume());
};
// clang-format on
} // namespace coro

View file

@ -8,5 +8,6 @@
#include "coro/promise.hpp"
#include "coro/sync_wait.hpp"
#include "coro/task.hpp"
#include "coro/tcp_scheduler.hpp"
#include "coro/thread_pool.hpp"
#include "coro/when_all.hpp"

View file

@ -1,5 +1,6 @@
#pragma once
#include "coro/poll.hpp"
#include "coro/shutdown.hpp"
#include "coro/task.hpp"
@ -162,16 +163,6 @@ public:
auto resume() noexcept -> void;
};
enum class poll_op
{
/// Poll for read operations.
read = EPOLLIN,
/// Poll for write operations.
write = EPOLLOUT,
/// Poll for read and write operations.
read_write = EPOLLIN | EPOLLOUT
};
class io_scheduler
{
private:
@ -382,7 +373,7 @@ public:
auto operator=(const io_scheduler&) -> io_scheduler& = delete;
auto operator=(io_scheduler &&) -> io_scheduler& = delete;
~io_scheduler()
virtual ~io_scheduler()
{
shutdown();
if (m_epoll_fd != -1)
@ -405,7 +396,7 @@ public:
*/
auto schedule(coro::task<void> task) -> bool
{
if (m_shutdown_requested.load(std::memory_order::relaxed))
if (is_shutdown())
{
return false;
}
@ -434,6 +425,58 @@ public:
return true;
}
template<awaitable_void... tasks_type>
auto schedule(tasks_type&&... tasks) -> bool
{
if (is_shutdown())
{
return false;
}
m_size.fetch_add(sizeof...(tasks), std::memory_order::relaxed);
{
std::lock_guard<std::mutex> lk{m_accept_mutex};
((m_accept_queue.emplace_back(std::forward<tasks_type>(tasks))), ...);
}
bool expected{false};
if (m_event_set.compare_exchange_strong(expected, true, std::memory_order::release, std::memory_order::relaxed))
{
uint64_t value{1};
::write(m_accept_fd, &value, sizeof(value));
}
return true;
}
auto schedule(std::vector<task<void>>& tasks)
{
if (is_shutdown())
{
return false;
}
m_size.fetch_add(tasks.size(), std::memory_order::relaxed);
{
std::lock_guard<std::mutex> lk{m_accept_mutex};
m_accept_queue.insert(
m_accept_queue.end(), std::make_move_iterator(tasks.begin()), std::make_move_iterator(tasks.end()));
// std::move(tasks.begin(), tasks.end(), std::back_inserter(m_accept_queue));
}
tasks.clear();
bool expected{false};
if (m_event_set.compare_exchange_strong(expected, true, std::memory_order::release, std::memory_order::relaxed))
{
uint64_t value{1};
::write(m_accept_fd, &value, sizeof(value));
}
return true;
}
/**
* Schedules a task to be run after waiting for a certain period of time.
* @param task The task to schedule after waiting `after` amount of time.
@ -459,7 +502,7 @@ public:
{
co_await unsafe_yield<void>([&](resume_token<void>& token) {
epoll_event e{};
e.events = static_cast<uint32_t>(op) | EPOLLONESHOT | EPOLLET;
e.events = static_cast<uint32_t>(op) | EPOLLONESHOT | EPOLLET | EPOLLRDHUP;
e.data.ptr = &token;
epoll_ctl(m_epoll_fd, EPOLL_CTL_ADD, fd, &e);
});
@ -477,8 +520,15 @@ public:
*/
auto read(fd_t fd, std::span<char> buffer) -> coro::task<ssize_t>
{
co_await poll(fd, poll_op::read);
/*auto status =*/co_await poll(fd, poll_op::read);
co_return ::read(fd, buffer.data(), buffer.size());
// switch(status)
// {
// case poll_status::success:
// co_return ::read(fd, buffer.data(), buffer.size());
// default:
// co_return 0;
// }
}
/**
@ -490,8 +540,15 @@ public:
*/
auto write(fd_t fd, const std::span<const char> buffer) -> coro::task<ssize_t>
{
co_await poll(fd, poll_op::write);
/*auto status =*/co_await poll(fd, poll_op::write);
co_return ::write(fd, buffer.data(), buffer.size());
// switch(status)
// {
// case poll_status::success:
// co_return ::write(fd, buffer.data(), buffer.size());
// default:
// co_return 0;
// }
}
/**
@ -631,7 +688,7 @@ public:
* the scheduler to shutdown but not wait for all tasks to complete, it returns
* immediately.
*/
auto shutdown(shutdown_t wait_for_tasks = shutdown_t::sync) -> void
virtual auto shutdown(shutdown_t wait_for_tasks = shutdown_t::sync) -> void
{
if (!m_shutdown_requested.exchange(true, std::memory_order::release))
{

29
inc/coro/poll.hpp Normal file
View file

@ -0,0 +1,29 @@
#pragma once
#include <sys/epoll.h>
namespace coro
{
enum class poll_op
{
/// Poll for read operations.
read = EPOLLIN,
/// Poll for write operations.
write = EPOLLOUT,
/// Poll for read and write operations.
read_write = EPOLLIN | EPOLLOUT
};
enum class poll_status
{
/// The poll operation was was successful.
success,
/// The poll operation timed out.
timeout,
/// The file descriptor had an error while polling.
error,
/// The file descriptor has been closed by the remote or an internal error/close.
closed
};
} // namespace coro

199
inc/coro/socket.hpp Normal file
View file

@ -0,0 +1,199 @@
#pragma once
#include "coro/poll.hpp"
#include <arpa/inet.h>
#include <fcntl.h>
#include <utility>
#include <iostream>
namespace coro
{
class socket
{
public:
enum class domain_t
{
ipv4,
ipv6
};
enum class type_t
{
udp,
tcp
};
enum class blocking_t
{
yes,
no
};
struct options
{
domain_t domain;
type_t type;
blocking_t blocking;
};
static auto domain_to_os(const domain_t& domain) -> int
{
switch (domain)
{
case domain_t::ipv4:
return AF_INET;
case domain_t::ipv6:
return AF_INET6;
default:
throw std::runtime_error{"Unknown socket::domain_t."};
}
}
static auto type_to_os(const type_t& type) -> int
{
switch (type)
{
case type_t::udp:
return SOCK_DGRAM;
case type_t::tcp:
return SOCK_STREAM;
default:
throw std::runtime_error{"Unknown socket::type_t."};
}
}
static auto make_socket(const options& opts) -> socket
{
socket s{::socket(domain_to_os(opts.domain), type_to_os(opts.type), 0)};
if (s.native_handle() < 0)
{
throw std::runtime_error{"Failed to create socket."};
}
if (opts.blocking == blocking_t::no)
{
if (!s.blocking(blocking_t::no))
{
throw std::runtime_error{"Failed to set socket to non-blocking mode."};
}
}
return s;
}
static auto make_accept_socket(
const options& opts,
const std::string& address, // force string to guarantee null terminated.
uint16_t port,
int32_t backlog = 128) -> socket
{
socket s = make_socket(opts);
int sock_opt{1};
if (setsockopt(s.native_handle(), SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &sock_opt, sizeof(sock_opt)) < 0)
{
throw std::runtime_error{"Failed to setsockopt(SO_REUSEADDR | SO_REUSEPORT)"};
}
sockaddr_in server{};
server.sin_family = domain_to_os(opts.domain);
server.sin_addr.s_addr = htonl(inet_addr(address.data()));
server.sin_port = htons(port);
if (bind(s.native_handle(), (struct sockaddr*)&server, sizeof(server)) < 0)
{
throw std::runtime_error{"Failed to bind."};
}
if (listen(s.native_handle(), backlog) < 0)
{
throw std::runtime_error{"Failed to listen."};
}
return s;
}
socket() = default;
explicit socket(int fd) : m_fd(fd) {}
socket(const socket&) = delete;
socket(socket&& other) : m_fd(std::exchange(other.m_fd, -1)) {}
auto operator=(const socket&) -> socket& = delete;
auto operator =(socket&& other) noexcept -> socket&
{
if (std::addressof(other) != this)
{
m_fd = std::exchange(other.m_fd, -1);
}
return *this;
}
~socket() { close(); }
auto blocking(blocking_t block) -> bool
{
if (m_fd < 0)
{
return false;
}
int flags = fcntl(m_fd, F_GETFL, 0);
if (flags == -1)
{
return false;
}
// Add or subtract non-blocking flag.
flags = (block == blocking_t::yes) ? flags & ~O_NONBLOCK : (flags | O_NONBLOCK);
return (fcntl(m_fd, F_SETFL, flags) == 0);
}
auto recv(std::span<char> buffer) -> ssize_t { return ::read(m_fd, buffer.data(), buffer.size()); }
auto send(const std::span<const char> buffer) -> ssize_t { return ::write(m_fd, buffer.data(), buffer.size()); }
auto shutdown(poll_op how = poll_op::read_write) -> bool
{
if (m_fd != -1)
{
int h{0};
switch (how)
{
case poll_op::read:
h = SHUT_RD;
break;
case poll_op::write:
h = SHUT_WR;
break;
case poll_op::read_write:
h = SHUT_RDWR;
break;
}
return (::shutdown(m_fd, h) == 0);
}
return false;
}
auto close() -> void
{
if (m_fd != -1)
{
::close(m_fd);
m_fd = -1;
}
}
auto native_handle() const -> int { return m_fd; }
private:
int m_fd{-1};
};
} // namespace coro

121
inc/coro/tcp_scheduler.hpp Normal file
View file

@ -0,0 +1,121 @@
#pragma once
#include "coro/io_scheduler.hpp"
#include "coro/socket.hpp"
#include "coro/task.hpp"
#include <fcntl.h>
#include <functional>
#include <sys/socket.h>
#include <iostream>
namespace coro
{
class tcp_scheduler : public io_scheduler
{
public:
using on_connection_t = std::function<task<void>(tcp_scheduler&, socket)>;
struct options
{
std::string address = "0.0.0.0";
uint16_t port = 8080;
int32_t backlog = 128;
on_connection_t on_connection = nullptr;
io_scheduler::options io_options{};
};
tcp_scheduler(
options opts =
options{
"0.0.0.0",
8080,
128,
[](tcp_scheduler&, socket) -> task<void> { co_return; },
io_scheduler::options{9, 2, io_scheduler::thread_strategy_t::spawn}})
: io_scheduler(std::move(opts.io_options)),
m_opts(std::move(opts)),
m_accept_socket(socket::make_accept_socket(
socket::options{socket::domain_t::ipv4, socket::type_t::tcp, socket::blocking_t::no},
m_opts.address,
m_opts.port,
m_opts.backlog))
{
if (m_opts.on_connection == nullptr)
{
throw std::runtime_error{"options::on_connection cannot be nullptr."};
}
schedule(make_accept_task());
}
tcp_scheduler(const tcp_scheduler&) = delete;
tcp_scheduler(tcp_scheduler&&) = delete;
auto operator=(const tcp_scheduler&) -> tcp_scheduler& = delete;
auto operator=(tcp_scheduler &&) -> tcp_scheduler& = delete;
~tcp_scheduler() override = default;
auto shutdown(shutdown_t wait_for_tasks = shutdown_t::sync) -> void override
{
if (m_accept_new_connections.exchange(false, std::memory_order_release))
{
m_accept_socket.shutdown(); // wake it up by shutting down read/write operations.
while (m_accept_task_exited.load(std::memory_order_acquire) == false)
{
std::this_thread::sleep_for(std::chrono::milliseconds{10});
}
io_scheduler::shutdown(wait_for_tasks);
}
}
private:
options m_opts;
std::atomic<bool> m_accept_new_connections{true};
std::atomic<bool> m_accept_task_exited{false};
socket m_accept_socket{-1};
auto make_accept_task() -> coro::task<void>
{
sockaddr_in client{};
constexpr const int len = sizeof(struct sockaddr_in);
std::vector<task<void>> tasks{};
tasks.reserve(16);
while (m_accept_new_connections.load(std::memory_order::acquire))
{
co_await poll(m_accept_socket.native_handle(), coro::poll_op::read);
// auto status = co_await poll(m_accept_socket.native_handle(), coro::poll_op::read);
// (void)status; // TODO: introduce timeouts on io_scheduer.poll();
// On accept socket read drain the listen accept queue.
while (true)
{
socket s{::accept(m_accept_socket.native_handle(), (struct sockaddr*)&client, (socklen_t*)&len)};
if (s.native_handle() < 0)
{
break;
}
tasks.emplace_back(m_opts.on_connection(std::ref(*this), std::move(s)));
}
if (!tasks.empty())
{
schedule(tasks);
tasks.clear();
}
}
m_accept_task_exited.exchange(true, std::memory_order::release);
co_return;
};
};
} // namespace coro

View file

@ -9,6 +9,7 @@ set(LIBCORO_TEST_SOURCE_FILES
test_latch.cpp
test_sync_wait.cpp
test_task.cpp
test_tcp_scheduler.cpp
test_thread_pool.cpp
test_when_all.cpp
)
@ -29,4 +30,4 @@ elseif(${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
message(FATAL_ERROR "Clang is currently not supported.")
endif()
add_test(NAME corohttp_test COMMAND ${PROJECT_NAME})
add_test(NAME coro_test COMMAND ${PROJECT_NAME})

View file

@ -141,8 +141,8 @@ TEST_CASE("io_scheduler task with read poll")
auto func = [&]() -> coro::task<void> {
// Poll will block until there is data to read.
co_await s.poll(trigger_fd, coro::poll_op::read);
REQUIRE(true);
/*auto status =*/co_await s.poll(trigger_fd, coro::poll_op::read);
/*REQUIRE(status == coro::poll_status::success);*/
co_return;
};
@ -152,6 +152,7 @@ TEST_CASE("io_scheduler task with read poll")
write(trigger_fd, &value, sizeof(value));
s.shutdown();
REQUIRE(s.empty());
close(trigger_fd);
}
@ -546,3 +547,46 @@ TEST_CASE("io_scheduler task throws after resume")
s.shutdown();
REQUIRE(s.empty());
}
TEST_CASE("io_scheduler schedule parameter pack tasks")
{
coro::io_scheduler s{};
std::atomic<uint64_t> counter{0};
auto make_task = [&]() -> coro::task<void> {
counter.fetch_add(1, std::memory_order::relaxed);
co_return;
};
s.schedule(make_task(), make_task(), make_task(), make_task(), make_task());
s.shutdown();
REQUIRE(s.empty());
REQUIRE(counter == 5);
}
TEST_CASE("io_scheduler schedule vector<task>")
{
coro::io_scheduler s{};
std::atomic<uint64_t> counter{0};
auto make_task = [&]() -> coro::task<void> {
counter.fetch_add(1, std::memory_order::relaxed);
co_return;
};
std::vector<coro::task<void>> tasks;
tasks.reserve(4);
tasks.emplace_back(make_task());
tasks.emplace_back(make_task());
tasks.emplace_back(make_task());
tasks.emplace_back(make_task());
s.schedule(tasks);
REQUIRE(tasks.empty());
s.shutdown();
REQUIRE(s.empty());
REQUIRE(counter == 4);
}

View file

@ -0,0 +1,70 @@
#include "catch.hpp"
#include <coro/coro.hpp>
TEST_CASE("tcp_scheduler no on connection throws")
{
REQUIRE_THROWS(coro::tcp_scheduler{coro::tcp_scheduler::options{.on_connection = nullptr}});
}
TEST_CASE("tcp_scheduler ping")
{
std::string msg{"Hello from client"};
auto on_connection = [&](coro::tcp_scheduler& tcp, coro::socket sock) -> coro::task<void> {
/*auto status =*/co_await tcp.poll(sock.native_handle(), coro::poll_op::read);
/*REQUIRE(status == coro::poll_status::success);*/
std::string in{};
in.resize(2048, '\0');
auto read_bytes = sock.recv(std::span<char>{in.data(), in.size()});
REQUIRE(read_bytes == msg.length());
in.resize(read_bytes);
REQUIRE(in == msg);
/*status =*/co_await tcp.poll(sock.native_handle(), coro::poll_op::write);
/*REQUIRE(status == coro::poll_status::success);*/
auto written_bytes = sock.send(std::span<const char>(in.data(), in.length()));
REQUIRE(written_bytes == in.length());
co_return;
};
coro::tcp_scheduler tcp{coro::tcp_scheduler::options{
.address = "0.0.0.0",
.port = 8080,
.backlog = 128,
.on_connection = on_connection,
.io_options = coro::io_scheduler::options{8, 2, coro::io_scheduler::thread_strategy_t::spawn}}};
int client_socket = ::socket(AF_INET, SOCK_STREAM, 0);
sockaddr_in server{};
server.sin_family = AF_INET;
server.sin_port = htons(8080);
if (inet_pton(AF_INET, "127.0.0.1", &server.sin_addr) <= 0)
{
perror("failed to set sin_addr=127.0.0.1");
REQUIRE(false);
}
if (connect(client_socket, (struct sockaddr*)&server, sizeof(server)) < 0)
{
perror("Failed to connect to tcp scheduler server");
REQUIRE(false);
}
::send(client_socket, msg.data(), msg.length(), 0);
std::string response{};
response.resize(256, '\0');
auto bytes_recv = ::recv(client_socket, response.data(), response.length(), 0);
REQUIRE(bytes_recv == msg.length());
response.resize(bytes_recv);
REQUIRE(response == msg);
tcp.shutdown();
REQUIRE(tcp.empty());
close(client_socket);
}