1
0
Fork 0
mirror of https://gitlab.com/niansa/libcrosscoro.git synced 2025-03-06 20:53:32 +01:00

But cleanup

This commit is contained in:
Nils 2021-07-28 11:50:15 +02:00
parent d1263aebd7
commit 73d8424008
25 changed files with 0 additions and 2459 deletions

View file

@ -19,7 +19,6 @@ set(LIBCORO_SOURCE_FILES
inc/coro/concepts/promise.hpp
inc/coro/concepts/range_of.hpp
inc/coro/detail/poll_info.hpp
inc/coro/detail/void_value.hpp
inc/coro/coro.hpp

View file

@ -1,74 +0,0 @@
#pragma once
#include "coro/fd.hpp"
#include "coro/poll.hpp"
#include <atomic>
#include <chrono>
#include <coroutine>
#include <map>
#include <optional>
namespace coro::detail
{
/**
* Poll Info encapsulates everything about a poll operation for the event as well as its paired
* timeout. This is important since coroutines that are waiting on an event or timeout do not
* immediately execute, they are re-scheduled onto the thread pool, so its possible its pair
* event or timeout also triggers while the coroutine is still waiting to resume. This means that
* the first one to happen, the event itself or its timeout, needs to disable the other pair item
* prior to resuming the coroutine.
*
* Finally, its also important to note that the event and its paired timeout could happen during
* the same epoll_wait and possibly trigger the coroutine to start twice. Only one can win, so the
* first one processed sets m_processed to true and any subsequent events in the same epoll batch
* are effectively discarded.
*/
struct poll_info
{
using clock = std::chrono::steady_clock;
using time_point = clock::time_point;
using timed_events = std::multimap<time_point, detail::poll_info*>;
poll_info() = default;
~poll_info() = default;
poll_info(const poll_info&) = delete;
poll_info(poll_info&&) = delete;
auto operator=(const poll_info&) -> poll_info& = delete;
auto operator=(poll_info&&) -> poll_info& = delete;
struct poll_awaiter
{
explicit poll_awaiter(poll_info& pi) noexcept : m_pi(pi) {}
auto await_ready() const noexcept -> bool { return false; }
auto await_suspend(std::coroutine_handle<> awaiting_coroutine) noexcept -> void
{
m_pi.m_awaiting_coroutine = awaiting_coroutine;
std::atomic_thread_fence(std::memory_order::release);
}
auto await_resume() noexcept -> coro::poll_status { return m_pi.m_poll_status; }
poll_info& m_pi;
};
auto operator co_await() noexcept -> poll_awaiter { return poll_awaiter{*this}; }
/// The file descriptor being polled on. This is needed so that if the timeout occurs first then
/// the event loop can immediately disable the event within epoll.
fd_t m_fd{-1};
/// The timeout's position in the timeout map. A poll() with no timeout or yield() this is empty.
/// This is needed so that if the event occurs first then the event loop can immediately disable
/// the timeout within epoll.
std::optional<timed_events::iterator> m_timer_pos{std::nullopt};
/// The awaiting coroutine for this poll info to resume upon event or timeout.
std::coroutine_handle<> m_awaiting_coroutine;
/// The status of the poll operation.
coro::poll_status m_poll_status{coro::poll_status::error};
/// Did the timeout and event trigger at the same time on the same epoll_wait call?
/// Once this is set to true all future events on this poll info are null and void.
bool m_processed{false};
};
} // namespace coro::detail

View file

@ -1,25 +0,0 @@
#pragma once
#include <string>
namespace coro::net
{
enum class connect_status
{
/// The connection has been established.
connected,
/// The given ip address could not be parsed or is invalid.
invalid_ip_address,
/// The connection operation timed out.
timeout,
/// There was an error, use errno to get more information on the specific error.
error
};
/**
* @param status String representation of the connection status.
* @throw std::logic_error If provided an invalid connect_status enum value.
*/
auto to_string(const connect_status& status) -> const std::string&;
} // namespace coro::net

View file

@ -1,100 +0,0 @@
#pragma once
#include "coro/fd.hpp"
#include "coro/io_scheduler.hpp"
#include "coro/net/hostname.hpp"
#include "coro/net/ip_address.hpp"
#include "coro/task.hpp"
#include "coro/task_container.hpp"
#include <ares.h>
#include <array>
#include <chrono>
#include <functional>
#include <memory>
#include <mutex>
#include <sys/epoll.h>
#include <unordered_set>
#include <vector>
namespace coro::net
{
class dns_resolver;
enum class dns_status
{
complete,
error
};
class dns_result
{
friend dns_resolver;
public:
dns_result(coro::io_scheduler& scheduler, coro::event& resume, uint64_t pending_dns_requests);
~dns_result() = default;
/**
* @return The status of the dns lookup.
*/
auto status() const -> dns_status { return m_status; }
/**
* @return If the result of the dns looked was successful then the list of ip addresses that
* were resolved from the hostname.
*/
auto ip_addresses() const -> const std::vector<coro::net::ip_address>& { return m_ip_addresses; }
private:
coro::io_scheduler& m_io_scheduler;
coro::event& m_resume;
uint64_t m_pending_dns_requests{0};
dns_status m_status{dns_status::complete};
std::vector<coro::net::ip_address> m_ip_addresses{};
friend auto ares_dns_callback(void* arg, int status, int timeouts, struct hostent* host) -> void;
};
class dns_resolver
{
public:
explicit dns_resolver(std::shared_ptr<io_scheduler> scheduler, std::chrono::milliseconds timeout);
dns_resolver(const dns_resolver&) = delete;
dns_resolver(dns_resolver&&) = delete;
auto operator=(const dns_resolver&) noexcept -> dns_resolver& = delete;
auto operator=(dns_resolver&&) noexcept -> dns_resolver& = delete;
~dns_resolver();
/**
* @param hn The hostname to resolve its ip addresses.
*/
auto host_by_name(const net::hostname& hn) -> coro::task<std::unique_ptr<dns_result>>;
private:
/// The io scheduler to drive the events for dns lookups.
std::shared_ptr<io_scheduler> m_io_scheduler;
/// The global timeout per dns lookup request.
std::chrono::milliseconds m_timeout{0};
/// The libc-ares channel for looking up dns entries.
ares_channel m_ares_channel{nullptr};
/// This is the set of sockets that are currently being actively polled so multiple poll tasks
/// are not setup when ares_poll() is called.
std::unordered_set<fd_t> m_active_sockets{};
task_container<io_scheduler> m_task_container;
/// Global count to track if c-ares has been initialized or cleaned up.
static uint64_t m_ares_count;
/// Critical section around the c-ares global init/cleanup to prevent heap corruption.
static std::mutex m_ares_mutex;
auto ares_poll() -> void;
auto make_poll_task(fd_t fd, poll_op ops) -> coro::task<void>;
};
} // namespace coro::net

View file

@ -1,26 +0,0 @@
#pragma once
#include <string>
namespace coro::net
{
class hostname
{
public:
hostname() = default;
explicit hostname(std::string hn) : m_hostname(std::move(hn)) {}
hostname(const hostname&) = default;
hostname(hostname&&) = default;
auto operator=(const hostname&) noexcept -> hostname& = default;
auto operator=(hostname&&) noexcept -> hostname& = default;
~hostname() = default;
auto data() const -> const std::string& { return m_hostname; }
auto operator<=>(const hostname& other) const { return m_hostname <=> other.m_hostname; }
private:
std::string m_hostname;
};
} // namespace coro::net

View file

@ -1,106 +0,0 @@
#pragma once
#include <arpa/inet.h>
#include <array>
#include <cstring>
#include <span>
#include <stdexcept>
#include <string>
namespace coro::net
{
enum class domain_t : int
{
ipv4 = AF_INET,
ipv6 = AF_INET6
};
auto to_string(domain_t domain) -> const std::string&;
class ip_address
{
public:
static const constexpr size_t ipv4_len{4};
static const constexpr size_t ipv6_len{16};
ip_address() = default;
ip_address(std::span<const uint8_t> binary_address, domain_t domain = domain_t::ipv4) : m_domain(domain)
{
if (m_domain == domain_t::ipv4 && binary_address.size() > ipv4_len)
{
throw std::runtime_error{"coro::net::ip_address provided binary ip address is too long"};
}
else if (binary_address.size() > ipv6_len)
{
throw std::runtime_error{"coro::net::ip_address provided binary ip address is too long"};
}
std::copy(binary_address.begin(), binary_address.end(), m_data.begin());
}
ip_address(const ip_address&) = default;
ip_address(ip_address&&) = default;
auto operator=(const ip_address&) noexcept -> ip_address& = default;
auto operator=(ip_address&&) noexcept -> ip_address& = default;
~ip_address() = default;
auto domain() const -> domain_t { return m_domain; }
auto data() const -> std::span<const uint8_t>
{
if (m_domain == domain_t::ipv4)
{
return std::span<const uint8_t>{m_data.data(), ipv4_len};
}
else
{
return std::span<const uint8_t>{m_data.data(), ipv6_len};
}
}
static auto from_string(const std::string& address, domain_t domain = domain_t::ipv4) -> ip_address
{
ip_address addr{};
addr.m_domain = domain;
auto success = inet_pton(static_cast<int>(addr.m_domain), address.data(), addr.m_data.data());
if (success != 1)
{
throw std::runtime_error{"coro::net::ip_address faild to convert from string"};
}
return addr;
}
auto to_string() const -> std::string
{
std::string output;
if (m_domain == domain_t::ipv4)
{
output.resize(INET_ADDRSTRLEN, '\0');
}
else
{
output.resize(INET6_ADDRSTRLEN, '\0');
}
auto success = inet_ntop(static_cast<int>(m_domain), m_data.data(), output.data(), output.length());
if (success != nullptr)
{
auto len = strnlen(success, output.length());
output.resize(len);
}
else
{
throw std::runtime_error{"coro::net::ip_address failed to convert to string representation"};
}
return output;
}
auto operator<=>(const ip_address& other) const = default;
private:
domain_t m_domain{domain_t::ipv4};
std::array<uint8_t, ipv6_len> m_data{};
};
} // namespace coro::net

View file

@ -1,32 +0,0 @@
#pragma once
#include <cstdint>
#include <errno.h>
#include <string>
namespace coro::net
{
enum class recv_status : int64_t
{
ok = 0,
/// The peer closed the socket.
closed = -1,
/// The udp socket has not been bind()'ed to a local port.
udp_not_bound = -2,
try_again = EAGAIN,
would_block = EWOULDBLOCK,
bad_file_descriptor = EBADF,
connection_refused = ECONNREFUSED,
memory_fault = EFAULT,
interrupted = EINTR,
invalid_argument = EINVAL,
no_memory = ENOMEM,
not_connected = ENOTCONN,
not_a_socket = ENOTSOCK,
ssl_error = -3
};
auto to_string(recv_status status) -> const std::string&;
} // namespace coro::net

View file

@ -1,32 +0,0 @@
#pragma once
#include <cstdint>
#include <errno.h>
namespace coro::net
{
enum class send_status : int64_t
{
ok = 0,
permission_denied = EACCES,
try_again = EAGAIN,
would_block = EWOULDBLOCK,
already_in_progress = EALREADY,
bad_file_descriptor = EBADF,
connection_reset = ECONNRESET,
no_peer_address = EDESTADDRREQ,
memory_fault = EFAULT,
interrupted = EINTR,
is_connection = EISCONN,
message_size = EMSGSIZE,
output_queue_full = ENOBUFS,
no_memory = ENOMEM,
not_connected = ENOTCONN,
not_a_socket = ENOTSOCK,
operationg_not_supported = EOPNOTSUPP,
pipe_closed = EPIPE,
ssl_error = -3
};
} // namespace coro::net

View file

@ -1,109 +0,0 @@
#pragma once
#include "coro/net/ip_address.hpp"
#include "coro/poll.hpp"
#include <arpa/inet.h>
#include <fcntl.h>
#include <span>
#include <unistd.h>
#include <utility>
#include <iostream>
namespace coro::net
{
class socket
{
public:
enum class type_t
{
/// udp datagram socket
udp,
/// tcp streaming socket
tcp
};
enum class blocking_t
{
/// This socket should block on system calls.
yes,
/// This socket should not block on system calls.
no
};
struct options
{
/// The domain for the socket.
domain_t domain;
/// The type of socket.
type_t type;
/// If the socket should be blocking or non-blocking.
blocking_t blocking;
};
static auto type_to_os(type_t type) -> int;
socket() = default;
explicit socket(int fd) : m_fd(fd) {}
socket(const socket&) = delete;
socket(socket&& other) : m_fd(std::exchange(other.m_fd, -1)) {}
auto operator=(const socket&) -> socket& = delete;
auto operator =(socket&& other) noexcept -> socket&;
~socket() { close(); }
/**
* This function returns true if the socket's file descriptor is a valid number, however it does
* not imply if the socket is still usable.
* @return True if the socket file descriptor is > 0.
*/
auto is_valid() const -> bool { return m_fd != -1; }
/**
* @param block Sets the socket to the given blocking mode.
*/
auto blocking(blocking_t block) -> bool;
/**
* @param how Shuts the socket down with the given operations.
* @param Returns true if the sockets given operations were shutdown.
*/
auto shutdown(poll_op how = poll_op::read_write) -> bool;
/**
* Closes the socket and sets this socket to an invalid state.
*/
auto close() -> void;
/**
* @return The native handle (file descriptor) for this socket.
*/
auto native_handle() const -> int { return m_fd; }
private:
int m_fd{-1};
};
/**
* Creates a socket with the given socket options, this typically is used for creating sockets to
* use within client objects, e.g. tcp_client and udp_client.
* @param opts See socket::options for more details.
*/
auto make_socket(const socket::options& opts) -> socket;
/**
* Creates a socket that can accept connections or packets with the given socket options, address,
* port and backlog. This is used for creating sockets to use within server objects, e.g.
* tcp_server and udp_server.
* @param opts See socket::options for more details
* @param address The ip address to bind to. If the type of socket is tcp then it will also listen.
* @param port The port to bind to.
* @param backlog If the type of socket is tcp then the backlog of connections to allow. Does nothing
* for udp types.
*/
auto make_accept_socket(
const socket::options& opts, const net::ip_address& address, uint16_t port, int32_t backlog = 128) -> socket;
} // namespace coro::net

View file

@ -1,61 +0,0 @@
#pragma once
#include <filesystem>
#include <mutex>
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/pem.h>
#include <openssl/ssl.h>
namespace coro::net
{
class tcp_client;
enum class ssl_file_type : int
{
/// The file is of type ASN1
asn1 = SSL_FILETYPE_ASN1,
/// The file is of type PEM
pem = SSL_FILETYPE_PEM
};
/**
* SSL context, used with client or server types to provide secure connections.
*/
class ssl_context
{
public:
/**
* Creates a context with no certificate and no private key, maybe useful for testing.
*/
ssl_context();
/**
* Creates a context with the given certificate and the given private key.
* @param certificate The location of the certificate file.
* @param certificate_type See `ssl_file_type`.
* @param private_key The location of the private key file.
* @param private_key_type See `ssl_file_type`.
*/
ssl_context(
std::filesystem::path certificate,
ssl_file_type certificate_type,
std::filesystem::path private_key,
ssl_file_type private_key_type);
~ssl_context();
private:
static uint64_t m_ssl_context_count;
static std::mutex m_ssl_context_mutex;
SSL_CTX* m_ssl_ctx{nullptr};
/// The following classes use the underlying SSL_CTX* object for performing SSL functions.
friend tcp_client;
auto native_handle() -> SSL_CTX* { return m_ssl_ctx; }
auto native_handle() const -> const SSL_CTX* { return m_ssl_ctx; }
};
} // namespace coro::net

View file

@ -1,28 +0,0 @@
#pragma once
namespace coro::net
{
enum class ssl_handshake_status
{
/// The ssl handshake was successful.
ok,
/// The connection hasn't been established yet, use connect() prior to the ssl_handshake().
not_connected,
/// The connection needs a coro::net::ssl_context to perform the handshake.
ssl_context_required,
/// The internal ssl memory alocation failed.
ssl_resource_allocation_failed,
/// Attempting to set the connections ssl socket/file descriptor failed.
ssl_set_fd_failure,
/// The handshake had an error.
handshake_failed,
/// The handshake timed out.
timeout,
/// An error occurred while polling for read or write operations on the socket.
poll_error,
/// The socket was unexpectedly closed while attempting the handshake.
unexpected_close
};
} // namespace coro::net

View file

@ -1,303 +0,0 @@
#pragma once
#include "coro/concepts/buffer.hpp"
#include "coro/io_scheduler.hpp"
#include "coro/net/connect.hpp"
#include "coro/net/ip_address.hpp"
#include "coro/net/recv_status.hpp"
#include "coro/net/send_status.hpp"
#include "coro/net/socket.hpp"
#include "coro/net/ssl_context.hpp"
#include "coro/net/ssl_handshake_status.hpp"
#include "coro/poll.hpp"
#include "coro/task.hpp"
#include <chrono>
#include <memory>
#include <optional>
namespace coro::net
{
class tcp_server;
class tcp_client
{
public:
struct options
{
/// The ip address to connect to. Use a dns_resolver to turn hostnames into ip addresses.
net::ip_address address{net::ip_address::from_string("127.0.0.1")};
/// The port to connect to.
uint16_t port{8080};
/// Should this tcp_client connect using a secure connection SSL/TLS?
ssl_context* ssl_ctx{nullptr};
};
/**
* Creates a new tcp client that can connect to an ip address + port. By default the socket
* created will be in non-blocking mode, meaning that any sending or receiving of data should
* poll for event readiness prior.
* @param scheduler The io scheduler to drive the tcp client.
* @param opts See tcp_client::options for more information.
*/
tcp_client(
std::shared_ptr<io_scheduler> scheduler,
options opts = options{
.address = {net::ip_address::from_string("127.0.0.1")}, .port = 8080, .ssl_ctx = nullptr});
tcp_client(const tcp_client&) = delete;
tcp_client(tcp_client&& other);
auto operator=(const tcp_client&) noexcept -> tcp_client& = delete;
auto operator =(tcp_client&& other) noexcept -> tcp_client&;
~tcp_client();
/**
* @return The tcp socket this client is using.
* @{
**/
auto socket() -> net::socket& { return m_socket; }
auto socket() const -> const net::socket& { return m_socket; }
/** @} */
/**
* Connects to the address+port with the given timeout. Once connected calling this function
* only returns the connected status, it will not reconnect.
* @param timeout How long to wait for the connection to establish? Timeout of zero is indefinite.
* @return The result status of trying to connect.
*/
auto connect(std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task<net::connect_status>;
/**
* If this client is connected and the connection is SSL/TLS then perform the ssl handshake.
* This must be done after a successful connect() call for clients that are initiating a
* connection to a server. This must be done after a successful accept() call for clients that
* have been accepted by a tcp_server. TCP server 'client's start in the connected state and
* thus skip the connect() call.
*
* tcp_client initiating to a server:
* tcp_client client{...options...};
* co_await client.connect();
* co_await client.ssl_handshake(); // <-- only perform if ssl/tls connection
*
* tcp_server accepting a client connection:
* tcp_server server{...options...};
* co_await server.poll();
* auto client = server.accept();
* if(client.socket().is_valid())
* {
* co_await client.ssl_handshake(); // <-- only perform if ssl/tls connection
* }
* @param timeout How long to allow for the ssl handshake to successfully complete?
* @return The result of the ssl handshake.
*/
auto ssl_handshake(std::chrono::milliseconds timeout = std::chrono::milliseconds{0})
-> coro::task<ssl_handshake_status>;
/**
* Polls for the given operation on this client's tcp socket. This should be done prior to
* calling recv and after a send that doesn't send the entire buffer.
* @param op The poll operation to perform, use read for incoming data and write for outgoing.
* @param timeout The amount of time to wait for the poll event to be ready. Use zero for infinte timeout.
* @return The status result of th poll operation. When poll_status::event is returned then the
* event operation is ready.
*/
auto poll(coro::poll_op op, std::chrono::milliseconds timeout = std::chrono::milliseconds{0})
-> coro::task<poll_status>
{
return m_io_scheduler->poll(m_socket, op, timeout);
}
/**
* Receives incoming data into the given buffer. By default since all tcp client sockets are set
* to non-blocking use co_await poll() to determine when data is ready to be received.
* @param buffer Received bytes are written into this buffer up to the buffers size.
* @return The status of the recv call and a span of the bytes recevied (if any). The span of
* bytes will be a subspan or full span of the given input buffer.
*/
template<concepts::mutable_buffer buffer_type>
auto recv(buffer_type&& buffer) -> std::pair<recv_status, std::span<char>>
{
// If the user requested zero bytes, just return.
if (buffer.empty())
{
return {recv_status::ok, std::span<char>{}};
}
if (m_options.ssl_ctx == nullptr)
{
auto bytes_recv = ::recv(m_socket.native_handle(), buffer.data(), buffer.size(), 0);
if (bytes_recv > 0)
{
// Ok, we've recieved some data.
return {recv_status::ok, std::span<char>{buffer.data(), static_cast<size_t>(bytes_recv)}};
}
else if (bytes_recv == 0)
{
// On TCP stream sockets 0 indicates the connection has been closed by the peer.
return {recv_status::closed, std::span<char>{}};
}
else
{
// Report the error to the user.
return {static_cast<recv_status>(errno), std::span<char>{}};
}
}
else
{
ERR_clear_error();
size_t bytes_recv{0};
int r = SSL_read_ex(m_ssl_info.m_ssl_ptr.get(), buffer.data(), buffer.size(), &bytes_recv);
if (r == 0)
{
int err = SSL_get_error(m_ssl_info.m_ssl_ptr.get(), r);
if (err == SSL_ERROR_WANT_READ)
{
return {recv_status::would_block, std::span<char>{}};
}
else
{
// TODO: Flesh out all possible ssl errors:
// https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html
return {recv_status::ssl_error, std::span<char>{}};
}
}
else
{
return {recv_status::ok, std::span<char>{buffer.data(), static_cast<size_t>(bytes_recv)}};
}
}
}
/**
* Sends outgoing data from the given buffer. If a partial write occurs then use co_await poll()
* to determine when the tcp client socket is ready to be written to again. On partial writes
* the status will be 'ok' and the span returned will be non-empty, it will contain the buffer
* span data that was not written to the client's socket.
* @param buffer The data to write on the tcp socket.
* @return The status of the send call and a span of any remaining bytes not sent. If all bytes
* were successfully sent the status will be 'ok' and the remaining span will be empty.
*/
template<concepts::const_buffer buffer_type>
auto send(const buffer_type& buffer) -> std::pair<send_status, std::span<const char>>
{
// If the user requested zero bytes, just return.
if (buffer.empty())
{
return {send_status::ok, std::span<const char>{buffer.data(), buffer.size()}};
}
if (m_options.ssl_ctx == nullptr)
{
auto bytes_sent = ::send(m_socket.native_handle(), buffer.data(), buffer.size(), 0);
if (bytes_sent >= 0)
{
// Some or all of the bytes were written.
return {send_status::ok, std::span<const char>{buffer.data() + bytes_sent, buffer.size() - bytes_sent}};
}
else
{
// Due to the error none of the bytes were written.
return {static_cast<send_status>(errno), std::span<const char>{buffer.data(), buffer.size()}};
}
}
else
{
ERR_clear_error();
size_t bytes_sent{0};
int r = SSL_write_ex(m_ssl_info.m_ssl_ptr.get(), buffer.data(), buffer.size(), &bytes_sent);
if (r == 0)
{
int err = SSL_get_error(m_ssl_info.m_ssl_ptr.get(), r);
if (err == SSL_ERROR_WANT_WRITE)
{
return {send_status::would_block, std::span<char>{}};
}
else
{
// TODO: Flesh out all possible ssl errors:
// https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html
return {send_status::ssl_error, std::span<char>{}};
}
}
else
{
return {send_status::ok, std::span<const char>{buffer.data() + bytes_sent, buffer.size() - bytes_sent}};
}
}
}
private:
struct ssl_deleter
{
auto operator()(SSL* ssl) const -> void { SSL_free(ssl); }
};
using ssl_unique_ptr = std::unique_ptr<SSL, ssl_deleter>;
enum class ssl_connection_type
{
/// This connection is a client connecting to a server.
connect,
/// This connection is an accepted connection on a sever.
accept
};
struct ssl_info
{
ssl_info() {}
explicit ssl_info(ssl_connection_type type) : m_ssl_connection_type(type) {}
ssl_info(const ssl_info&) noexcept = delete;
ssl_info(ssl_info&& other) noexcept
: m_ssl_connection_type(std::exchange(other.m_ssl_connection_type, ssl_connection_type::connect)),
m_ssl_ptr(std::move(other.m_ssl_ptr)),
m_ssl_error(std::exchange(other.m_ssl_error, false)),
m_ssl_handshake_status(std::move(other.m_ssl_handshake_status))
{
}
auto operator=(const ssl_info&) noexcept -> ssl_info& = delete;
auto operator=(ssl_info&& other) noexcept -> ssl_info&
{
if (std::addressof(other) != this)
{
m_ssl_connection_type = std::exchange(other.m_ssl_connection_type, ssl_connection_type::connect);
m_ssl_ptr = std::move(other.m_ssl_ptr);
m_ssl_error = std::exchange(other.m_ssl_error, false);
m_ssl_handshake_status = std::move(other.m_ssl_handshake_status);
}
return *this;
}
/// What kind of connection is this, client initiated connect or server side accept?
ssl_connection_type m_ssl_connection_type{ssl_connection_type::connect};
/// OpenSSL ssl connection.
ssl_unique_ptr m_ssl_ptr{nullptr};
/// Was there an error with the SSL/TLS connection?
bool m_ssl_error{false};
/// The result of the ssl handshake.
std::optional<ssl_handshake_status> m_ssl_handshake_status{std::nullopt};
};
/// The tcp_server creates already connected clients and provides a tcp socket pre-built.
friend tcp_server;
tcp_client(std::shared_ptr<io_scheduler> scheduler, net::socket socket, options opts);
/// The scheduler that will drive this tcp client.
std::shared_ptr<io_scheduler> m_io_scheduler{nullptr};
/// Options for what server to connect to.
options m_options{};
/// The tcp socket.
net::socket m_socket{-1};
/// Cache the status of the connect in the event the user calls connect() again.
std::optional<net::connect_status> m_connect_status{std::nullopt};
/// SSL/TLS specific information if m_options.ssl_ctx != nullptr.
ssl_info m_ssl_info{};
static auto ssl_shutdown_and_free(
std::shared_ptr<io_scheduler> io_scheduler,
net::socket s,
ssl_unique_ptr ssl_ptr,
std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task<void>;
};
} // namespace coro::net

View file

@ -1,74 +0,0 @@
#pragma once
#include "coro/net/ip_address.hpp"
#include "coro/net/socket.hpp"
#include "coro/net/tcp_client.hpp"
#include "coro/task.hpp"
#include <fcntl.h>
#include <sys/socket.h>
namespace coro
{
class io_scheduler;
} // namespace coro
namespace coro::net
{
class ssl_context;
class tcp_server
{
public:
struct options
{
/// The ip address for the tcp server to bind and listen on.
net::ip_address address{net::ip_address::from_string("0.0.0.0")};
/// The port for the tcp server to bind and listen on.
uint16_t port{8080};
/// The kernel backlog of connections to buffer.
int32_t backlog{128};
/// Should this tcp server use TLS/SSL? If provided all accepted connections will use the
/// given SSL certificate and private key to secure the connections.
ssl_context* ssl_ctx{nullptr};
};
tcp_server(
std::shared_ptr<io_scheduler> scheduler,
options opts = options{
.address = net::ip_address::from_string("0.0.0.0"), .port = 8080, .backlog = 128, .ssl_ctx = nullptr});
tcp_server(const tcp_server&) = delete;
tcp_server(tcp_server&& other);
auto operator=(const tcp_server&) -> tcp_server& = delete;
auto operator =(tcp_server&& other) -> tcp_server&;
~tcp_server() = default;
/**
* Polls for new incoming tcp connections.
* @param timeout How long to wait for a new connection before timing out, zero waits indefinitely.
* @return The result of the poll, 'event' means the poll was successful and there is at least 1
* connection ready to be accepted.
*/
auto poll(std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task<coro::poll_status>
{
return m_io_scheduler->poll(m_accept_socket, coro::poll_op::read, timeout);
}
/**
* Accepts an incoming tcp client connection. On failure the tcp clients socket will be set to
* and invalid state, use the socket.is_value() to verify the client was correctly accepted.
* @return The newly connected tcp client connection.
*/
auto accept() -> coro::net::tcp_client;
private:
/// The io scheduler for awaiting new connections.
std::shared_ptr<io_scheduler> m_io_scheduler{nullptr};
/// The bind and listen options for this server.
options m_options;
/// The socket for accepting new tcp connections on.
net::socket m_accept_socket{-1};
};
} // namespace coro::net

View file

@ -1,147 +0,0 @@
#pragma once
#include "coro/concepts/buffer.hpp"
#include "coro/io_scheduler.hpp"
#include "coro/net/ip_address.hpp"
#include "coro/net/recv_status.hpp"
#include "coro/net/send_status.hpp"
#include "coro/net/socket.hpp"
#include "coro/task.hpp"
#include <chrono>
#include <span>
#include <variant>
namespace coro
{
class io_scheduler;
} // namespace coro
namespace coro::net
{
class udp_peer
{
public:
struct info
{
/// The ip address of the peer.
net::ip_address address{net::ip_address::from_string("127.0.0.1")};
/// The port of the peer.
uint16_t port{8080};
auto operator<=>(const info& other) const = default;
};
/**
* Creates a udp peer that can send packets but not receive them. This udp peer will not explicitly
* bind to a local ip+port.
*/
explicit udp_peer(std::shared_ptr<io_scheduler> scheduler, net::domain_t domain = net::domain_t::ipv4);
/**
* Creates a udp peer that can send and receive packets. This peer will bind to the given ip_port.
*/
explicit udp_peer(std::shared_ptr<io_scheduler> scheduler, const info& bind_info);
udp_peer(const udp_peer&) = delete;
udp_peer(udp_peer&&) = default;
auto operator=(const udp_peer&) noexcept -> udp_peer& = delete;
auto operator=(udp_peer&&) noexcept -> udp_peer& = default;
~udp_peer() = default;
/**
* @param op The poll operation to perform on the udp socket. Note that if this is a send only
* udp socket (did not bind) then polling for read will not work.
* @param timeout The timeout for the poll operation to be ready.
* @return The result status of the poll operation.
*/
auto poll(poll_op op, std::chrono::milliseconds timeout = std::chrono::milliseconds{0})
-> coro::task<coro::poll_status>
{
co_return co_await m_io_scheduler->poll(m_socket, op, timeout);
}
/**
* @param peer_info The peer to send the data to.
* @param buffer The data to send.
* @return The status of send call and a span view of any data that wasn't sent. This data if
* un-sent will correspond to bytes at the end of the given buffer.
*/
template<concepts::const_buffer buffer_type>
auto sendto(const info& peer_info, const buffer_type& buffer) -> std::pair<send_status, std::span<const char>>
{
if (buffer.empty())
{
return {send_status::ok, std::span<const char>{}};
}
sockaddr_in peer{};
peer.sin_family = static_cast<int>(peer_info.address.domain());
peer.sin_port = htons(peer_info.port);
peer.sin_addr = *reinterpret_cast<const in_addr*>(peer_info.address.data().data());
socklen_t peer_len{sizeof(peer)};
auto bytes_sent = ::sendto(
m_socket.native_handle(), buffer.data(), buffer.size(), 0, reinterpret_cast<sockaddr*>(&peer), peer_len);
if (bytes_sent >= 0)
{
return {send_status::ok, std::span<const char>{buffer.data() + bytes_sent, buffer.size() - bytes_sent}};
}
else
{
return {static_cast<send_status>(errno), std::span<const char>{}};
}
}
/**
* @param buffer The buffer to receive data into.
* @return The receive status, if ok then also the peer who sent the data and the data.
* The span view of the data will be set to the size of the received data, this will
* always start at the beggining of the buffer but depending on how large the data was
* it might not fill the entire buffer.
*/
template<concepts::mutable_buffer buffer_type>
auto recvfrom(buffer_type&& buffer) -> std::tuple<recv_status, udp_peer::info, std::span<char>>
{
// The user must bind locally to be able to receive packets.
if (!m_bound)
{
return {recv_status::udp_not_bound, udp_peer::info{}, std::span<char>{}};
}
sockaddr_in peer{};
socklen_t peer_len{sizeof(peer)};
auto bytes_read = ::recvfrom(
m_socket.native_handle(), buffer.data(), buffer.size(), 0, reinterpret_cast<sockaddr*>(&peer), &peer_len);
if (bytes_read < 0)
{
return {static_cast<recv_status>(errno), udp_peer::info{}, std::span<char>{}};
}
std::span<const uint8_t> ip_addr_view{
reinterpret_cast<uint8_t*>(&peer.sin_addr.s_addr),
sizeof(peer.sin_addr.s_addr),
};
return {
recv_status::ok,
udp_peer::info{
.address = net::ip_address{ip_addr_view, static_cast<net::domain_t>(peer.sin_family)},
.port = ntohs(peer.sin_port)},
std::span<char>{buffer.data(), static_cast<size_t>(bytes_read)}};
}
private:
/// The scheduler that will drive this udp client.
std::shared_ptr<io_scheduler> m_io_scheduler;
/// The udp socket.
net::socket m_socket{-1};
/// Did the user request this udp socket is bound locally to receive packets?
bool m_bound{false};
};
} // namespace coro::net

View file

@ -1,500 +0,0 @@
#include "coro/io_scheduler.hpp"
#include <atomic>
#include <cstring>
#include <optional>
#include <sys/epoll.h>
#include <sys/eventfd.h>
#include <sys/socket.h>
#include <sys/timerfd.h>
#include <sys/types.h>
#include <unistd.h>
using namespace std::chrono_literals;
namespace coro
{
io_scheduler::io_scheduler(options opts)
: m_opts(std::move(opts)),
m_epoll_fd(epoll_create1(EPOLL_CLOEXEC)),
m_shutdown_fd(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)),
m_timer_fd(timerfd_create(CLOCK_MONOTONIC, TFD_NONBLOCK | TFD_CLOEXEC)),
m_schedule_fd(eventfd(0, EFD_CLOEXEC | EFD_NONBLOCK)),
m_owned_tasks(new coro::task_container<coro::io_scheduler>(*this))
{
if (opts.execution_strategy == execution_strategy_t::process_tasks_on_thread_pool)
{
m_thread_pool = std::make_unique<thread_pool>(std::move(m_opts.pool));
}
epoll_event e{};
e.events = EPOLLIN;
e.data.ptr = const_cast<void*>(m_shutdown_ptr);
epoll_ctl(m_epoll_fd, EPOLL_CTL_ADD, m_shutdown_fd, &e);
e.data.ptr = const_cast<void*>(m_timer_ptr);
epoll_ctl(m_epoll_fd, EPOLL_CTL_ADD, m_timer_fd, &e);
e.data.ptr = const_cast<void*>(m_schedule_ptr);
epoll_ctl(m_epoll_fd, EPOLL_CTL_ADD, m_schedule_fd, &e);
if (m_opts.thread_strategy == thread_strategy_t::spawn)
{
m_io_thread = std::thread([this]() { process_events_dedicated_thread(); });
}
// else manual mode, the user must call process_events.
}
io_scheduler::~io_scheduler()
{
shutdown();
if (m_io_thread.joinable())
{
m_io_thread.join();
}
if (m_epoll_fd != -1)
{
close(m_epoll_fd);
m_epoll_fd = -1;
}
if (m_timer_fd != -1)
{
close(m_timer_fd);
m_timer_fd = -1;
}
if (m_schedule_fd != -1)
{
close(m_schedule_fd);
m_schedule_fd = -1;
}
if (m_owned_tasks != nullptr)
{
delete static_cast<coro::task_container<coro::io_scheduler>*>(m_owned_tasks);
m_owned_tasks = nullptr;
}
}
auto io_scheduler::process_events(std::chrono::milliseconds timeout) -> std::size_t
{
process_events_manual(timeout);
return size();
}
auto io_scheduler::schedule_after(std::chrono::milliseconds amount) -> coro::task<void>
{
return yield_for(amount);
}
auto io_scheduler::schedule_at(time_point time) -> coro::task<void>
{
return yield_until(time);
}
auto io_scheduler::yield_for(std::chrono::milliseconds amount) -> coro::task<void>
{
if (amount <= 0ms)
{
co_await schedule();
}
else
{
// Yield/timeout tasks are considered live in the scheduler and must be accounted for. Note
// that if the user gives an invalid amount and schedule() is directly called it will account
// for the scheduled task there.
m_size.fetch_add(1, std::memory_order::release);
// Yielding does not requiring setting the timer position on the poll info since
// it doesn't have a corresponding 'event' that can trigger, it always waits for
// the timeout to occur before resuming.
detail::poll_info pi{};
add_timer_token(clock::now() + amount, pi);
co_await pi;
m_size.fetch_sub(1, std::memory_order::release);
}
co_return;
}
auto io_scheduler::yield_until(time_point time) -> coro::task<void>
{
auto now = clock::now();
// If the requested time is in the past (or now!) bail out!
if (time <= now)
{
co_await schedule();
}
else
{
m_size.fetch_add(1, std::memory_order::release);
auto amount = std::chrono::duration_cast<std::chrono::milliseconds>(time - now);
detail::poll_info pi{};
add_timer_token(now + amount, pi);
co_await pi;
m_size.fetch_sub(1, std::memory_order::release);
}
co_return;
}
auto io_scheduler::poll(fd_t fd, coro::poll_op op, std::chrono::milliseconds timeout) -> coro::task<poll_status>
{
// Because the size will drop when this coroutine suspends every poll needs to undo the subtraction
// on the number of active tasks in the scheduler. When this task is resumed by the event loop.
m_size.fetch_add(1, std::memory_order::release);
// Setup two events, a timeout event and the actual poll for op event.
// Whichever triggers first will delete the other to guarantee only one wins.
// The resume token will be set by the scheduler to what the event turned out to be.
bool timeout_requested = (timeout > 0ms);
detail::poll_info pi{};
pi.m_fd = fd;
if (timeout_requested)
{
pi.m_timer_pos = add_timer_token(clock::now() + timeout, pi);
}
epoll_event e{};
e.events = static_cast<uint32_t>(op) | EPOLLONESHOT | EPOLLRDHUP;
e.data.ptr = &pi;
if (epoll_ctl(m_epoll_fd, EPOLL_CTL_ADD, fd, &e) == -1)
{
std::cerr << "epoll ctl error on fd " << fd << "\n";
}
// The event loop will 'clean-up' whichever event didn't win since the coroutine is scheduled
// onto the thread poll its possible the other type of event could trigger while its waiting
// to execute again, thus restarting the coroutine twice, that would be quite bad.
auto result = co_await pi;
m_size.fetch_sub(1, std::memory_order::release);
co_return result;
}
auto io_scheduler::shutdown() noexcept -> void
{
// Only allow shutdown to occur once.
if (m_shutdown_requested.exchange(true, std::memory_order::acq_rel) == false)
{
if (m_thread_pool != nullptr)
{
m_thread_pool->shutdown();
}
// Signal the event loop to stop asap, triggering the event fd is safe.
uint64_t value{1};
auto written = ::write(m_shutdown_fd, &value, sizeof(value));
(void)written;
if (m_io_thread.joinable())
{
m_io_thread.join();
}
}
}
auto io_scheduler::process_events_manual(std::chrono::milliseconds timeout) -> void
{
bool expected{false};
if (m_io_processing.compare_exchange_strong(expected, true, std::memory_order::release, std::memory_order::relaxed))
{
process_events_execute(timeout);
m_io_processing.exchange(false, std::memory_order::release);
}
}
auto io_scheduler::process_events_dedicated_thread() -> void
{
if (m_opts.on_io_thread_start_functor != nullptr)
{
m_opts.on_io_thread_start_functor();
}
m_io_processing.exchange(true, std::memory_order::release);
// Execute tasks until stopped or there are no more tasks to complete.
while (!m_shutdown_requested.load(std::memory_order::acquire) || size() > 0)
{
process_events_execute(m_default_timeout);
}
m_io_processing.exchange(false, std::memory_order::release);
if (m_opts.on_io_thread_stop_functor != nullptr)
{
m_opts.on_io_thread_stop_functor();
}
}
auto io_scheduler::process_events_execute(std::chrono::milliseconds timeout) -> void
{
auto event_count = epoll_wait(m_epoll_fd, m_events.data(), m_max_events, timeout.count());
if (event_count > 0)
{
for (std::size_t i = 0; i < static_cast<std::size_t>(event_count); ++i)
{
epoll_event& event = m_events[i];
void* handle_ptr = event.data.ptr;
if (handle_ptr == m_timer_ptr)
{
// Process all events that have timed out.
process_timeout_execute();
}
else if (handle_ptr == m_schedule_ptr)
{
// Process scheduled coroutines.
process_scheduled_execute_inline();
}
else if (handle_ptr == m_shutdown_ptr)
[[unlikely]]
{
// Nothing to do , just needed to wake-up and smell the flowers
}
else
{
// Individual poll task wake-up.
process_event_execute(static_cast<detail::poll_info*>(handle_ptr), event_to_poll_status(event.events));
}
}
}
// Its important to not resume any handles until the full set is accounted for. If a timeout
// and an event for the same handle happen in the same epoll_wait() call then inline processing
// will destruct the poll_info object before the second event is handled. This is also possible
// with thread pool processing, but probably has an extremely low chance of occuring due to
// the thread switch required. If m_max_events == 1 this would be unnecessary.
if (!m_handles_to_resume.empty())
{
if (m_opts.execution_strategy == execution_strategy_t::process_tasks_inline)
{
for (auto& handle : m_handles_to_resume)
{
handle.resume();
}
}
else
{
m_thread_pool->resume(m_handles_to_resume);
}
m_handles_to_resume.clear();
}
}
auto io_scheduler::event_to_poll_status(uint32_t events) -> poll_status
{
if (events & EPOLLIN || events & EPOLLOUT)
{
return poll_status::event;
}
else if (events & EPOLLERR)
{
return poll_status::error;
}
else if (events & EPOLLRDHUP || events & EPOLLHUP)
{
return poll_status::closed;
}
throw std::runtime_error{"invalid epoll state"};
}
auto io_scheduler::process_scheduled_execute_inline() -> void
{
std::vector<std::coroutine_handle<>> tasks{};
{
// Acquire the entire list, and then reset it.
std::scoped_lock lk{m_scheduled_tasks_mutex};
tasks.swap(m_scheduled_tasks);
// Clear the schedule eventfd if this is a scheduled task.
eventfd_t value{0};
eventfd_read(m_schedule_fd, &value);
// Clear the in memory flag to reduce eventfd_* calls on scheduling.
m_schedule_fd_triggered.exchange(false, std::memory_order::release);
}
// This set of handles can be safely resumed now since they do not have a corresponding timeout event.
for (auto& task : tasks)
{
task.resume();
}
m_size.fetch_sub(tasks.size(), std::memory_order::release);
}
auto io_scheduler::process_event_execute(detail::poll_info* pi, poll_status status) -> void
{
if (!pi->m_processed)
{
std::atomic_thread_fence(std::memory_order::acquire);
// Its possible the event and the timeout occurred in the same epoll, make sure only one
// is ever processed, the other is discarded.
pi->m_processed = true;
// Given a valid fd always remove it from epoll so the next poll can blindly EPOLL_CTL_ADD.
if (pi->m_fd != -1)
{
epoll_ctl(m_epoll_fd, EPOLL_CTL_DEL, pi->m_fd, nullptr);
}
// Since this event triggered, remove its corresponding timeout if it has one.
if (pi->m_timer_pos.has_value())
{
remove_timer_token(pi->m_timer_pos.value());
}
pi->m_poll_status = status;
while (pi->m_awaiting_coroutine == nullptr)
{
std::atomic_thread_fence(std::memory_order::acquire);
}
m_handles_to_resume.emplace_back(pi->m_awaiting_coroutine);
}
}
auto io_scheduler::process_timeout_execute() -> void
{
std::vector<detail::poll_info*> poll_infos{};
auto now = clock::now();
{
std::scoped_lock lk{m_timed_events_mutex};
while (!m_timed_events.empty())
{
auto first = m_timed_events.begin();
auto [tp, pi] = *first;
if (tp <= now)
{
m_timed_events.erase(first);
poll_infos.emplace_back(pi);
}
else
{
break;
}
}
}
for (auto pi : poll_infos)
{
if (!pi->m_processed)
{
// Its possible the event and the timeout occurred in the same epoll, make sure only one
// is ever processed, the other is discarded.
pi->m_processed = true;
// Since this timed out, remove its corresponding event if it has one.
if (pi->m_fd != -1)
{
epoll_ctl(m_epoll_fd, EPOLL_CTL_DEL, pi->m_fd, nullptr);
}
while (pi->m_awaiting_coroutine == nullptr)
{
std::atomic_thread_fence(std::memory_order::acquire);
// std::cerr << "process_event_execute() has a nullptr event\n";
}
m_handles_to_resume.emplace_back(pi->m_awaiting_coroutine);
pi->m_poll_status = coro::poll_status::timeout;
}
}
// Update the time to the next smallest time point, re-take the current now time
// since updating and resuming tasks could shift the time.
update_timeout(clock::now());
}
auto io_scheduler::add_timer_token(time_point tp, detail::poll_info& pi) -> timed_events::iterator
{
std::scoped_lock lk{m_timed_events_mutex};
auto pos = m_timed_events.emplace(tp, &pi);
// If this item was inserted as the smallest time point, update the timeout.
if (pos == m_timed_events.begin())
{
update_timeout(clock::now());
}
return pos;
}
auto io_scheduler::remove_timer_token(timed_events::iterator pos) -> void
{
{
std::scoped_lock lk{m_timed_events_mutex};
auto is_first = (m_timed_events.begin() == pos);
m_timed_events.erase(pos);
// If this was the first item, update the timeout. It would be acceptable to just let it
// also fire the timeout as the event loop will ignore it since nothing will have timed
// out but it feels like the right thing to do to update it to the correct timeout value.
if (is_first)
{
update_timeout(clock::now());
}
}
}
auto io_scheduler::update_timeout(time_point now) -> void
{
if (!m_timed_events.empty())
{
auto& [tp, pi] = *m_timed_events.begin();
auto amount = tp - now;
auto seconds = std::chrono::duration_cast<std::chrono::seconds>(amount);
amount -= seconds;
auto nanoseconds = std::chrono::duration_cast<std::chrono::nanoseconds>(amount);
// As a safeguard if both values end up as zero (or negative) then trigger the timeout
// immediately as zero disarms timerfd according to the man pages and negative values
// will result in an error return value.
if (seconds <= 0s)
{
seconds = 0s;
if (nanoseconds <= 0ns)
{
// just trigger "immediately"!
nanoseconds = 1ns;
}
}
itimerspec ts{};
ts.it_value.tv_sec = seconds.count();
ts.it_value.tv_nsec = nanoseconds.count();
if (timerfd_settime(m_timer_fd, 0, &ts, nullptr) == -1)
{
std::cerr << "Failed to set timerfd errorno=[" << std::string{strerror(errno)} << "].";
}
}
else
{
// Setting these values to zero disables the timer.
itimerspec ts{};
ts.it_value.tv_sec = 0;
ts.it_value.tv_nsec = 0;
if (timerfd_settime(m_timer_fd, 0, &ts, nullptr) == -1)
{
std::cerr << "Failed to set timerfd errorno=[" << std::string{strerror(errno)} << "].";
}
}
}
} // namespace coro

View file

@ -1,29 +0,0 @@
#include "coro/net/connect.hpp"
#include <stdexcept>
namespace coro::net
{
const static std::string connect_status_connected{"connected"};
const static std::string connect_status_invalid_ip_address{"invalid_ip_address"};
const static std::string connect_status_timeout{"timeout"};
const static std::string connect_status_error{"error"};
auto to_string(const connect_status& status) -> const std::string&
{
switch (status)
{
case connect_status::connected:
return connect_status_connected;
case connect_status::invalid_ip_address:
return connect_status_invalid_ip_address;
case connect_status::timeout:
return connect_status_timeout;
case connect_status::error:
return connect_status_error;
}
throw std::logic_error{"Invalid/unknown connect status."};
}
} // namespace coro::net

View file

@ -1,193 +0,0 @@
#include "coro/net/dns_resolver.hpp"
#include <arpa/inet.h>
#include <iostream>
#include <netdb.h>
namespace coro::net
{
uint64_t dns_resolver::m_ares_count{0};
std::mutex dns_resolver::m_ares_mutex{};
auto ares_dns_callback(void* arg, int status, int /*timeouts*/, struct hostent* host) -> void
{
auto& result = *static_cast<dns_result*>(arg);
--result.m_pending_dns_requests;
if (host == nullptr || status != ARES_SUCCESS)
{
result.m_status = dns_status::error;
}
else
{
result.m_status = dns_status::complete;
for (size_t i = 0; host->h_addr_list[i] != nullptr; ++i)
{
size_t len = (host->h_addrtype == AF_INET) ? net::ip_address::ipv4_len : net::ip_address::ipv6_len;
net::ip_address ip_addr{
std::span<const uint8_t>{reinterpret_cast<const uint8_t*>(host->h_addr_list[i]), len},
static_cast<net::domain_t>(host->h_addrtype)};
result.m_ip_addresses.emplace_back(std::move(ip_addr));
}
}
if (result.m_pending_dns_requests == 0)
{
result.m_resume.set(result.m_io_scheduler);
}
}
dns_result::dns_result(coro::io_scheduler& scheduler, coro::event& resume, uint64_t pending_dns_requests)
: m_io_scheduler(scheduler),
m_resume(resume),
m_pending_dns_requests(pending_dns_requests)
{
}
dns_resolver::dns_resolver(std::shared_ptr<io_scheduler> scheduler, std::chrono::milliseconds timeout)
: m_io_scheduler(std::move(scheduler)),
m_timeout(timeout),
m_task_container(m_io_scheduler)
{
if (m_io_scheduler == nullptr)
{
throw std::runtime_error{"dns_resolver cannot have nullptr scheduler"};
}
{
std::scoped_lock g{m_ares_mutex};
if (m_ares_count == 0)
{
auto ares_status = ares_library_init(ARES_LIB_INIT_ALL);
if (ares_status != ARES_SUCCESS)
{
throw std::runtime_error{ares_strerror(ares_status)};
}
}
++m_ares_count;
}
auto channel_init_status = ares_init(&m_ares_channel);
if (channel_init_status != ARES_SUCCESS)
{
throw std::runtime_error{ares_strerror(channel_init_status)};
}
}
dns_resolver::~dns_resolver()
{
if (m_ares_channel != nullptr)
{
ares_destroy(m_ares_channel);
m_ares_channel = nullptr;
}
{
std::scoped_lock g{m_ares_mutex};
--m_ares_count;
if (m_ares_count == 0)
{
ares_library_cleanup();
}
}
}
auto dns_resolver::host_by_name(const net::hostname& hn) -> coro::task<std::unique_ptr<dns_result>>
{
coro::event resume_event{};
auto result_ptr = std::make_unique<dns_result>(*m_io_scheduler.get(), resume_event, 2);
ares_gethostbyname(m_ares_channel, hn.data().data(), AF_INET, ares_dns_callback, result_ptr.get());
ares_gethostbyname(m_ares_channel, hn.data().data(), AF_INET6, ares_dns_callback, result_ptr.get());
// Add all required poll calls for ares to kick off the dns requests.
ares_poll();
// Suspend until this specific result is completed by ares.
co_await resume_event;
co_return result_ptr;
}
auto dns_resolver::ares_poll() -> void
{
std::array<ares_socket_t, ARES_GETSOCK_MAXNUM> ares_sockets{};
std::array<poll_op, ARES_GETSOCK_MAXNUM> poll_ops{};
int bitmask = ares_getsock(m_ares_channel, ares_sockets.data(), ARES_GETSOCK_MAXNUM);
size_t new_sockets{0};
for (size_t i = 0; i < ARES_GETSOCK_MAXNUM; ++i)
{
uint64_t ops{0};
if (ARES_GETSOCK_READABLE(bitmask, i))
{
ops |= static_cast<uint64_t>(poll_op::read);
}
if (ARES_GETSOCK_WRITABLE(bitmask, i))
{
ops |= static_cast<uint64_t>(poll_op::write);
}
if (ops != 0)
{
poll_ops[i] = static_cast<poll_op>(ops);
++new_sockets;
}
else
{
// According to ares usage within curl once a bitmask for a socket is zero the rest of
// the bitmask will also be zero.
break;
}
}
std::vector<coro::task<void>> poll_tasks{};
for (size_t i = 0; i < new_sockets; ++i)
{
auto fd = static_cast<fd_t>(ares_sockets[i]);
// If this socket is not currently actively polling, start polling!
if (m_active_sockets.emplace(fd).second)
{
m_task_container.start(make_poll_task(fd, poll_ops[i]));
}
}
}
auto dns_resolver::make_poll_task(fd_t fd, poll_op ops) -> coro::task<void>
{
auto result = co_await m_io_scheduler->poll(fd, ops, m_timeout);
switch (result)
{
case poll_status::event:
{
auto read_sock = poll_op_readable(ops) ? fd : ARES_SOCKET_BAD;
auto write_sock = poll_op_writeable(ops) ? fd : ARES_SOCKET_BAD;
ares_process_fd(m_ares_channel, read_sock, write_sock);
}
break;
case poll_status::timeout:
ares_process_fd(m_ares_channel, ARES_SOCKET_BAD, ARES_SOCKET_BAD);
break;
case poll_status::closed:
// might need to do something like call with two ARES_SOCKET_BAD?
break;
case poll_status::error:
// might need to do something like call with two ARES_SOCKET_BAD?
break;
}
// Remove from the list of actively polling sockets.
m_active_sockets.erase(fd);
// Re-initialize sockets/polls for ares since this one has now triggered.
ares_poll();
co_return;
};
} // namespace coro::net

View file

@ -1,20 +0,0 @@
#include "coro/net/ip_address.hpp"
namespace coro::net
{
static std::string domain_ipv4{"ipv4"};
static std::string domain_ipv6{"ipv6"};
auto to_string(domain_t domain) -> const std::string&
{
switch (domain)
{
case domain_t::ipv4:
return domain_ipv4;
case domain_t::ipv6:
return domain_ipv6;
}
throw std::runtime_error{"coro::net::to_string(domain_t) unknown domain"};
}
} // namespace coro::net

View file

@ -1,59 +0,0 @@
#include "coro/net/recv_status.hpp"
namespace coro::net
{
static const std::string recv_status_ok{"ok"};
static const std::string recv_status_closed{"closed"};
static const std::string recv_status_udp_not_bound{"udp_not_bound"};
// static const std::string recv_status_try_again{"try_again"};
static const std::string recv_status_would_block{"would_block"};
static const std::string recv_status_bad_file_descriptor{"bad_file_descriptor"};
static const std::string recv_status_connection_refused{"connection_refused"};
static const std::string recv_status_memory_fault{"memory_fault"};
static const std::string recv_status_interrupted{"interrupted"};
static const std::string recv_status_invalid_argument{"invalid_argument"};
static const std::string recv_status_no_memory{"no_memory"};
static const std::string recv_status_not_connected{"not_connected"};
static const std::string recv_status_not_a_socket{"not_a_socket"};
static const std::string recv_status_unknown{"unknown"};
static const std::string recv_status_ssl_error{"ssl_error"};
auto to_string(recv_status status) -> const std::string&
{
switch (status)
{
case recv_status::ok:
return recv_status_ok;
case recv_status::closed:
return recv_status_closed;
case recv_status::udp_not_bound:
return recv_status_udp_not_bound;
// case recv_status::try_again: return recv_status_try_again;
case recv_status::would_block:
return recv_status_would_block;
case recv_status::bad_file_descriptor:
return recv_status_bad_file_descriptor;
case recv_status::connection_refused:
return recv_status_connection_refused;
case recv_status::memory_fault:
return recv_status_memory_fault;
case recv_status::interrupted:
return recv_status_interrupted;
case recv_status::invalid_argument:
return recv_status_invalid_argument;
case recv_status::no_memory:
return recv_status_no_memory;
case recv_status::not_connected:
return recv_status_not_connected;
case recv_status::not_a_socket:
return recv_status_not_a_socket;
case recv_status::ssl_error:
return recv_status_ssl_error;
}
return recv_status_unknown;
}
} // namespace coro::net

View file

@ -1,5 +0,0 @@
#include "coro/net/send_status.hpp"
namespace coro::net
{
} // namespace coro::net

View file

@ -1,130 +0,0 @@
#include "coro/net/socket.hpp"
namespace coro::net
{
auto socket::type_to_os(type_t type) -> int
{
switch (type)
{
case type_t::udp:
return SOCK_DGRAM;
case type_t::tcp:
return SOCK_STREAM;
default:
throw std::runtime_error{"Unknown socket::type_t."};
}
}
auto socket::operator=(socket&& other) noexcept -> socket&
{
if (std::addressof(other) != this)
{
m_fd = std::exchange(other.m_fd, -1);
}
return *this;
}
auto socket::blocking(blocking_t block) -> bool
{
if (m_fd < 0)
{
return false;
}
int flags = fcntl(m_fd, F_GETFL, 0);
if (flags == -1)
{
return false;
}
// Add or subtract non-blocking flag.
flags = (block == blocking_t::yes) ? flags & ~O_NONBLOCK : (flags | O_NONBLOCK);
return (fcntl(m_fd, F_SETFL, flags) == 0);
}
auto socket::shutdown(poll_op how) -> bool
{
if (m_fd != -1)
{
int h{0};
switch (how)
{
case poll_op::read:
h = SHUT_RD;
break;
case poll_op::write:
h = SHUT_WR;
break;
case poll_op::read_write:
h = SHUT_RDWR;
break;
}
return (::shutdown(m_fd, h) == 0);
}
return false;
}
auto socket::close() -> void
{
if (m_fd != -1)
{
::close(m_fd);
m_fd = -1;
}
}
auto make_socket(const socket::options& opts) -> socket
{
socket s{::socket(static_cast<int>(opts.domain), socket::type_to_os(opts.type), 0)};
if (s.native_handle() < 0)
{
throw std::runtime_error{"Failed to create socket."};
}
if (opts.blocking == socket::blocking_t::no)
{
if (s.blocking(socket::blocking_t::no) == false)
{
throw std::runtime_error{"Failed to set socket to non-blocking mode."};
}
}
return s;
}
auto make_accept_socket(const socket::options& opts, const net::ip_address& address, uint16_t port, int32_t backlog)
-> socket
{
socket s = make_socket(opts);
int sock_opt{1};
if (setsockopt(s.native_handle(), SOL_SOCKET, SO_REUSEADDR | SO_REUSEPORT, &sock_opt, sizeof(sock_opt)) < 0)
{
throw std::runtime_error{"Failed to setsockopt(SO_REUSEADDR | SO_REUSEPORT)"};
}
sockaddr_in server{};
server.sin_family = static_cast<int>(opts.domain);
server.sin_port = htons(port);
server.sin_addr = *reinterpret_cast<const in_addr*>(address.data().data());
if (bind(s.native_handle(), (struct sockaddr*)&server, sizeof(server)) < 0)
{
throw std::runtime_error{"Failed to bind."};
}
if (opts.type == socket::type_t::tcp)
{
if (listen(s.native_handle(), backlog) < 0)
{
throw std::runtime_error{"Failed to listen."};
}
}
return s;
}
} // namespace coro::net

View file

@ -1,65 +0,0 @@
#include "coro/net/ssl_context.hpp"
#include <iostream>
namespace coro::net
{
uint64_t ssl_context::m_ssl_context_count{0};
std::mutex ssl_context::m_ssl_context_mutex{};
ssl_context::ssl_context()
{
{
std::scoped_lock g{m_ssl_context_mutex};
if (m_ssl_context_count == 0)
{
OPENSSL_init_ssl(0, nullptr);
}
++m_ssl_context_count;
}
m_ssl_ctx = SSL_CTX_new(TLS_method());
if (m_ssl_ctx == nullptr)
{
throw std::runtime_error{"Failed to initialize OpenSSL Context object."};
}
// Disable SSLv3
SSL_CTX_set_options(m_ssl_ctx, SSL_OP_ALL | SSL_OP_NO_SSLv3);
}
ssl_context::ssl_context(
std::filesystem::path certificate,
ssl_file_type certificate_type,
std::filesystem::path private_key,
ssl_file_type private_key_type)
: ssl_context()
{
if (auto r = SSL_CTX_use_certificate_file(m_ssl_ctx, certificate.c_str(), static_cast<int>(certificate_type));
r != 1)
{
throw std::runtime_error{"Failed to load certificate file " + certificate.string()};
}
if (auto r = SSL_CTX_use_PrivateKey_file(m_ssl_ctx, private_key.c_str(), static_cast<int>(private_key_type));
r != 1)
{
throw std::runtime_error{"Failed to load private key file " + private_key.string()};
}
if (auto r = SSL_CTX_check_private_key(m_ssl_ctx); r != 1)
{
throw std::runtime_error{"Certificate and private key do not match."};
}
}
ssl_context::~ssl_context()
{
if (m_ssl_ctx != nullptr)
{
SSL_CTX_free(m_ssl_ctx);
m_ssl_ctx = nullptr;
}
}
} // namespace coro::net

View file

@ -1,259 +0,0 @@
#include "coro/net/tcp_client.hpp"
namespace coro::net
{
using namespace std::chrono_literals;
tcp_client::tcp_client(std::shared_ptr<io_scheduler> scheduler, options opts)
: m_io_scheduler(std::move(scheduler)),
m_options(std::move(opts)),
m_socket(net::make_socket(
net::socket::options{m_options.address.domain(), net::socket::type_t::tcp, net::socket::blocking_t::no}))
{
if (m_io_scheduler == nullptr)
{
throw std::runtime_error{"tcp_client cannot have nullptr io_scheduler"};
}
}
tcp_client::tcp_client(std::shared_ptr<io_scheduler> scheduler, net::socket socket, options opts)
: m_io_scheduler(std::move(scheduler)),
m_options(std::move(opts)),
m_socket(std::move(socket)),
m_connect_status(connect_status::connected),
m_ssl_info(ssl_connection_type::accept)
{
// io_scheduler is assumed good since it comes from a tcp_server.
// Force the socket to be non-blocking.
m_socket.blocking(coro::net::socket::blocking_t::no);
}
tcp_client::tcp_client(tcp_client&& other)
: m_io_scheduler(std::move(other.m_io_scheduler)),
m_options(std::move(other.m_options)),
m_socket(std::move(other.m_socket)),
m_connect_status(std::exchange(other.m_connect_status, std::nullopt)),
m_ssl_info(std::move(other.m_ssl_info))
{
}
tcp_client::~tcp_client()
{
// If this tcp client is using SSL and the connection did not have an ssl error, schedule a task
// to shutdown the connection cleanly. This is done on a background scheduled task since the
// tcp client's destructor cannot co_await the SSL_shutdown() read and write poll operations.
if (m_ssl_info.m_ssl_ptr != nullptr && !m_ssl_info.m_ssl_error)
{
// Should the shutdown timeout be configurable?
m_io_scheduler->schedule(ssl_shutdown_and_free(
m_io_scheduler, std::move(m_socket), std::move(m_ssl_info.m_ssl_ptr), std::chrono::seconds{30}));
}
}
auto tcp_client::operator=(tcp_client&& other) noexcept -> tcp_client&
{
if (std::addressof(other) != this)
{
m_io_scheduler = std::move(other.m_io_scheduler);
m_options = std::move(other.m_options);
m_socket = std::move(other.m_socket);
m_connect_status = std::exchange(other.m_connect_status, std::nullopt);
m_ssl_info = std::move(other.m_ssl_info);
}
return *this;
}
auto tcp_client::connect(std::chrono::milliseconds timeout) -> coro::task<connect_status>
{
// Only allow the user to connect per tcp client once, if they need to re-connect they should
// make a new tcp_client.
if (m_connect_status.has_value())
{
co_return m_connect_status.value();
}
// This enforces the connection status is aways set on the client object upon returning.
auto return_value = [this](connect_status s) -> connect_status {
m_connect_status = s;
return s;
};
sockaddr_in server{};
server.sin_family = static_cast<int>(m_options.address.domain());
server.sin_port = htons(m_options.port);
server.sin_addr = *reinterpret_cast<const in_addr*>(m_options.address.data().data());
auto cret = ::connect(m_socket.native_handle(), (struct sockaddr*)&server, sizeof(server));
if (cret == 0)
{
co_return return_value(connect_status::connected);
}
else if (cret == -1)
{
// If the connect is happening in the background poll for write on the socket to trigger
// when the connection is established.
if (errno == EAGAIN || errno == EINPROGRESS)
{
auto pstatus = co_await m_io_scheduler->poll(m_socket, poll_op::write, timeout);
if (pstatus == poll_status::event)
{
int result{0};
socklen_t result_length{sizeof(result)};
if (getsockopt(m_socket.native_handle(), SOL_SOCKET, SO_ERROR, &result, &result_length) < 0)
{
std::cerr << "connect failed to getsockopt after write poll event\n";
}
if (result == 0)
{
co_return return_value(connect_status::connected);
}
}
else if (pstatus == poll_status::timeout)
{
co_return return_value(connect_status::timeout);
}
}
}
co_return return_value(connect_status::error);
}
auto tcp_client::ssl_handshake(std::chrono::milliseconds timeout) -> coro::task<ssl_handshake_status>
{
if (!m_connect_status.has_value() || m_connect_status.value() != connect_status::connected)
{
// Can't ssl handshake if the connection isn't established.
co_return ssl_handshake_status::not_connected;
}
if (m_options.ssl_ctx == nullptr)
{
// ssl isn't setup
co_return ssl_handshake_status::ssl_context_required;
}
if (m_ssl_info.m_ssl_handshake_status.has_value())
{
// The user has already called this function.
co_return m_ssl_info.m_ssl_handshake_status.value();
}
// Enforce on any return past here to set the cached handshake status.
auto return_value = [this](ssl_handshake_status s) -> ssl_handshake_status {
m_ssl_info.m_ssl_handshake_status = s;
return s;
};
m_ssl_info.m_ssl_ptr = ssl_unique_ptr{SSL_new(m_options.ssl_ctx->native_handle())};
if (m_ssl_info.m_ssl_ptr == nullptr)
{
co_return return_value(ssl_handshake_status::ssl_resource_allocation_failed);
}
if (auto r = SSL_set_fd(m_ssl_info.m_ssl_ptr.get(), m_socket.native_handle()); r == 0)
{
co_return return_value(ssl_handshake_status::ssl_set_fd_failure);
}
if (m_ssl_info.m_ssl_connection_type == ssl_connection_type::connect)
{
SSL_set_connect_state(m_ssl_info.m_ssl_ptr.get());
}
else // ssl_connection_type::accept
{
SSL_set_accept_state(m_ssl_info.m_ssl_ptr.get());
}
int r{0};
ERR_clear_error();
while ((r = SSL_do_handshake(m_ssl_info.m_ssl_ptr.get())) != 1)
{
poll_op op{poll_op::read_write};
int err = SSL_get_error(m_ssl_info.m_ssl_ptr.get(), r);
if (err == SSL_ERROR_WANT_WRITE)
{
op = poll_op::write;
}
else if (err == SSL_ERROR_WANT_READ)
{
op = poll_op::read;
}
else
{
// char error_buffer[256];
// ERR_error_string(err, error_buffer);
// std::cerr << "ssl_handleshake error=[" << error_buffer << "]\n";
co_return return_value(ssl_handshake_status::handshake_failed);
}
// TODO: adjust timeout based on elapsed time so far.
auto pstatus = co_await m_io_scheduler->poll(m_socket, op, timeout);
switch (pstatus)
{
case poll_status::timeout:
co_return return_value(ssl_handshake_status::timeout);
case poll_status::error:
co_return return_value(ssl_handshake_status::poll_error);
case poll_status::closed:
co_return return_value(ssl_handshake_status::unexpected_close);
default:
// Event triggered, continue handshake.
break;
}
}
co_return return_value(ssl_handshake_status::ok);
}
auto tcp_client::ssl_shutdown_and_free(
std::shared_ptr<io_scheduler> io_scheduler,
net::socket s,
ssl_unique_ptr ssl_ptr,
std::chrono::milliseconds timeout) -> coro::task<void>
{
while (true)
{
auto r = SSL_shutdown(ssl_ptr.get());
if (r == 1) // shutdown complete
{
co_return;
}
else if (r == 0) // shutdown in progress
{
coro::poll_op op{coro::poll_op::read_write};
auto err = SSL_get_error(ssl_ptr.get(), r);
if (err == SSL_ERROR_WANT_WRITE)
{
op = coro::poll_op::write;
}
else if (err == SSL_ERROR_WANT_READ)
{
op = coro::poll_op::read;
}
else
{
co_return;
}
auto pstatus = co_await io_scheduler->poll(s, op, timeout);
switch (pstatus)
{
case poll_status::timeout:
case poll_status::error:
case poll_status::closed:
co_return;
default:
// continue shutdown.
break;
}
}
else // r < 0 error
{
co_return;
}
}
}
} // namespace coro::net

View file

@ -1,60 +0,0 @@
#include "coro/net/tcp_server.hpp"
#include "coro/io_scheduler.hpp"
namespace coro::net
{
tcp_server::tcp_server(std::shared_ptr<io_scheduler> scheduler, options opts)
: m_io_scheduler(std::move(scheduler)),
m_options(std::move(opts)),
m_accept_socket(net::make_accept_socket(
net::socket::options{net::domain_t::ipv4, net::socket::type_t::tcp, net::socket::blocking_t::no},
m_options.address,
m_options.port,
m_options.backlog))
{
if (m_io_scheduler == nullptr)
{
throw std::runtime_error{"tcp_server cannot have a nullptr io_scheduler"};
}
}
tcp_server::tcp_server(tcp_server&& other)
: m_io_scheduler(std::move(other.m_io_scheduler)),
m_options(std::move(other.m_options)),
m_accept_socket(std::move(other.m_accept_socket))
{
}
auto tcp_server::operator=(tcp_server&& other) -> tcp_server&
{
if (std::addressof(other) != this)
{
m_io_scheduler = std::move(other.m_io_scheduler);
m_options = std::move(other.m_options);
m_accept_socket = std::move(other.m_accept_socket);
}
return *this;
}
auto tcp_server::accept() -> coro::net::tcp_client
{
sockaddr_in client{};
constexpr const int len = sizeof(struct sockaddr_in);
net::socket s{::accept(m_accept_socket.native_handle(), (struct sockaddr*)&client, (socklen_t*)&len)};
std::span<const uint8_t> ip_addr_view{
reinterpret_cast<uint8_t*>(&client.sin_addr.s_addr),
sizeof(client.sin_addr.s_addr),
};
return tcp_client{
m_io_scheduler,
std::move(s),
tcp_client::options{
.address = net::ip_address{ip_addr_view, static_cast<net::domain_t>(client.sin_family)},
.port = ntohs(client.sin_port),
.ssl_ctx = m_options.ssl_ctx}};
};
} // namespace coro::net

View file

@ -1,21 +0,0 @@
#include "coro/net/udp_peer.hpp"
namespace coro::net
{
udp_peer::udp_peer(std::shared_ptr<io_scheduler> scheduler, net::domain_t domain)
: m_io_scheduler(std::move(scheduler)),
m_socket(net::make_socket(net::socket::options{domain, net::socket::type_t::udp, net::socket::blocking_t::no}))
{
}
udp_peer::udp_peer(std::shared_ptr<io_scheduler> scheduler, const info& bind_info)
: m_io_scheduler(std::move(scheduler)),
m_socket(net::make_accept_socket(
net::socket::options{bind_info.address.domain(), net::socket::type_t::udp, net::socket::blocking_t::no},
bind_info.address,
bind_info.port)),
m_bound(true)
{
}
} // namespace coro::net