1
0
Fork 0
mirror of https://gitlab.com/niansa/libcrosscoro.git synced 2025-03-06 20:53:32 +01:00

udp client + server (#31)

This commit is contained in:
Josh Baldwin 2021-01-08 20:28:55 -07:00 committed by GitHub
parent f81acc9fcd
commit 92a42699bc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 488 additions and 167 deletions

View file

@ -20,9 +20,11 @@ set(LIBCORO_SOURCE_FILES
inc/coro/net/dns_client.hpp src/net/dns_client.cpp
inc/coro/net/hostname.hpp
inc/coro/net/ip_address.hpp src/net/ip_address.cpp
inc/coro/net/socket.hpp
inc/coro/net/socket.hpp src/net/socket.cpp
inc/coro/net/tcp_client.hpp src/net/tcp_client.cpp
inc/coro/net/tcp_server.hpp src/net/tcp_server.cpp
inc/coro/net/udp_client.hpp src/net/udp_client.cpp
inc/coro/net/udp_server.hpp src/net/udp_server.cpp
inc/coro/awaitable.hpp
inc/coro/coro.hpp

View file

@ -7,6 +7,8 @@
#include "coro/net/socket.hpp"
#include "coro/net/tcp_client.hpp"
#include "coro/net/tcp_server.hpp"
#include "coro/net/udp_client.hpp"
#include "coro/net/udp_server.hpp"
#include "coro/awaitable.hpp"
#include "coro/event.hpp"

View file

@ -350,6 +350,12 @@ public:
auto poll(fd_t fd, poll_op op, std::chrono::milliseconds timeout = std::chrono::milliseconds{0})
-> coro::task<poll_status>;
auto poll(
const net::socket& sock,
poll_op op,
std::chrono::milliseconds timeout = std::chrono::milliseconds{0})
-> coro::task<poll_status>;
/**
* This function will first poll the given `fd` to make sure it can be read from. Once notified
* that the `fd` has data available to read the given `buffer` is filled with up to the buffer's

View file

@ -13,6 +13,7 @@
namespace coro::net
{
class socket
{
public:
@ -35,74 +36,7 @@ public:
blocking_t blocking;
};
static auto type_to_os(const type_t& type) -> int
{
switch (type)
{
case type_t::udp:
return SOCK_DGRAM;
case type_t::tcp:
return SOCK_STREAM;
default:
throw std::runtime_error{"Unknown socket::type_t."};
}
}
static auto make_socket(const options& opts) -> socket
{
socket s{::socket(static_cast<int>(opts.domain), type_to_os(opts.type), 0)};
if (s.native_handle() < 0)
{
throw std::runtime_error{"Failed to create socket."};
}
if (opts.blocking == blocking_t::no)
{
if (s.blocking(blocking_t::no) == false)
{
throw std::runtime_error{"Failed to set socket to non-blocking mode."};
}
}
return s;
}
static auto make_accept_socket(
const options& opts,
const net::ip_address& address,
uint16_t port,
int32_t backlog = 128) -> socket
{
socket s = make_socket(opts);
int sock_opt{1};
if (setsockopt(s.native_handle(), SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &sock_opt, sizeof(sock_opt)) < 0)
{
throw std::runtime_error{"Failed to setsockopt(SO_REUSEADDR | SO_REUSEPORT)"};
}
sockaddr_in server{};
server.sin_family = static_cast<int>(opts.domain);
server.sin_port = htons(port);
server.sin_addr = *reinterpret_cast<const in_addr*>(address.data().data());
// if (inet_pton(server.sin_family, address.data(), &server.sin_addr) <= 0)
// {
// throw std::runtime_error{"Failed to translate IP Address."};
// }
if (bind(s.native_handle(), (struct sockaddr*)&server, sizeof(server)) < 0)
{
throw std::runtime_error{"Failed to bind."};
}
if (listen(s.native_handle(), backlog) < 0)
{
throw std::runtime_error{"Failed to listen."};
}
return s;
}
static auto type_to_os(type_t type) -> int;
socket() = default;
@ -112,72 +46,16 @@ public:
socket(socket&& other) : m_fd(std::exchange(other.m_fd, -1)) {}
auto operator=(const socket&) -> socket& = delete;
auto operator =(socket&& other) noexcept -> socket&
{
if (std::addressof(other) != this)
{
m_fd = std::exchange(other.m_fd, -1);
}
return *this;
}
auto operator=(socket&& other) noexcept -> socket&;
~socket() { close(); }
auto blocking(blocking_t block) -> bool
{
if (m_fd < 0)
{
return false;
}
auto blocking(blocking_t block) -> bool;
int flags = fcntl(m_fd, F_GETFL, 0);
if (flags == -1)
{
return false;
}
auto shutdown(poll_op how = poll_op::read_write) -> bool;
// Add or subtract non-blocking flag.
flags = (block == blocking_t::yes) ? flags & ~O_NONBLOCK : (flags | O_NONBLOCK);
return (fcntl(m_fd, F_SETFL, flags) == 0);
}
auto recv(std::span<char> buffer) -> ssize_t { return ::read(m_fd, buffer.data(), buffer.size()); }
auto send(const std::span<const char> buffer) -> ssize_t { return ::write(m_fd, buffer.data(), buffer.size()); }
auto shutdown(poll_op how = poll_op::read_write) -> bool
{
if (m_fd != -1)
{
int h{0};
switch (how)
{
case poll_op::read:
h = SHUT_RD;
break;
case poll_op::write:
h = SHUT_WR;
break;
case poll_op::read_write:
h = SHUT_RDWR;
break;
}
return (::shutdown(m_fd, h) == 0);
}
return false;
}
auto close() -> void
{
if (m_fd != -1)
{
::close(m_fd);
m_fd = -1;
}
}
auto close() -> void;
auto native_handle() const -> int { return m_fd; }
@ -185,4 +63,12 @@ private:
int m_fd{-1};
};
auto make_socket(const socket::options& opts) -> socket;
auto make_accept_socket(
const socket::options& opts,
const net::ip_address& address,
uint16_t port,
int32_t backlog = 128) -> socket;
} // namespace coro::net

View file

@ -12,6 +12,7 @@
#include <optional>
#include <variant>
#include <memory>
#include <chrono>
namespace coro
{
@ -50,8 +51,12 @@ public:
auto connect(std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task<net::connect_status>;
auto socket() const -> const net::socket& { return m_socket; }
auto socket() -> net::socket& { return m_socket; }
auto recv(
std::span<char> buffer,
std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task<std::pair<poll_status, ssize_t>>;
auto send(
const std::span<const char> buffer,
std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task<std::pair<poll_status, ssize_t>>;
private:
/// The scheduler that will drive this tcp client.

View file

@ -28,11 +28,11 @@ public:
explicit tcp_server(
options opts =
options{
net::ip_address::from_string("0.0.0.0"),
8080,
128,
[](tcp_server&, net::socket) -> task<void> { co_return; },
io_scheduler::options{9, 2, io_scheduler::thread_strategy_t::spawn}});
.address = net::ip_address::from_string("0.0.0.0"),
.port = 8080,
.backlog = 128,
.on_connection = [](tcp_server&, net::socket) -> task<void> { co_return; },
.io_options = io_scheduler::options{}});
tcp_server(const tcp_server&) = delete;
tcp_server(tcp_server&&) = delete;
@ -53,7 +53,7 @@ public:
auto shutdown(shutdown_t wait_for_tasks = shutdown_t::sync) -> void override;
private:
options m_opts;
options m_options;
/// Should the accept task continue accepting new connections?
std::atomic<bool> m_accept_new_connections{true};

View file

@ -0,0 +1,53 @@
#pragma once
#include "coro/net/hostname.hpp"
#include "coro/net/ip_address.hpp"
#include "coro/net/socket.hpp"
#include "coro/task.hpp"
#include <chrono>
#include <variant>
#include <span>
namespace coro
{
class io_scheduler;
} // namespace coro
namespace coro::net
{
class udp_client
{
public:
struct options
{
/// The ip address to connect to. If using hostname then a dns client must be provided.
net::ip_address address{net::ip_address::from_string("127.0.0.1")};
/// The port to connect to.
uint16_t port{8080};
};
udp_client(io_scheduler& scheduler, options opts = options{
.address = {net::ip_address::from_string("127.0.0.1")},
.port = 8080});
udp_client(const udp_client&) = delete;
udp_client(udp_client&&) = default;
auto operator=(const udp_client&) noexcept -> udp_client& = delete;
auto operator=(udp_client&&) noexcept -> udp_client& = default;
~udp_client() = default;
auto sendto(
const std::span<const char> buffer,
std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task<ssize_t>;
private:
/// The scheduler that will drive this udp client.
io_scheduler& m_io_scheduler;
/// Options for what server to connect to.
options m_options;
/// The udp socket.
net::socket m_socket{-1};
};
} // namespace coro::net

View file

@ -0,0 +1,51 @@
#pragma once
#include "coro/net/socket.hpp"
#include "coro/net/udp_client.hpp"
#include "coro/io_scheduler.hpp"
#include <string>
#include <functional>
#include <span>
#include <optional>
namespace coro::net
{
class udp_server
{
public:
struct options
{
/// The local address to bind to to recv packets from.
net::ip_address address{net::ip_address::from_string("0.0.0.0")};
/// The port to recv packets from.
uint16_t port{8080};
};
explicit udp_server(
io_scheduler& io_scheduler,
options opts =
options{
.address = net::ip_address::from_string("0.0.0.0"),
.port = 8080,
}
);
udp_server(const udp_server&) = delete;
udp_server(udp_server&&) = default;
auto operator=(const udp_server&) -> udp_server& = delete;
auto operator=(udp_server&&) -> udp_server& = default;
~udp_server() = default;
auto recvfrom(
std::span<char>& buffer,
std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task<std::optional<udp_client::options>>;
private:
io_scheduler& m_io_scheduler;
options m_options;
net::socket m_accept_socket{-1};
};
} // namespace coro::net

View file

@ -318,6 +318,15 @@ auto io_scheduler::poll(fd_t fd, poll_op op, std::chrono::milliseconds timeout)
co_return status;
}
auto io_scheduler::poll(
const net::socket& sock,
poll_op op,
std::chrono::milliseconds timeout)
-> coro::task<poll_status>
{
return poll(sock.native_handle(), op, timeout);
}
auto io_scheduler::read(fd_t fd, std::span<char> buffer, std::chrono::milliseconds timeout)
-> coro::task<std::pair<poll_status, ssize_t>>
{

134
src/net/socket.cpp Normal file
View file

@ -0,0 +1,134 @@
#include "coro/net/socket.hpp"
namespace coro::net
{
auto socket::type_to_os(type_t type) -> int
{
switch (type)
{
case type_t::udp:
return SOCK_DGRAM;
case type_t::tcp:
return SOCK_STREAM;
default:
throw std::runtime_error{"Unknown socket::type_t."};
}
}
auto socket::operator=(socket&& other) noexcept -> socket&
{
if (std::addressof(other) != this)
{
m_fd = std::exchange(other.m_fd, -1);
}
return *this;
}
auto socket::blocking(blocking_t block) -> bool
{
if (m_fd < 0)
{
return false;
}
int flags = fcntl(m_fd, F_GETFL, 0);
if (flags == -1)
{
return false;
}
// Add or subtract non-blocking flag.
flags = (block == blocking_t::yes) ? flags & ~O_NONBLOCK : (flags | O_NONBLOCK);
return (fcntl(m_fd, F_SETFL, flags) == 0);
}
auto socket::shutdown(poll_op how) -> bool
{
if (m_fd != -1)
{
int h{0};
switch (how)
{
case poll_op::read:
h = SHUT_RD;
break;
case poll_op::write:
h = SHUT_WR;
break;
case poll_op::read_write:
h = SHUT_RDWR;
break;
}
return (::shutdown(m_fd, h) == 0);
}
return false;
}
auto socket::close() -> void
{
if (m_fd != -1)
{
::close(m_fd);
m_fd = -1;
}
}
auto make_socket(const socket::options& opts) -> socket
{
socket s{::socket(static_cast<int>(opts.domain), socket::type_to_os(opts.type), 0)};
if (s.native_handle() < 0)
{
throw std::runtime_error{"Failed to create socket."};
}
if (opts.blocking == socket::blocking_t::no)
{
if (s.blocking(socket::blocking_t::no) == false)
{
throw std::runtime_error{"Failed to set socket to non-blocking mode."};
}
}
return s;
}
auto make_accept_socket(
const socket::options& opts,
const net::ip_address& address,
uint16_t port,
int32_t backlog) -> socket
{
socket s = make_socket(opts);
int sock_opt{1};
if (setsockopt(s.native_handle(), SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &sock_opt, sizeof(sock_opt)) < 0)
{
throw std::runtime_error{"Failed to setsockopt(SO_REUSEADDR | SO_REUSEPORT)"};
}
sockaddr_in server{};
server.sin_family = static_cast<int>(opts.domain);
server.sin_port = htons(port);
server.sin_addr = *reinterpret_cast<const in_addr*>(address.data().data());
if (bind(s.native_handle(), (struct sockaddr*)&server, sizeof(server)) < 0)
{
throw std::runtime_error{"Failed to bind."};
}
if(opts.type == socket::type_t::tcp)
{
if (listen(s.native_handle(), backlog) < 0)
{
throw std::runtime_error{"Failed to listen."};
}
}
return s;
}
} // namespace coro::net

View file

@ -5,10 +5,12 @@
namespace coro::net
{
using namespace std::chrono_literals;
tcp_client::tcp_client(io_scheduler& scheduler, options opts)
: m_io_scheduler(scheduler),
m_options(std::move(opts)),
m_socket(net::socket::make_socket(net::socket::options{m_options.domain, net::socket::type_t::tcp, net::socket::blocking_t::no}))
m_socket(net::make_socket(net::socket::options{m_options.domain, net::socket::type_t::tcp, net::socket::blocking_t::no}))
{
}
@ -100,4 +102,28 @@ auto tcp_client::connect(std::chrono::milliseconds timeout) -> coro::task<connec
co_return connect_status::error;
}
auto tcp_client::recv(std::span<char> buffer, std::chrono::milliseconds timeout) -> coro::task<std::pair<poll_status, ssize_t>>
{
auto pstatus = co_await m_io_scheduler.poll(m_socket, poll_op::read, timeout);
ssize_t bread{0};
if(pstatus == poll_status::event)
{
bread = ::read(m_socket.native_handle(), buffer.data(), buffer.size());
}
co_return {pstatus, bread};
}
auto tcp_client::send(const std::span<const char> buffer, std::chrono::milliseconds timeout) -> coro::task<std::pair<poll_status, ssize_t>>
{
auto pstatus = co_await m_io_scheduler.poll(m_socket, poll_op::write, timeout);
ssize_t bwrite{0};
if(pstatus == poll_status::event)
{
bwrite = ::write(m_socket.native_handle(), buffer.data(), buffer.size());
}
co_return {pstatus, bwrite};
}
} // namespace coro::net

View file

@ -5,14 +5,14 @@ namespace coro::net
tcp_server::tcp_server(options opts)
: io_scheduler(std::move(opts.io_options)),
m_opts(std::move(opts)),
m_accept_socket(net::socket::make_accept_socket(
m_options(std::move(opts)),
m_accept_socket(net::make_accept_socket(
net::socket::options{net::domain_t::ipv4, net::socket::type_t::tcp, net::socket::blocking_t::no},
m_opts.address,
m_opts.port,
m_opts.backlog))
m_options.address,
m_options.port,
m_options.backlog))
{
if (m_opts.on_connection == nullptr)
if (m_options.on_connection == nullptr)
{
throw std::runtime_error{"options::on_connection cannot be nullptr."};
}
@ -50,25 +50,25 @@ auto tcp_server::make_accept_task() -> coro::task<void>
while (m_accept_new_connections.load(std::memory_order::acquire))
{
co_await poll(m_accept_socket.native_handle(), coro::poll_op::read);
// auto status = co_await poll(m_accept_socket.native_handle(), coro::poll_op::read);
// (void)status; // TODO: introduce timeouts on io_scheduer.poll();
// On accept socket read drain the listen accept queue.
while (true)
auto pstatus = co_await poll(m_accept_socket, coro::poll_op::read, std::chrono::seconds{1});
if(pstatus == poll_status::event)
{
net::socket s{::accept(m_accept_socket.native_handle(), (struct sockaddr*)&client, (socklen_t*)&len)};
if (s.native_handle() < 0)
// On accept socket read drain the listen accept queue.
while (true)
{
break;
net::socket s{::accept(m_accept_socket.native_handle(), (struct sockaddr*)&client, (socklen_t*)&len)};
if (s.native_handle() < 0)
{
break;
}
tasks.emplace_back(m_options.on_connection(std::ref(*this), std::move(s)));
}
tasks.emplace_back(m_opts.on_connection(std::ref(*this), std::move(s)));
}
if (!tasks.empty())
{
schedule(std::move(tasks));
if (!tasks.empty())
{
schedule(std::move(tasks));
}
}
}

39
src/net/udp_client.cpp Normal file
View file

@ -0,0 +1,39 @@
#include "coro/net/udp_client.hpp"
#include "coro/io_scheduler.hpp"
namespace coro::net
{
udp_client::udp_client(io_scheduler& scheduler, options opts)
: m_io_scheduler(scheduler),
m_options(std::move(opts)),
m_socket({net::make_socket(net::socket::options{m_options.address.domain(), net::socket::type_t::udp, net::socket::blocking_t::no})})
{
}
auto udp_client::sendto(const std::span<const char> buffer, std::chrono::milliseconds timeout) -> coro::task<ssize_t>
{
auto pstatus = co_await m_io_scheduler.poll(m_socket, poll_op::write, timeout);
if(pstatus != poll_status::event)
{
co_return 0;
}
sockaddr_in server{};
server.sin_family = static_cast<int>(m_options.address.domain());
server.sin_port = htons(m_options.port);
server.sin_addr = *reinterpret_cast<const in_addr*>(m_options.address.data().data());
socklen_t server_len{sizeof(server)};
co_return ::sendto(
m_socket.native_handle(),
buffer.data(),
buffer.size(),
0,
reinterpret_cast<sockaddr*>(&server),
server_len);
}
} // namespace coro::net

57
src/net/udp_server.cpp Normal file
View file

@ -0,0 +1,57 @@
#include "coro/net/udp_server.hpp"
#include "coro/io_scheduler.hpp"
namespace coro::net
{
udp_server::udp_server(io_scheduler& io_scheduler, options opts)
: m_io_scheduler(io_scheduler),
m_options(std::move(opts)),
m_accept_socket(net::make_accept_socket(
net::socket::options{m_options.address.domain(), net::socket::type_t::udp, net::socket::blocking_t::no},
m_options.address,
m_options.port
))
{
}
auto udp_server::recvfrom(std::span<char>& buffer, std::chrono::milliseconds timeout) -> coro::task<std::optional<udp_client::options>>
{
auto pstatus = co_await m_io_scheduler.poll(m_accept_socket, poll_op::read, timeout);
if(pstatus != poll_status::event)
{
co_return std::nullopt;
}
sockaddr_in client{};
socklen_t client_len{sizeof(client)};
auto bytes_read = ::recvfrom(
m_accept_socket.native_handle(),
buffer.data(),
buffer.size(),
0,
reinterpret_cast<sockaddr*>(&client),
&client_len);
if(bytes_read == -1)
{
co_return std::nullopt;
}
buffer = buffer.subspan(0, bytes_read);
std::span<const uint8_t> ip_addr_view{
reinterpret_cast<uint8_t*>(&client.sin_addr.s_addr),
sizeof(client.sin_addr.s_addr),
};
co_return udp_client::options{
.address = net::ip_address{ip_addr_view, static_cast<net::domain_t>(client.sin_family)},
.port = ntohs(client.sin_port)
};
}
} // namespace coro::net

View file

@ -5,6 +5,7 @@ set(LIBCORO_TEST_SOURCE_FILES
net/test_dns_client.cpp
net/test_ip_address.cpp
net/test_tcp_server.cpp
net/test_udp_peers.cpp
bench.cpp
test_event.cpp

View file

@ -399,16 +399,14 @@ TEST_CASE("benchmark tcp_client and tcp_server")
msg_ptr = &done_msg;
}
auto [wstatus, wbytes] =
co_await client_scheduler.write(client.socket(), std::span<const char>{msg_ptr->data(), msg_ptr->length()});
auto [wstatus, wbytes] = co_await client.send(std::span<const char>{msg_ptr->data(), msg_ptr->length()});
REQUIRE(wstatus == coro::poll_status::event);
REQUIRE(wbytes == msg_ptr->length());
std::string response(64, '\0');
auto [rstatus, rbytes] =
co_await client_scheduler.read(client.socket(), std::span<char>{response.data(), response.length()});
auto [rstatus, rbytes] = co_await client.recv(std::span<char>{response.data(), response.length()});
REQUIRE(rstatus == coro::poll_status::event);
REQUIRE(rbytes == msg_ptr->length());

View file

@ -49,16 +49,14 @@ static auto tcp_server_echo(
auto cstatus = co_await client.connect();
REQUIRE(cstatus == coro::net::connect_status::connected);
auto [wstatus, wbytes] =
co_await scheduler.write(client.socket(), std::span<const char>{msg.data(), msg.length()});
auto [wstatus, wbytes] = co_await client.send(std::span<const char>{msg.data(), msg.length()});
REQUIRE(wstatus == coro::poll_status::event);
REQUIRE(wbytes == msg.length());
std::string response(64, '\0');
auto [rstatus, rbytes] =
co_await scheduler.read(client.socket(), std::span<char>{response.data(), response.length()});
auto [rstatus, rbytes] = co_await client.recv(std::span<char>{response.data(), response.length()});
REQUIRE(rstatus == coro::poll_status::event);
REQUIRE(rbytes == msg.length());

View file

@ -0,0 +1,54 @@
#include "catch.hpp"
#include <coro/coro.hpp>
TEST_CASE("udp echo peers")
{
const std::string client_msg{"Hello from client!"};
const std::string server_msg{"Hello from server!!"};
coro::io_scheduler scheduler{};
auto make_client_task = [&](uint16_t client_port, uint16_t server_port) -> coro::task<void> {
std::string owning_buffer(4096, '\0');
coro::net::udp_client client{scheduler, coro::net::udp_client::options{.port = client_port}};
auto wbytes = co_await client.sendto(std::span<const char>{client_msg.data(), client_msg.length()});
REQUIRE(wbytes == client_msg.length());
coro::net::udp_server server{scheduler, coro::net::udp_server::options{.port = server_port}};
std::span<char> buffer{owning_buffer.data(), owning_buffer.length()};
auto client_opt = co_await server.recvfrom(buffer);
REQUIRE(client_opt.has_value());
REQUIRE(buffer.size() == server_msg.length());
owning_buffer.resize(buffer.size());
REQUIRE(owning_buffer == server_msg);
co_return;
};
auto make_server_task = [&](uint16_t server_port, uint16_t client_port) -> coro::task<void> {
std::string owning_buffer(4096, '\0');
coro::net::udp_server server{scheduler, coro::net::udp_server::options{.port = server_port}};
std::span<char> buffer{owning_buffer.data(), owning_buffer.length()};
auto client_opt = co_await server.recvfrom(buffer);
REQUIRE(client_opt.has_value());
REQUIRE(buffer.size() == client_msg.length());
owning_buffer.resize(buffer.size());
REQUIRE(owning_buffer == client_msg);
auto options = client_opt.value();
options.port = client_port; // we'll change the port for this test since its the same host
coro::net::udp_client client{scheduler, options};
auto wbytes = co_await client.sendto(std::span<const char>{server_msg.data(), server_msg.length()});
REQUIRE(wbytes == server_msg.length());
co_return;
};
scheduler.schedule(make_server_task(8080, 8081));
scheduler.schedule(make_client_task(8080, 8081));
}