diff --git a/CMakeLists.txt b/CMakeLists.txt index d230292..edabb7c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -14,26 +14,29 @@ set(CARES_INSTALL OFF CACHE INTERNAL "") add_subdirectory(vendor/c-ares/c-ares) set(LIBCORO_SOURCE_FILES + inc/coro/concepts/awaitable.hpp + inc/coro/concepts/buffer.hpp + inc/coro/concepts/promise.hpp + inc/coro/detail/void_value.hpp inc/coro/net/connect.hpp src/net/connect.cpp - inc/coro/net/dns_client.hpp src/net/dns_client.cpp + inc/coro/net/dns_resolver.hpp src/net/dns_resolver.cpp inc/coro/net/hostname.hpp inc/coro/net/ip_address.hpp src/net/ip_address.cpp + inc/coro/net/recv_status.hpp src/net/recv_status.cpp + inc/coro/net/send_status.hpp src/net/send_status.cpp inc/coro/net/socket.hpp src/net/socket.cpp inc/coro/net/tcp_client.hpp src/net/tcp_client.cpp inc/coro/net/tcp_server.hpp src/net/tcp_server.cpp - inc/coro/net/udp_client.hpp src/net/udp_client.cpp - inc/coro/net/udp_server.hpp src/net/udp_server.cpp + inc/coro/net/udp_peer.hpp src/net/udp_peer.cpp - inc/coro/awaitable.hpp inc/coro/coro.hpp inc/coro/event.hpp src/event.cpp inc/coro/generator.hpp inc/coro/io_scheduler.hpp src/io_scheduler.cpp inc/coro/latch.hpp inc/coro/poll.hpp - inc/coro/promise.hpp inc/coro/shutdown.hpp inc/coro/sync_wait.hpp src/sync_wait.cpp inc/coro/task.hpp diff --git a/inc/coro/awaitable.hpp b/inc/coro/concepts/awaitable.hpp similarity index 97% rename from inc/coro/awaitable.hpp rename to inc/coro/concepts/awaitable.hpp index 3fe9e9a..264ade0 100644 --- a/inc/coro/awaitable.hpp +++ b/inc/coro/concepts/awaitable.hpp @@ -5,7 +5,7 @@ #include #include -namespace coro +namespace coro::concepts { /** * This concept declares a type that is required to meet the c++20 coroutine operator co_await() @@ -72,4 +72,4 @@ struct awaitable_traits }; // clang-format on -} // namespace coro +} // namespace coro::concepts diff --git a/inc/coro/concepts/buffer.hpp b/inc/coro/concepts/buffer.hpp new file mode 100644 index 0000000..e1c90a7 --- /dev/null +++ b/inc/coro/concepts/buffer.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include + +namespace coro::concepts +{ + +// clang-format off +template +concept const_buffer = requires(const type t) +{ + { t.empty() } -> std::same_as; + { t.data() } -> std::same_as; + { t.size() } -> std::same_as; +}; + +template +concept mutable_buffer = requires(type t) +{ + { t.empty() } -> std::same_as; + { t.data() } -> std::same_as; + { t.size() } -> std::same_as; +}; +// clang-format on + +} // namespace coro::concepts diff --git a/inc/coro/concepts/promise.hpp b/inc/coro/concepts/promise.hpp new file mode 100644 index 0000000..c3e714a --- /dev/null +++ b/inc/coro/concepts/promise.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include "coro/concepts/awaitable.hpp" + +#include + +namespace coro::concepts +{ + +// clang-format off +template +concept promise = requires(type t) +{ + { t.get_return_object() } -> std::convertible_to>; + { t.initial_suspend() } -> awaiter; + { t.final_suspend() } -> awaiter; + { t.yield_value() } -> awaitable; +} +&& requires(type t, return_type return_value) +{ + std::same_as || + std::same_as || + requires { t.yield_value(return_value); }; +}; +// clang-format on + +} // namespace coro::concepts diff --git a/inc/coro/coro.hpp b/inc/coro/coro.hpp index e2b0b6b..e84bafa 100644 --- a/inc/coro/coro.hpp +++ b/inc/coro/coro.hpp @@ -1,21 +1,24 @@ #pragma once +#include "coro/concepts/awaitable.hpp" +#include "coro/concepts/buffer.hpp" +#include "coro/concepts/promise.hpp" + #include "coro/net/connect.hpp" -#include "coro/net/dns_client.hpp" +#include "coro/net/dns_resolver.hpp" #include "coro/net/hostname.hpp" #include "coro/net/ip_address.hpp" +#include "coro/net/recv_status.hpp" +#include "coro/net/send_status.hpp" #include "coro/net/socket.hpp" #include "coro/net/tcp_client.hpp" #include "coro/net/tcp_server.hpp" -#include "coro/net/udp_client.hpp" -#include "coro/net/udp_server.hpp" +#include "coro/net/udp_peer.hpp" -#include "coro/awaitable.hpp" #include "coro/event.hpp" #include "coro/generator.hpp" #include "coro/io_scheduler.hpp" #include "coro/latch.hpp" -#include "coro/promise.hpp" #include "coro/sync_wait.hpp" #include "coro/task.hpp" #include "coro/thread_pool.hpp" diff --git a/inc/coro/io_scheduler.hpp b/inc/coro/io_scheduler.hpp index 295e77d..22c5fe5 100644 --- a/inc/coro/io_scheduler.hpp +++ b/inc/coro/io_scheduler.hpp @@ -1,6 +1,6 @@ #pragma once -#include "coro/awaitable.hpp" +#include "coro/concepts/awaitable.hpp" #include "coro/poll.hpp" #include "coro/shutdown.hpp" #include "coro/net/socket.hpp" @@ -299,7 +299,7 @@ public: auto schedule(std::vector> tasks) -> bool; - template + template auto schedule(tasks_type&&... tasks) -> bool { if (is_shutdown()) diff --git a/inc/coro/net/connect.hpp b/inc/coro/net/connect.hpp index 693d369..2d1e1bd 100644 --- a/inc/coro/net/connect.hpp +++ b/inc/coro/net/connect.hpp @@ -14,8 +14,6 @@ enum class connect_status timeout, /// There was an error, use errno to get more information on the specific error. error, - /// The client was given a hostname but no dns client to resolve the ip address. - dns_client_required, /// The dns hostname lookup failed dns_lookup_failure }; diff --git a/inc/coro/net/dns_client.hpp b/inc/coro/net/dns_resolver.hpp similarity index 74% rename from inc/coro/net/dns_client.hpp rename to inc/coro/net/dns_resolver.hpp index a2c7c24..b5908ba 100644 --- a/inc/coro/net/dns_client.hpp +++ b/inc/coro/net/dns_resolver.hpp @@ -19,7 +19,7 @@ namespace coro::net { -class dns_client; +class dns_resolver; enum class dns_status { @@ -29,12 +29,20 @@ enum class dns_status class dns_result { - friend dns_client; + friend dns_resolver; public: explicit dns_result(coro::resume_token& token, uint64_t pending_dns_requests); ~dns_result() = default; + /** + * @return The status of the dns lookup. + */ auto status() const -> dns_status { return m_status; } + + /** + * @return If the result of the dns looked was successful then the list of ip addresses that + * were resolved from the hostname. + */ auto ip_addresses() const -> const std::vector& { return m_ip_addresses; } private: coro::resume_token& m_token; @@ -50,16 +58,19 @@ private: ) -> void; }; -class dns_client +class dns_resolver { public: - explicit dns_client(io_scheduler& scheduler, std::chrono::milliseconds timeout); - dns_client(const dns_client&) = delete; - dns_client(dns_client&&) = delete; - auto operator=(const dns_client&) noexcept -> dns_client& = delete; - auto operator=(dns_client&&) noexcept -> dns_client& = delete; - ~dns_client(); + explicit dns_resolver(io_scheduler& scheduler, std::chrono::milliseconds timeout); + dns_resolver(const dns_resolver&) = delete; + dns_resolver(dns_resolver&&) = delete; + auto operator=(const dns_resolver&) noexcept -> dns_resolver& = delete; + auto operator=(dns_resolver&&) noexcept -> dns_resolver& = delete; + ~dns_resolver(); + /** + * @param hn The hostname to resolve its ip addresses. + */ auto host_by_name(const net::hostname& hn) -> coro::task>; private: /// The io scheduler to drive the events for dns lookups. diff --git a/inc/coro/net/ip_address.hpp b/inc/coro/net/ip_address.hpp index dfa1761..5ee5dbe 100644 --- a/inc/coro/net/ip_address.hpp +++ b/inc/coro/net/ip_address.hpp @@ -98,6 +98,8 @@ public: return output; } + auto operator<=>(const ip_address& other) const = default; + private: domain_t m_domain{domain_t::ipv4}; std::array m_data{}; diff --git a/inc/coro/net/recv_status.hpp b/inc/coro/net/recv_status.hpp new file mode 100644 index 0000000..b04a832 --- /dev/null +++ b/inc/coro/net/recv_status.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include + +namespace coro::net +{ + +enum class recv_status : int64_t +{ + ok = 0, + /// The peer closed the socket. + closed = -1, + /// The udp socket has not been bind()'ed to a local port. + udp_not_bound = -2, + try_again = EAGAIN, + would_block = EWOULDBLOCK, + bad_file_descriptor = EBADF, + connection_refused = ECONNREFUSED, + memory_fault = EFAULT, + interrupted = EINTR, + invalid_argument = EINVAL, + no_memory = ENOMEM, + not_connected = ENOTCONN, + not_a_socket = ENOTSOCK +}; + +auto to_string(recv_status status) -> const std::string&; + +} // namespace coro::net diff --git a/inc/coro/net/send_status.hpp b/inc/coro/net/send_status.hpp new file mode 100644 index 0000000..ae77691 --- /dev/null +++ b/inc/coro/net/send_status.hpp @@ -0,0 +1,31 @@ +#pragma once + +#include +#include + +namespace coro::net +{ + +enum class send_status : int64_t +{ + ok = 0, + permission_denied = EACCES, + try_again = EAGAIN, + would_block = EWOULDBLOCK, + already_in_progress = EALREADY, + bad_file_descriptor = EBADF, + connection_reset = ECONNRESET, + no_peer_address = EDESTADDRREQ, + memory_fault = EFAULT, + interrupted = EINTR, + is_connection = EISCONN, + message_size = EMSGSIZE, + output_queue_full = ENOBUFS, + no_memory = ENOMEM, + not_connected = ENOTCONN, + not_a_socket = ENOTSOCK, + operationg_not_supported = EOPNOTSUPP, + pipe_closed = EPIPE +}; + +} // namespace coro::net diff --git a/inc/coro/net/socket.hpp b/inc/coro/net/socket.hpp index ed6f75e..0a615e9 100644 --- a/inc/coro/net/socket.hpp +++ b/inc/coro/net/socket.hpp @@ -19,54 +19,93 @@ class socket public: enum class type_t { + /// udp datagram socket udp, + /// tcp streaming socket tcp }; enum class blocking_t { + /// This socket should block on system calls. yes, + /// This socket should not block on system calls. no }; struct options { + /// The domain for the socket. domain_t domain; + /// The type of socket. type_t type; + /// If the socket should be blocking or non-blocking. blocking_t blocking; }; static auto type_to_os(type_t type) -> int; socket() = default; - explicit socket(int fd) : m_fd(fd) {} socket(const socket&) = delete; socket(socket&& other) : m_fd(std::exchange(other.m_fd, -1)) {} - auto operator=(const socket&) -> socket& = delete; - auto operator=(socket&& other) noexcept -> socket&; ~socket() { close(); } + /** + * This function returns true if the socket's file descriptor is a valid number, however it does + * not imply if the socket is still usable. + * @return True if the socket file descriptor is > 0. + */ + auto is_valid() const -> bool { return m_fd != -1; } + + /** + * @param block Sets the socket to the given blocking mode. + */ auto blocking(blocking_t block) -> bool; + /** + * @param how Shuts the socket down with the given operations. + * @param Returns true if the sockets given operations were shutdown. + */ auto shutdown(poll_op how = poll_op::read_write) -> bool; + /** + * Closes the socket and sets this socket to an invalid state. + */ auto close() -> void; + /** + * @return The native handle (file descriptor) for this socket. + */ auto native_handle() const -> int { return m_fd; } private: int m_fd{-1}; }; +/** + * Creates a socket with the given socket options, this typically is used for creating sockets to + * use within client objects, e.g. tcp_client and udp_client. + * @param opts See socket::options for more details. + */ auto make_socket(const socket::options& opts) -> socket; +/** + * Creates a socket that can accept connections or packets with the given socket options, address, + * port and backlog. This is used for creating sockets to use within server objects, e.g. + * tcp_server and udp_server. + * @param opts See socket::options for more details + * @param address The ip address to bind to. If the type of socket is tcp then it will also listen. + * @param port The port to bind to. + * @param backlog If the type of socket is tcp then the backlog of connections to allow. Does nothing + * for udp types. + */ auto make_accept_socket( - const socket::options& opts, + const socket::options& opts, const net::ip_address& address, uint16_t port, int32_t backlog = 128) -> socket; diff --git a/inc/coro/net/tcp_client.hpp b/inc/coro/net/tcp_client.hpp index 2290172..58367f9 100644 --- a/inc/coro/net/tcp_client.hpp +++ b/inc/coro/net/tcp_client.hpp @@ -1,18 +1,22 @@ #pragma once -#include "coro/net/dns_client.hpp" +#include "coro/concepts/buffer.hpp" #include "coro/net/ip_address.hpp" -#include "coro/net/hostname.hpp" +#include "coro/net/recv_status.hpp" +#include "coro/net/send_status.hpp" #include "coro/net/socket.hpp" #include "coro/net/connect.hpp" #include "coro/poll.hpp" #include "coro/task.hpp" +#include "coro/io_scheduler.hpp" #include #include #include #include #include +#include +#include namespace coro { @@ -22,49 +26,150 @@ class io_scheduler; namespace coro::net { +class tcp_server; + class tcp_client { public: struct options { - /// The hostname or ip address to connect to. If using hostname then a dns client must be provided. - std::variant address{net::ip_address::from_string("127.0.0.1")}; + /// The ip address to connect to. Use a dns_resolver to turn hostnames into ip addresses. + net::ip_address address{net::ip_address::from_string("127.0.0.1")}; /// The port to connect to. - int16_t port{8080}; - /// The protocol domain to connect with. - net::domain_t domain{net::domain_t::ipv4}; - /// If using a hostname to connect to then provide a dns client to lookup the host's ip address. - /// This is optional if using ip addresses directly. - net::dns_client* dns{nullptr}; + uint16_t port{8080}; }; - tcp_client(io_scheduler& scheduler, options opts = options{ - .address = {net::ip_address::from_string("127.0.0.1")}, - .port = 8080, - .domain = net::domain_t::ipv4, - .dns = nullptr}); + /** + * Creates a new tcp client that can connect to an ip address + port. By default the socket + * created will be in non-blocking mode, meaning that any sending or receiving of data should + * poll for event readiness prior. + * @param scheduler The io scheduler to drive the tcp client. + * @param opts See tcp_client::options for more information. + */ + tcp_client( + io_scheduler& scheduler, + options opts = options{ + .address = {net::ip_address::from_string("127.0.0.1")}, + .port = 8080}); tcp_client(const tcp_client&) = delete; tcp_client(tcp_client&&) = default; auto operator=(const tcp_client&) noexcept -> tcp_client& = delete; auto operator=(tcp_client&&) noexcept -> tcp_client& = default; ~tcp_client() = default; - auto connect(std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task; + /** + * @return The tcp socket this client is using. + * @{ + **/ + auto socket() -> net::socket& { return m_socket; } + auto socket() const -> const net::socket& { return m_socket; } + /** @} */ + /** + * Connects to the address+port with the given timeout. Once connected calling this function + * only returns the connected status, it will not reconnect. + * @param timeout How long to wait for the connection to establish? Timeout of zero is indefinite. + * @return The result status of trying to connect. + */ + auto connect( + std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task; + + /** + * Polls for the given operation on this client's tcp socket. This should be done prior to + * calling recv and after a send that doesn't send the entire buffer. + * @param op The poll operation to perform, use read for incoming data and write for outgoing. + * @param timeout The amount of time to wait for the poll event to be ready. Use zero for infinte timeout. + * @return The status result of th poll operation. When poll_status::event is returned then the + * event operation is ready. + */ + auto poll( + coro::poll_op op, + std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task + { + co_return co_await m_io_scheduler.poll(m_socket, op, timeout); + } + + /** + * Receives incoming data into the given buffer. By default since all tcp client sockets are set + * to non-blocking use co_await poll() to determine when data is ready to be received. + * @param buffer Received bytes are written into this buffer up to the buffers size. + * @return The status of the recv call and a span of the bytes recevied (if any). The span of + * bytes will be a subspan or full span of the given input buffer. + */ + template auto recv( - std::span buffer, - std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task>; + buffer_type&& buffer) -> std::pair> + { + // If the user requested zero bytes, just return. + if(buffer.empty()) + { + return {recv_status::ok, std::span{}}; + } + + auto bytes_recv = ::recv(m_socket.native_handle(), buffer.data(), buffer.size(), 0); + if(bytes_recv > 0) + { + // Ok, we've recieved some data. + return {recv_status::ok, std::span{buffer.data(), static_cast(bytes_recv)}}; + } + else if (bytes_recv == 0) + { + // On TCP stream sockets 0 indicates the connection has been closed by the peer. + return {recv_status::closed, std::span{}}; + } + else + { + // Report the error to the user. + return {static_cast(errno), std::span{}}; + } + } + + /** + * Sends outgoing data from the given buffer. If a partial write occurs then use co_await poll() + * to determine when the tcp client socket is ready to be written to again. On partial writes + * the status will be 'ok' and the span returned will be non-empty, it will contain the buffer + * span data that was not written to the client's socket. + * @param buffer The data to write on the tcp socket. + * @return The status of the send call and a span of any remaining bytes not sent. If all bytes + * were successfully sent the status will be 'ok' and the remaining span will be empty. + */ + template auto send( - const std::span buffer, - std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task>; + const buffer_type& buffer) -> std::pair> + { + // If the user requested zero bytes, just return. + if(buffer.empty()) + { + return {send_status::ok, std::span{buffer.data(), buffer.size()}}; + } + + auto bytes_sent = ::send(m_socket.native_handle(), buffer.data(), buffer.size(), 0); + if(bytes_sent >= 0) + { + // Some or all of the bytes were written. + return {send_status::ok, std::span{buffer.data() + bytes_sent, buffer.size() - bytes_sent}}; + } + else + { + // Due to the error none of the bytes were written. + return {static_cast(errno), std::span{buffer.data(), buffer.size()}}; + } + } private: + /// The tcp_server creates already connected clients and provides a tcp socket pre-built. + friend tcp_server; + tcp_client( + io_scheduler& scheduler, + net::socket socket, + options opts); + /// The scheduler that will drive this tcp client. io_scheduler& m_io_scheduler; /// Options for what server to connect to. - options m_options; + options m_options{}; /// The tcp socket. - net::socket m_socket; + net::socket m_socket{-1}; /// Cache the status of the connect in the event the user calls connect() again. std::optional m_connect_status{std::nullopt}; }; diff --git a/inc/coro/net/tcp_server.hpp b/inc/coro/net/tcp_server.hpp index 235cad2..6a18f5d 100644 --- a/inc/coro/net/tcp_server.hpp +++ b/inc/coro/net/tcp_server.hpp @@ -1,6 +1,7 @@ #pragma once #include "coro/net/ip_address.hpp" +#include "coro/net/tcp_client.hpp" #include "coro/io_scheduler.hpp" #include "coro/net/socket.hpp" #include "coro/task.hpp" @@ -14,53 +15,52 @@ namespace coro::net class tcp_server : public io_scheduler { public: - using on_connection_t = std::function(tcp_server&, net::socket)>; - struct options { - net::ip_address address = net::ip_address::from_string("0.0.0.0"); - uint16_t port = 8080; - int32_t backlog = 128; - on_connection_t on_connection = nullptr; - io_scheduler::options io_options{}; + /// The ip address for the tcp server to bind and listen on. + net::ip_address address{net::ip_address::from_string("0.0.0.0")}; + /// The port for the tcp server to bind and listen on. + uint16_t port{8080}; + /// The kernel backlog of connections to buffer. + int32_t backlog{128}; }; - explicit tcp_server( + tcp_server( + io_scheduler& scheduler, options opts = options{ .address = net::ip_address::from_string("0.0.0.0"), .port = 8080, - .backlog = 128, - .on_connection = [](tcp_server&, net::socket) -> task { co_return; }, - .io_options = io_scheduler::options{}}); + .backlog = 128}); tcp_server(const tcp_server&) = delete; tcp_server(tcp_server&&) = delete; auto operator=(const tcp_server&) -> tcp_server& = delete; auto operator=(tcp_server&&) -> tcp_server& = delete; + ~tcp_server() = default; - ~tcp_server() override; + /** + * Polls for new incoming tcp connections. + * @param timeout How long to wait for a new connection before timing out, zero waits indefinitely. + * @return The result of the poll, 'event' means the poll was successful and there is at least 1 + * connection ready to be accepted. + */ + auto poll(std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task; - auto empty() const -> bool { return size() == 0; } - - auto size() const -> size_t - { - // Take one off for the accept task so the user doesn't have to account for the hidden task. - auto size = io_scheduler::size(); - return (size > 0) ? size - 1 : 0; - } - - auto shutdown(shutdown_t wait_for_tasks = shutdown_t::sync) -> void override; + /** + * Accepts an incoming tcp client connection. On failure the tcp clients socket will be set to + * and invalid state, use the socket.is_value() to verify the client was correctly accepted. + * @return The newly connected tcp client connection. + */ + auto accept() -> coro::net::tcp_client; private: + /// The io scheduler for awaiting new connections. + io_scheduler& m_io_scheduler; + /// The bind and listen options for this server. options m_options; - - /// Should the accept task continue accepting new connections? - std::atomic m_accept_new_connections{true}; - std::atomic m_accept_task_exited{false}; - net::socket m_accept_socket{-1}; - - auto make_accept_task() -> coro::task; + /// The socket for accepting new tcp connections on. + net::socket m_accept_socket{-1}; }; } // namespace coro::net diff --git a/inc/coro/net/udp_client.hpp b/inc/coro/net/udp_client.hpp deleted file mode 100644 index fecb0a9..0000000 --- a/inc/coro/net/udp_client.hpp +++ /dev/null @@ -1,53 +0,0 @@ -#pragma once - -#include "coro/net/hostname.hpp" -#include "coro/net/ip_address.hpp" -#include "coro/net/socket.hpp" -#include "coro/task.hpp" - -#include -#include -#include - -namespace coro -{ -class io_scheduler; -} // namespace coro - -namespace coro::net -{ - -class udp_client -{ -public: - struct options - { - /// The ip address to connect to. If using hostname then a dns client must be provided. - net::ip_address address{net::ip_address::from_string("127.0.0.1")}; - /// The port to connect to. - uint16_t port{8080}; - }; - - udp_client(io_scheduler& scheduler, options opts = options{ - .address = {net::ip_address::from_string("127.0.0.1")}, - .port = 8080}); - udp_client(const udp_client&) = delete; - udp_client(udp_client&&) = default; - auto operator=(const udp_client&) noexcept -> udp_client& = delete; - auto operator=(udp_client&&) noexcept -> udp_client& = default; - ~udp_client() = default; - - auto sendto( - const std::span buffer, - std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task; - -private: - /// The scheduler that will drive this udp client. - io_scheduler& m_io_scheduler; - /// Options for what server to connect to. - options m_options; - /// The udp socket. - net::socket m_socket{-1}; -}; - -} // namespace coro::net diff --git a/inc/coro/net/udp_peer.hpp b/inc/coro/net/udp_peer.hpp new file mode 100644 index 0000000..74f3d85 --- /dev/null +++ b/inc/coro/net/udp_peer.hpp @@ -0,0 +1,168 @@ +#pragma once + +#include "coro/concepts/buffer.hpp" +#include "coro/net/ip_address.hpp" +#include "coro/net/socket.hpp" +#include "coro/net/send_status.hpp" +#include "coro/net/recv_status.hpp" +#include "coro/task.hpp" +#include "coro/io_scheduler.hpp" + +#include +#include +#include + +namespace coro +{ +class io_scheduler; +} // namespace coro + +namespace coro::net +{ + +class udp_peer +{ +public: + struct info + { + /// The ip address of the peer. + net::ip_address address{net::ip_address::from_string("127.0.0.1")}; + /// The port of the peer. + uint16_t port{8080}; + + auto operator<=>(const info& other) const = default; + }; + + /** + * Creates a udp peer that can send packets but not receive them. This udp peer will not explicitly + * bind to a local ip+port. + */ + explicit udp_peer( + io_scheduler& scheduler, + net::domain_t domain = net::domain_t::ipv4); + + /** + * Creates a udp peer that can send and receive packets. This peer will bind to the given ip_port. + */ + explicit udp_peer( + io_scheduler& scheduler, + const info& bind_info); + + udp_peer(const udp_peer&) = delete; + udp_peer(udp_peer&&) = default; + auto operator=(const udp_peer&) noexcept -> udp_peer& = delete; + auto operator=(udp_peer&&) noexcept -> udp_peer& = default; + ~udp_peer() = default; + + /** + * @param op The poll operation to perform on the udp socket. Note that if this is a send only + * udp socket (did not bind) then polling for read will not work. + * @param timeout The timeout for the poll operation to be ready. + * @return The result status of the poll operation. + */ + auto poll( + poll_op op, + std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task + { + co_return co_await m_io_scheduler.poll(m_socket, op, timeout); + } + + /** + * @param peer_info The peer to send the data to. + * @param buffer The data to send. + * @return The status of send call and a span view of any data that wasn't sent. This data if + * un-sent will correspond to bytes at the end of the given buffer. + */ + template + auto sendto( + const info& peer_info, + const buffer_type& buffer) -> std::pair> + { + if(buffer.empty()) + { + return {send_status::ok, std::span{}}; + } + + sockaddr_in peer{}; + peer.sin_family = static_cast(peer_info.address.domain()); + peer.sin_port = htons(peer_info.port); + peer.sin_addr = *reinterpret_cast(peer_info.address.data().data()); + + socklen_t peer_len{sizeof(peer)}; + + auto bytes_sent = ::sendto( + m_socket.native_handle(), + buffer.data(), + buffer.size(), + 0, + reinterpret_cast(&peer), + peer_len); + + if(bytes_sent >= 0) + { + return {send_status::ok, std::span{buffer.data() + bytes_sent, buffer.size() - bytes_sent}}; + } + else + { + return {static_cast(errno), std::span{}}; + } + } + + /** + * @param buffer The buffer to receive data into. + * @return The receive status, if ok then also the peer who sent the data and the data. + * The span view of the data will be set to the size of the received data, this will + * always start at the beggining of the buffer but depending on how large the data was + * it might not fill the entire buffer. + */ + template + auto recvfrom( + buffer_type&& buffer) -> std::tuple> + { + // The user must bind locally to be able to receive packets. + if(!m_bound) + { + return {recv_status::udp_not_bound, udp_peer::info{}, std::span{}}; + } + + sockaddr_in peer{}; + socklen_t peer_len{sizeof(peer)}; + + auto bytes_read = ::recvfrom( + m_socket.native_handle(), + buffer.data(), + buffer.size(), + 0, + reinterpret_cast(&peer), + &peer_len); + + if(bytes_read < 0) + { + return {static_cast(errno), udp_peer::info{}, std::span{}}; + } + + std::span ip_addr_view{ + reinterpret_cast(&peer.sin_addr.s_addr), + sizeof(peer.sin_addr.s_addr), + }; + + return { + recv_status::ok, + udp_peer::info{ + .address = net::ip_address{ip_addr_view, static_cast(peer.sin_family)}, + .port = ntohs(peer.sin_port) + }, + std::span{buffer.data(), static_cast(bytes_read)} + }; + } + +private: + /// The scheduler that will drive this udp client. + io_scheduler& m_io_scheduler; + /// The udp socket. + net::socket m_socket{-1}; + /// Did the user request this udp socket is bound locally to receive packets? + bool m_bound{false}; +}; + +} // namespace coro::net diff --git a/inc/coro/net/udp_server.hpp b/inc/coro/net/udp_server.hpp deleted file mode 100644 index 93757f2..0000000 --- a/inc/coro/net/udp_server.hpp +++ /dev/null @@ -1,51 +0,0 @@ -#pragma once - -#include "coro/net/socket.hpp" -#include "coro/net/udp_client.hpp" -#include "coro/io_scheduler.hpp" - -#include -#include -#include -#include - -namespace coro::net -{ - -class udp_server -{ -public: - struct options - { - /// The local address to bind to to recv packets from. - net::ip_address address{net::ip_address::from_string("0.0.0.0")}; - /// The port to recv packets from. - uint16_t port{8080}; - }; - - explicit udp_server( - io_scheduler& io_scheduler, - options opts = - options{ - .address = net::ip_address::from_string("0.0.0.0"), - .port = 8080, - } - ); - - udp_server(const udp_server&) = delete; - udp_server(udp_server&&) = default; - auto operator=(const udp_server&) -> udp_server& = delete; - auto operator=(udp_server&&) -> udp_server& = default; - ~udp_server() = default; - - auto recvfrom( - std::span& buffer, - std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task>; - -private: - io_scheduler& m_io_scheduler; - options m_options; - net::socket m_accept_socket{-1}; -}; - -} // namespace coro::net diff --git a/inc/coro/promise.hpp b/inc/coro/promise.hpp deleted file mode 100644 index b51aa40..0000000 --- a/inc/coro/promise.hpp +++ /dev/null @@ -1,38 +0,0 @@ -#pragma once - -#include "coro/awaitable.hpp" - -#include - -namespace coro -{ -template -concept promise_type = requires(type t) -{ - { - t.get_return_object() - } - ->std::convertible_to>; - { - t.initial_suspend() - } - ->awaiter; - { - t.final_suspend() - } - ->awaiter; - { - t.yield_value() - } - ->awaitable; -} -&&requires(type t, return_type return_value) -{ - std::same_as || std::same_as || - requires - { - t.yield_value(return_value); - }; -}; - -} // namespace coro diff --git a/inc/coro/sync_wait.hpp b/inc/coro/sync_wait.hpp index 702cc48..cf9b61f 100644 --- a/inc/coro/sync_wait.hpp +++ b/inc/coro/sync_wait.hpp @@ -1,6 +1,6 @@ #pragma once -#include "coro/awaitable.hpp" +#include "coro/concepts/awaitable.hpp" #include #include @@ -182,7 +182,7 @@ private: coroutine_type m_coroutine; }; -template::awaiter_return_type> +template::awaiter_return_type> static auto make_sync_wait_task(awaitable&& a) -> sync_wait_task { if constexpr (std::is_void_v) @@ -198,7 +198,7 @@ static auto make_sync_wait_task(awaitable&& a) -> sync_wait_task } // namespace detail -template +template auto sync_wait(awaitable&& a) -> decltype(auto) { detail::sync_wait_event e{}; diff --git a/inc/coro/task.hpp b/inc/coro/task.hpp index 425d272..75bc53a 100644 --- a/inc/coro/task.hpp +++ b/inc/coro/task.hpp @@ -1,5 +1,7 @@ #pragma once +#include "coro/concepts/promise.hpp" + #include #include #include diff --git a/inc/coro/when_all.hpp b/inc/coro/when_all.hpp index 404afc0..96d5f2e 100644 --- a/inc/coro/when_all.hpp +++ b/inc/coro/when_all.hpp @@ -1,6 +1,6 @@ #pragma once -#include "coro/awaitable.hpp" +#include "coro/concepts/awaitable.hpp" #include "coro/detail/void_value.hpp" #include @@ -434,7 +434,7 @@ private: coroutine_handle_type m_coroutine; }; -template::awaiter_return_type> +template::awaiter_return_type> static auto make_when_all_task(awaitable&& a) -> when_all_task { if constexpr (std::is_void_v) @@ -450,15 +450,15 @@ static auto make_when_all_task(awaitable&& a) -> when_all_task } // namespace detail -template +template [[nodiscard]] auto when_all_awaitable(awaitables_type&&... awaitables) { return detail::when_all_ready_awaitable< - std::tuple::awaiter_return_type>...>>( + std::tuple::awaiter_return_type>...>>( std::make_tuple(detail::make_when_all_task(std::forward(awaitables))...)); } -template::awaiter_return_type> +template::awaiter_return_type> [[nodiscard]] auto when_all_awaitable(std::vector& awaitables) -> detail::when_all_ready_awaitable>> { diff --git a/src/net/connect.cpp b/src/net/connect.cpp index d472773..a7ad9f8 100644 --- a/src/net/connect.cpp +++ b/src/net/connect.cpp @@ -8,7 +8,6 @@ static std::string connect_status_connected{"connected"}; static std::string connect_status_invalid_ip_address{"invalid_ip_address"}; static std::string connect_status_timeout{"timeout"}; static std::string connect_status_error{"error"}; -static std::string connect_status_dns_client_required{"dns_client_required"}; static std::string connect_status_dns_lookup_failure{"dns_lookup_failure"}; auto to_string(const connect_status& status) -> const std::string& @@ -23,8 +22,6 @@ auto to_string(const connect_status& status) -> const std::string& return connect_status_timeout; case connect_status::error: return connect_status_error; - case connect_status::dns_client_required: - return connect_status_dns_client_required; case connect_status::dns_lookup_failure: return connect_status_dns_lookup_failure; } diff --git a/src/net/dns_client.cpp b/src/net/dns_resolver.cpp similarity index 91% rename from src/net/dns_client.cpp rename to src/net/dns_resolver.cpp index 5a9f1ae..34593bf 100644 --- a/src/net/dns_client.cpp +++ b/src/net/dns_resolver.cpp @@ -1,4 +1,4 @@ -#include "coro/net/dns_client.hpp" +#include "coro/net/dns_resolver.hpp" #include #include @@ -7,8 +7,8 @@ namespace coro::net { -uint64_t dns_client::m_ares_count{0}; -std::mutex dns_client::m_ares_mutex{}; +uint64_t dns_resolver::m_ares_count{0}; +std::mutex dns_resolver::m_ares_mutex{}; auto ares_dns_callback( void* arg, @@ -53,7 +53,7 @@ dns_result::dns_result(coro::resume_token& token, uint64_t pending_dns_req } -dns_client::dns_client(io_scheduler& scheduler, std::chrono::milliseconds timeout) +dns_resolver::dns_resolver(io_scheduler& scheduler, std::chrono::milliseconds timeout) : m_scheduler(scheduler), m_timeout(timeout) { @@ -77,7 +77,7 @@ dns_client::dns_client(io_scheduler& scheduler, std::chrono::milliseconds timeou } } -dns_client::~dns_client() +dns_resolver::~dns_resolver() { if(m_ares_channel != nullptr) { @@ -95,7 +95,7 @@ dns_client::~dns_client() } } -auto dns_client::host_by_name(const net::hostname& hn) -> coro::task> +auto dns_resolver::host_by_name(const net::hostname& hn) -> coro::task> { auto token = m_scheduler.make_resume_token(); auto result_ptr = std::make_unique(token, 2); @@ -111,7 +111,7 @@ auto dns_client::host_by_name(const net::hostname& hn) -> coro::task void +auto dns_resolver::ares_poll() -> void { std::array ares_sockets{}; std::array poll_ops{}; @@ -158,7 +158,7 @@ auto dns_client::ares_poll() -> void } } -auto dns_client::make_poll_task(io_scheduler::fd_t fd, poll_op ops) -> coro::task +auto dns_resolver::make_poll_task(io_scheduler::fd_t fd, poll_op ops) -> coro::task { auto result = co_await m_scheduler.poll(fd, ops, m_timeout); switch(result) diff --git a/src/net/recv_status.cpp b/src/net/recv_status.cpp new file mode 100644 index 0000000..0fa7681 --- /dev/null +++ b/src/net/recv_status.cpp @@ -0,0 +1,43 @@ +#include "coro/net/recv_status.hpp" + +namespace coro::net +{ + +static const std::string recv_status_ok{"ok"}; +static const std::string recv_status_closed{"closed"}; +static const std::string recv_status_udp_not_bound{"udp_not_bound"}; +// static const std::string recv_status_try_again{"try_again"}; +static const std::string recv_status_would_block{"would_block"}; +static const std::string recv_status_bad_file_descriptor{"bad_file_descriptor"}; +static const std::string recv_status_connection_refused{"connection_refused"}; +static const std::string recv_status_memory_fault{"memory_fault"}; +static const std::string recv_status_interrupted{"interrupted"}; +static const std::string recv_status_invalid_argument{"invalid_argument"}; +static const std::string recv_status_no_memory{"no_memory"}; +static const std::string recv_status_not_connected{"not_connected"}; +static const std::string recv_status_not_a_socket{"not_a_socket"}; +static const std::string recv_status_unknown{"unknown"}; + +auto to_string(recv_status status) -> const std::string& +{ + switch(status) + { + case recv_status::ok: return recv_status_ok; + case recv_status::closed: return recv_status_closed; + case recv_status::udp_not_bound: return recv_status_udp_not_bound; + //case recv_status::try_again: return recv_status_try_again; + case recv_status::would_block: return recv_status_would_block; + case recv_status::bad_file_descriptor: return recv_status_bad_file_descriptor; + case recv_status::connection_refused: return recv_status_connection_refused; + case recv_status::memory_fault: return recv_status_memory_fault; + case recv_status::interrupted: return recv_status_interrupted; + case recv_status::invalid_argument: return recv_status_invalid_argument; + case recv_status::no_memory: return recv_status_no_memory; + case recv_status::not_connected: return recv_status_not_connected; + case recv_status::not_a_socket: return recv_status_not_a_socket; + } + + return recv_status_unknown; +} + +} // namespace coro::net diff --git a/src/net/send_status.cpp b/src/net/send_status.cpp new file mode 100644 index 0000000..336e6b1 --- /dev/null +++ b/src/net/send_status.cpp @@ -0,0 +1,6 @@ +#include "coro/net/send_status.hpp" + +namespace coro::net +{ + +} // namespace coro::net diff --git a/src/net/tcp_client.cpp b/src/net/tcp_client.cpp index e9f0f90..6b9fe8e 100644 --- a/src/net/tcp_client.cpp +++ b/src/net/tcp_client.cpp @@ -10,10 +10,22 @@ using namespace std::chrono_literals; tcp_client::tcp_client(io_scheduler& scheduler, options opts) : m_io_scheduler(scheduler), m_options(std::move(opts)), - m_socket(net::make_socket(net::socket::options{m_options.domain, net::socket::type_t::tcp, net::socket::blocking_t::no})) + m_socket(net::make_socket(net::socket::options{ + m_options.address.domain(), + net::socket::type_t::tcp, + net::socket::blocking_t::no})) { } +tcp_client::tcp_client(io_scheduler& scheduler, net::socket socket, options opts) + : m_io_scheduler(scheduler), + m_options(std::move(opts)), + m_socket(std::move(socket)), + m_connect_status(connect_status::connected) +{ + +} + auto tcp_client::connect(std::chrono::milliseconds timeout) -> coro::task { if(m_connect_status.has_value() && m_connect_status.value() == connect_status::connected) @@ -21,44 +33,10 @@ auto tcp_client::connect(std::chrono::milliseconds timeout) -> coro::task result_ptr{nullptr}; - - // If the user provided a hostname then perform the dns lookup. - if(std::holds_alternative(m_options.address)) - { - if(m_options.dns == nullptr) - { - m_connect_status = connect_status::dns_client_required; - co_return connect_status::dns_client_required; - } - const auto& hn = std::get(m_options.address); - result_ptr = co_await m_options.dns->host_by_name(hn); - if(result_ptr->status() != net::dns_status::complete) - { - m_connect_status = connect_status::dns_lookup_failure; - co_return connect_status::dns_lookup_failure; - } - - if(result_ptr->ip_addresses().empty()) - { - m_connect_status = connect_status::dns_lookup_failure; - co_return connect_status::dns_lookup_failure; - } - - // TODO: for now we'll just take the first ip address given, but should probably allow the - // user to take preference on ipv4/ipv6 addresses. - ip_addr = &result_ptr->ip_addresses().front(); - } - else - { - ip_addr = &std::get(m_options.address); - } - sockaddr_in server{}; - server.sin_family = static_cast(m_options.domain); + server.sin_family = static_cast(m_options.address.domain()); server.sin_port = htons(m_options.port); - server.sin_addr = *reinterpret_cast(ip_addr->data().data()); + server.sin_addr = *reinterpret_cast(m_options.address.data().data()); auto cret = ::connect(m_socket.native_handle(), (struct sockaddr*)&server, sizeof(server)); if (cret == 0) @@ -102,28 +80,4 @@ auto tcp_client::connect(std::chrono::milliseconds timeout) -> coro::task buffer, std::chrono::milliseconds timeout) -> coro::task> -{ - auto pstatus = co_await m_io_scheduler.poll(m_socket, poll_op::read, timeout); - ssize_t bread{0}; - if(pstatus == poll_status::event) - { - bread = ::read(m_socket.native_handle(), buffer.data(), buffer.size()); - } - - co_return {pstatus, bread}; -} - -auto tcp_client::send(const std::span buffer, std::chrono::milliseconds timeout) -> coro::task> -{ - auto pstatus = co_await m_io_scheduler.poll(m_socket, poll_op::write, timeout); - ssize_t bwrite{0}; - if(pstatus == poll_status::event) - { - bwrite = ::write(m_socket.native_handle(), buffer.data(), buffer.size()); - } - - co_return {pstatus, bwrite}; -} - } // namespace coro::net diff --git a/src/net/tcp_server.cpp b/src/net/tcp_server.cpp index 9c2817b..834769e 100644 --- a/src/net/tcp_server.cpp +++ b/src/net/tcp_server.cpp @@ -3,8 +3,8 @@ namespace coro::net { -tcp_server::tcp_server(options opts) - : io_scheduler(std::move(opts.io_options)), +tcp_server::tcp_server(io_scheduler& scheduler, options opts) + : m_io_scheduler(scheduler), m_options(std::move(opts)), m_accept_socket(net::make_accept_socket( net::socket::options{net::domain_t::ipv4, net::socket::type_t::tcp, net::socket::blocking_t::no}, @@ -12,69 +12,33 @@ tcp_server::tcp_server(options opts) m_options.port, m_options.backlog)) { - if (m_options.on_connection == nullptr) - { - throw std::runtime_error{"options::on_connection cannot be nullptr."}; - } - schedule(make_accept_task()); } -tcp_server::~tcp_server() +auto tcp_server::poll(std::chrono::milliseconds timeout) -> coro::task { - shutdown(); + co_return co_await m_io_scheduler.poll(m_accept_socket, coro::poll_op::read, timeout); } -auto tcp_server::shutdown(shutdown_t wait_for_tasks) -> void -{ - if (m_accept_new_connections.exchange(false, std::memory_order::release)) - { - m_accept_socket.shutdown(); // wake it up by shutting down read/write operations. - - while (m_accept_task_exited.load(std::memory_order::acquire) == false) - { - std::this_thread::sleep_for(std::chrono::milliseconds{1}); - } - - io_scheduler::shutdown(wait_for_tasks); - } -} - -auto tcp_server::make_accept_task() -> coro::task +auto tcp_server::accept() -> coro::net::tcp_client { sockaddr_in client{}; constexpr const int len = sizeof(struct sockaddr_in); + net::socket s{::accept(m_accept_socket.native_handle(), (struct sockaddr*)&client, (socklen_t*)&len)}; - std::vector> tasks{}; - tasks.reserve(16); + std::span ip_addr_view{ + reinterpret_cast(&client.sin_addr.s_addr), + sizeof(client.sin_addr.s_addr), + }; - while (m_accept_new_connections.load(std::memory_order::acquire)) - { - auto pstatus = co_await poll(m_accept_socket, coro::poll_op::read, std::chrono::seconds{1}); - if(pstatus == poll_status::event) - { - // On accept socket read drain the listen accept queue. - while (true) - { - net::socket s{::accept(m_accept_socket.native_handle(), (struct sockaddr*)&client, (socklen_t*)&len)}; - if (s.native_handle() < 0) - { - break; - } - - tasks.emplace_back(m_options.on_connection(std::ref(*this), std::move(s))); + return tcp_client{ + m_io_scheduler, + std::move(s), + tcp_client::options{ + .address = net::ip_address{ip_addr_view, static_cast(client.sin_family)}, + .port = ntohs(client.sin_port) } - - if (!tasks.empty()) - { - schedule(std::move(tasks)); - } - } - } - - m_accept_task_exited.exchange(true, std::memory_order::release); - - co_return; + }; }; } // namespace coro::net diff --git a/src/net/udp_client.cpp b/src/net/udp_client.cpp deleted file mode 100644 index 7d0fd5d..0000000 --- a/src/net/udp_client.cpp +++ /dev/null @@ -1,39 +0,0 @@ -#include "coro/net/udp_client.hpp" -#include "coro/io_scheduler.hpp" - -namespace coro::net -{ - -udp_client::udp_client(io_scheduler& scheduler, options opts) - : m_io_scheduler(scheduler), - m_options(std::move(opts)), - m_socket({net::make_socket(net::socket::options{m_options.address.domain(), net::socket::type_t::udp, net::socket::blocking_t::no})}) -{ - -} - -auto udp_client::sendto(const std::span buffer, std::chrono::milliseconds timeout) -> coro::task -{ - auto pstatus = co_await m_io_scheduler.poll(m_socket, poll_op::write, timeout); - if(pstatus != poll_status::event) - { - co_return 0; - } - - sockaddr_in server{}; - server.sin_family = static_cast(m_options.address.domain()); - server.sin_port = htons(m_options.port); - server.sin_addr = *reinterpret_cast(m_options.address.data().data()); - - socklen_t server_len{sizeof(server)}; - - co_return ::sendto( - m_socket.native_handle(), - buffer.data(), - buffer.size(), - 0, - reinterpret_cast(&server), - server_len); -} - -} // namespace coro::net diff --git a/src/net/udp_peer.cpp b/src/net/udp_peer.cpp new file mode 100644 index 0000000..5cad87e --- /dev/null +++ b/src/net/udp_peer.cpp @@ -0,0 +1,32 @@ +#include "coro/net/udp_peer.hpp" + +namespace coro::net +{ + +udp_peer::udp_peer( + io_scheduler& scheduler, + net::domain_t domain) + : m_io_scheduler(scheduler), + m_socket(net::make_socket( + net::socket::options{ + domain, + net::socket::type_t::udp, + net::socket::blocking_t::no})) +{ + +} + +udp_peer::udp_peer( + io_scheduler& scheduler, + const info& bind_info) + : m_io_scheduler(scheduler), + m_socket(net::make_accept_socket( + net::socket::options{bind_info.address.domain(), net::socket::type_t::udp, net::socket::blocking_t::no}, + bind_info.address, + bind_info.port)), + m_bound(true) +{ + +} + +} // namespace coro::net diff --git a/src/net/udp_server.cpp b/src/net/udp_server.cpp deleted file mode 100644 index 05d74a3..0000000 --- a/src/net/udp_server.cpp +++ /dev/null @@ -1,57 +0,0 @@ -#include "coro/net/udp_server.hpp" -#include "coro/io_scheduler.hpp" - -namespace coro::net -{ - -udp_server::udp_server(io_scheduler& io_scheduler, options opts) - : m_io_scheduler(io_scheduler), - m_options(std::move(opts)), - m_accept_socket(net::make_accept_socket( - net::socket::options{m_options.address.domain(), net::socket::type_t::udp, net::socket::blocking_t::no}, - m_options.address, - m_options.port - )) -{ - -} - -auto udp_server::recvfrom(std::span& buffer, std::chrono::milliseconds timeout) -> coro::task> -{ - auto pstatus = co_await m_io_scheduler.poll(m_accept_socket, poll_op::read, timeout); - if(pstatus != poll_status::event) - { - co_return std::nullopt; - } - - sockaddr_in client{}; - - socklen_t client_len{sizeof(client)}; - - auto bytes_read = ::recvfrom( - m_accept_socket.native_handle(), - buffer.data(), - buffer.size(), - 0, - reinterpret_cast(&client), - &client_len); - - if(bytes_read == -1) - { - co_return std::nullopt; - } - - buffer = buffer.subspan(0, bytes_read); - - std::span ip_addr_view{ - reinterpret_cast(&client.sin_addr.s_addr), - sizeof(client.sin_addr.s_addr), - }; - - co_return udp_client::options{ - .address = net::ip_address{ip_addr_view, static_cast(client.sin_family)}, - .port = ntohs(client.sin_port) - }; -} - -} // namespace coro::net diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 9f91c35..9f95a14 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -2,7 +2,7 @@ cmake_minimum_required(VERSION 3.16) project(libcoro_test) set(LIBCORO_TEST_SOURCE_FILES - net/test_dns_client.cpp + net/test_dns_resolver.cpp net/test_ip_address.cpp net/test_tcp_server.cpp net/test_udp_peers.cpp diff --git a/test/bench.cpp b/test/bench.cpp index 788d8ae..9c18f29 100644 --- a/test/bench.cpp +++ b/test/bench.cpp @@ -337,81 +337,95 @@ TEST_CASE("benchmark counter task io_scheduler yield (all) -> resume (all) from REQUIRE(counter == iterations); } -TEST_CASE("benchmark tcp_client and tcp_server") +TEST_CASE("benchmark tcp_server echo server") { /** * This test *requires* two schedulers since polling on read/write of the sockets involved * will reset/trample on each other when each side of the client + server go to poll(). */ - const constexpr std::size_t connections = 256; - const constexpr std::size_t messages_per_connection = 1'000; + const constexpr std::size_t connections = 64; + const constexpr std::size_t messages_per_connection = 10'000; const constexpr std::size_t ops = connections * messages_per_connection; const std::string msg = "im a data point in a stream of bytes"; - const std::string done_msg = "done"; - auto address = coro::net::ip_address::from_string("127.0.0.1"); - auto on_connection = [&msg, &done_msg](coro::net::tcp_server& scheduler, coro::net::socket sock) -> coro::task { + coro::io_scheduler server_scheduler{}; + coro::io_scheduler client_scheduler{}; + + std::atomic listening{false}; + + auto make_on_connection_task = [&](coro::net::tcp_client client) -> coro::task { std::string in(64, '\0'); - do + // Echo the messages until the socket is closed. a 'done' message arrives. + while(true) { - auto [rstatus, rbytes] = co_await scheduler.read(sock, std::span{in.data(), in.size()}); - REQUIRE(rstatus == coro::poll_status::event); + auto pstatus = co_await client.poll(coro::poll_op::read); + REQUIRE(pstatus == coro::poll_status::event); - in.resize(rbytes); + auto [rstatus, rspan] = client.recv(in); + if(rstatus == coro::net::recv_status::closed) + { + REQUIRE(rspan.empty()); + break; + } - auto [wstatus, wbytes] = co_await scheduler.write(sock, std::span(in.data(), in.length())); - REQUIRE(wstatus == coro::poll_status::event); - REQUIRE(wbytes == in.length()); - } while(in != done_msg); + REQUIRE(rstatus == coro::net::recv_status::ok); + + in.resize(rspan.size()); + + auto [sstatus, remaining] = client.send(in); + REQUIRE(sstatus == coro::net::send_status::ok); + REQUIRE(remaining.empty()); + } co_return; }; - coro::net::tcp_server scheduler{coro::net::tcp_server::options{ - .address = coro::net::ip_address::from_string("0.0.0.0"), - .port = 8080, - .backlog = 128, - .on_connection = on_connection, - .io_options = coro::io_scheduler::options{.thread_strategy = coro::io_scheduler::thread_strategy_t::spawn}}}; + auto make_server_task = [&]() -> coro::task { + coro::net::tcp_server server{server_scheduler}; - coro::io_scheduler client_scheduler{ - coro::io_scheduler::options{.thread_strategy = coro::io_scheduler::thread_strategy_t::spawn}}; + listening = true; - auto make_client_task = [&client_scheduler, &address, &msg, &done_msg, &messages_per_connection]() -> coro::task { - coro::net::tcp_client client{ - client_scheduler, - coro::net::tcp_client::options{ - .address = address, - .port = 8080, - .domain = coro::net::domain_t::ipv4}}; + uint64_t accepted{0}; + while(accepted < connections) + { + auto pstatus = co_await server.poll(); + REQUIRE(pstatus == coro::poll_status::event); + + auto client = server.accept(); + REQUIRE(client.socket().is_valid()); + + server_scheduler.schedule(make_on_connection_task(std::move(client))); + + ++accepted; + } + + co_return; + }; + + auto make_client_task = [&]() -> coro::task { + coro::net::tcp_client client{client_scheduler}; auto cstatus = co_await client.connect(); REQUIRE(cstatus == coro::net::connect_status::connected); for(size_t i = 1; i <= messages_per_connection; ++i) { - const std::string* msg_ptr = &msg; - if(i == messages_per_connection) - { - msg_ptr = &done_msg; - } + auto [sstatus, remaining] = client.send(msg); + REQUIRE(sstatus == coro::net::send_status::ok); + REQUIRE(remaining.empty()); - auto [wstatus, wbytes] = co_await client.send(std::span{msg_ptr->data(), msg_ptr->length()}); - - REQUIRE(wstatus == coro::poll_status::event); - REQUIRE(wbytes == msg_ptr->length()); + auto pstatus = co_await client.poll(coro::poll_op::read); + REQUIRE(pstatus == coro::poll_status::event); std::string response(64, '\0'); - - auto [rstatus, rbytes] = co_await client.recv(std::span{response.data(), response.length()}); - - REQUIRE(rstatus == coro::poll_status::event); - REQUIRE(rbytes == msg_ptr->length()); - response.resize(rbytes); - REQUIRE(response == *msg_ptr); + auto [rstatus, rspan] = client.recv(response); + REQUIRE(rstatus == coro::net::recv_status::ok); + REQUIRE(rspan.size() == msg.size()); + response.resize(rspan.size()); + REQUIRE(response == msg); } co_return; @@ -419,11 +433,23 @@ TEST_CASE("benchmark tcp_client and tcp_server") auto start = sc::now(); + // Create the server to accept incoming tcp connections. + server_scheduler.schedule(make_server_task()); + + // The server can take a small bit of time to start up, if we don't wait for it to notify then + // the first few connections can easily fail to connect causing this test to fail. + while(!listening) + { + std::this_thread::sleep_for(std::chrono::milliseconds{1}); + } + + // Spawn N client connections. for(size_t i = 0; i < connections; ++i) { REQUIRE(client_scheduler.schedule(make_client_task())); } + // Wait for all the connections to complete their work. while (!client_scheduler.empty()) { std::this_thread::sleep_for(std::chrono::milliseconds{1}); @@ -432,6 +458,9 @@ TEST_CASE("benchmark tcp_client and tcp_server") auto stop = sc::now(); print_stats("benchmark tcp_client and tcp_server", ops, start, stop); - scheduler.shutdown(); - REQUIRE(scheduler.empty()); + server_scheduler.shutdown(); + REQUIRE(server_scheduler.empty()); + + client_scheduler.shutdown(); + REQUIRE(client_scheduler.empty()); } diff --git a/test/net/test_dns_client.cpp b/test/net/test_dns_resolver.cpp similarity index 82% rename from test/net/test_dns_client.cpp rename to test/net/test_dns_resolver.cpp index a549ecb..71683e1 100644 --- a/test/net/test_dns_client.cpp +++ b/test/net/test_dns_resolver.cpp @@ -4,19 +4,19 @@ #include -TEST_CASE("dns_client basic") +TEST_CASE("dns_resolver basic") { coro::io_scheduler scheduler{ coro::io_scheduler::options{.thread_strategy = coro::io_scheduler::thread_strategy_t::spawn} }; - coro::net::dns_client dns_client{scheduler, std::chrono::milliseconds{5000}}; + coro::net::dns_resolver dns_resolver{scheduler, std::chrono::milliseconds{5000}}; std::atomic done{false}; auto make_host_by_name_task = [&](coro::net::hostname hn) -> coro::task { - auto result_ptr = co_await std::move(dns_client.host_by_name(hn)); + auto result_ptr = co_await std::move(dns_resolver.host_by_name(hn)); if(result_ptr->status() == coro::net::dns_status::complete) { diff --git a/test/net/test_tcp_server.cpp b/test/net/test_tcp_server.cpp index ebbca13..fa95773 100644 --- a/test/net/test_tcp_server.cpp +++ b/test/net/test_tcp_server.cpp @@ -2,88 +2,70 @@ #include -TEST_CASE("tcp_server no on connection throws") +TEST_CASE("tcp_server ping server") { - REQUIRE_THROWS(coro::net::tcp_server{coro::net::tcp_server::options{.on_connection = nullptr}}); -} + const std::string client_msg{"Hello from client"}; + const std::string server_msg{"Reply from server!"}; -static auto tcp_server_echo( - const std::variant address, - const std::string msg -) -> void -{ - auto on_connection = [&msg](coro::net::tcp_server& scheduler, coro::net::socket sock) -> coro::task { - std::string in(64, '\0'); + coro::io_scheduler scheduler{}; - auto [rstatus, rbytes] = co_await scheduler.read(sock, std::span{in.data(), in.size()}); - REQUIRE(rstatus == coro::poll_status::event); - - in.resize(rbytes); - REQUIRE(in == msg); - - auto [wstatus, wbytes] = co_await scheduler.write(sock, std::span(in.data(), in.length())); - REQUIRE(wstatus == coro::poll_status::event); - REQUIRE(wbytes == in.length()); - - co_return; - }; - - coro::net::tcp_server scheduler{coro::net::tcp_server::options{ - .address = coro::net::ip_address::from_string("0.0.0.0"), - .port = 8080, - .backlog = 128, - .on_connection = on_connection, - .io_options = coro::io_scheduler::options{.thread_strategy = coro::io_scheduler::thread_strategy_t::spawn}}}; - - coro::net::dns_client dns_client{scheduler, std::chrono::seconds{5}}; - - auto make_client_task = [&scheduler, &dns_client, &address, &msg]() -> coro::task { - coro::net::tcp_client client{ - scheduler, - coro::net::tcp_client::options{ - .address = address, - .port = 8080, - .domain = coro::net::domain_t::ipv4, - .dns = &dns_client}}; + auto make_client_task = [&]() -> coro::task { + coro::net::tcp_client client{scheduler}; auto cstatus = co_await client.connect(); REQUIRE(cstatus == coro::net::connect_status::connected); - auto [wstatus, wbytes] = co_await client.send(std::span{msg.data(), msg.length()}); + // Skip polling for write, should really only poll if the write is partial, shouldn't be + // required for this test. + auto [sstatus, remaining] = client.send(client_msg); + REQUIRE(sstatus == coro::net::send_status::ok); + REQUIRE(remaining.empty()); - REQUIRE(wstatus == coro::poll_status::event); - REQUIRE(wbytes == msg.length()); + // Poll for the server's response. + auto pstatus = co_await client.poll(coro::poll_op::read); + REQUIRE(pstatus == coro::poll_status::event); - std::string response(64, '\0'); - - auto [rstatus, rbytes] = co_await client.recv(std::span{response.data(), response.length()}); - - REQUIRE(rstatus == coro::poll_status::event); - REQUIRE(rbytes == msg.length()); - response.resize(rbytes); - REQUIRE(response == msg); + std::string buffer(256, '\0'); + auto [rstatus, rspan] = client.recv(buffer); + REQUIRE(rstatus == coro::net::recv_status::ok); + REQUIRE(rspan.size() == server_msg.length()); + buffer.resize(rspan.size()); + REQUIRE(buffer == server_msg); co_return; }; - REQUIRE(scheduler.schedule(make_client_task())); + auto make_server_task = [&]() -> coro::task { + coro::net::tcp_server server{scheduler}; - // Shutting down the scheduler will cause it to stop accepting new connections, to avoid requiring - // another scheduler for this test the main thread can spin sleep until the tcp scheduler reports - // that it is empty. tcp schedulers do not report their accept task as a task in its size/empty count. - while (!scheduler.empty()) + // Poll for client connection. + auto pstatus = co_await server.poll(); + REQUIRE(pstatus == coro::poll_status::event); + auto client = server.accept(); + REQUIRE(client.socket().is_valid()); + + // Poll for client request. + pstatus = co_await client.poll(coro::poll_op::read); + REQUIRE(pstatus == coro::poll_status::event); + + std::string buffer(256, '\0'); + auto [rstatus, rspan] = client.recv(buffer); + REQUIRE(rstatus == coro::net::recv_status::ok); + REQUIRE(rspan.size() == client_msg.size()); + buffer.resize(rspan.size()); + REQUIRE(buffer == client_msg); + + // Respond to client. + auto [sstatus, remaining] = client.send(server_msg); + REQUIRE(sstatus == coro::net::send_status::ok); + REQUIRE(remaining.empty()); + }; + + scheduler.schedule(make_server_task()); + scheduler.schedule(make_client_task()); + + while(!scheduler.empty()) { std::this_thread::sleep_for(std::chrono::milliseconds{1}); } - - scheduler.shutdown(); - REQUIRE(scheduler.empty()); -} - -TEST_CASE("tcp_server echo server") -{ - const std::string msg{"Hello from client"}; - - tcp_server_echo(coro::net::ip_address::from_string("127.0.0.1"), msg); - tcp_server_echo(coro::net::hostname{"localhost"}, msg); } diff --git a/test/net/test_udp_peers.cpp b/test/net/test_udp_peers.cpp index ef9ae93..622c6b1 100644 --- a/test/net/test_udp_peers.cpp +++ b/test/net/test_udp_peers.cpp @@ -2,53 +2,114 @@ #include -TEST_CASE("udp echo peers") +TEST_CASE("udp one way") { - const std::string client_msg{"Hello from client!"}; - const std::string server_msg{"Hello from server!!"}; + const std::string msg{"aaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbcccccccccccccccccc"}; coro::io_scheduler scheduler{}; - auto make_client_task = [&](uint16_t client_port, uint16_t server_port) -> coro::task { - std::string owning_buffer(4096, '\0'); + auto make_send_task = [&]() -> coro::task { + coro::net::udp_peer peer{scheduler}; + coro::net::udp_peer::info peer_info{}; - coro::net::udp_client client{scheduler, coro::net::udp_client::options{.port = client_port}}; - auto wbytes = co_await client.sendto(std::span{client_msg.data(), client_msg.length()}); - REQUIRE(wbytes == client_msg.length()); - - coro::net::udp_server server{scheduler, coro::net::udp_server::options{.port = server_port}}; - std::span buffer{owning_buffer.data(), owning_buffer.length()}; - auto client_opt = co_await server.recvfrom(buffer); - REQUIRE(client_opt.has_value()); - REQUIRE(buffer.size() == server_msg.length()); - owning_buffer.resize(buffer.size()); - REQUIRE(owning_buffer == server_msg); + auto [sstatus, remaining] = peer.sendto(peer_info, msg); + REQUIRE(sstatus == coro::net::send_status::ok); + REQUIRE(remaining.empty()); co_return; }; - auto make_server_task = [&](uint16_t server_port, uint16_t client_port) -> coro::task { - std::string owning_buffer(4096, '\0'); + auto make_recv_task = [&]() -> coro::task { + coro::net::udp_peer::info self_info{ + .address = coro::net::ip_address::from_string("0.0.0.0") + }; - coro::net::udp_server server{scheduler, coro::net::udp_server::options{.port = server_port}}; + coro::net::udp_peer self{scheduler, self_info}; - std::span buffer{owning_buffer.data(), owning_buffer.length()}; - auto client_opt = co_await server.recvfrom(buffer); - REQUIRE(client_opt.has_value()); - REQUIRE(buffer.size() == client_msg.length()); - owning_buffer.resize(buffer.size()); - REQUIRE(owning_buffer == client_msg); + auto pstatus = co_await self.poll(coro::poll_op::read); + REQUIRE(pstatus == coro::poll_status::event); - - auto options = client_opt.value(); - options.port = client_port; // we'll change the port for this test since its the same host - coro::net::udp_client client{scheduler, options}; - auto wbytes = co_await client.sendto(std::span{server_msg.data(), server_msg.length()}); - REQUIRE(wbytes == server_msg.length()); + std::string buffer(64, '\0'); + auto [rstatus, peer_info, rspan] = self.recvfrom(buffer); + REQUIRE(rstatus == coro::net::recv_status::ok); + REQUIRE(peer_info.address == coro::net::ip_address::from_string("127.0.0.1")); + // The peer's port will be randomly picked by the kernel since it wasn't bound. + REQUIRE(rspan.size() == msg.size()); + buffer.resize(rspan.size()); + REQUIRE(buffer == msg); co_return; }; - scheduler.schedule(make_server_task(8080, 8081)); - scheduler.schedule(make_client_task(8080, 8081)); + scheduler.schedule(make_recv_task()); + scheduler.schedule(make_send_task()); +} + +TEST_CASE("udp echo peers") +{ + const std::string peer1_msg{"Hello from peer1!"}; + const std::string peer2_msg{"Hello from peer2!!"}; + + coro::io_scheduler scheduler{}; + + auto make_peer_task = [&scheduler]( + uint16_t my_port, + uint16_t peer_port, + bool send_first, + const std::string my_msg, + const std::string peer_msg) -> coro::task { + + coro::net::udp_peer::info my_info{.address = coro::net::ip_address::from_string("0.0.0.0"), .port = my_port}; + coro::net::udp_peer::info peer_info{.address = coro::net::ip_address::from_string("127.0.0.1"), .port = peer_port}; + + coro::net::udp_peer me{scheduler, my_info}; + + if(send_first) + { + // Send my message to my peer first. + auto [sstatus, remaining] = me.sendto(peer_info, my_msg); + REQUIRE(sstatus == coro::net::send_status::ok); + REQUIRE(remaining.empty()); + } + else + { + // Poll for my peers message first. + auto pstatus = co_await me.poll(coro::poll_op::read); + REQUIRE(pstatus == coro::poll_status::event); + + std::string buffer(64, '\0'); + auto [rstatus, recv_peer_info, rspan] = me.recvfrom(buffer); + REQUIRE(rstatus == coro::net::recv_status::ok); + REQUIRE(recv_peer_info == peer_info); + REQUIRE(rspan.size() == peer_msg.size()); + buffer.resize(rspan.size()); + REQUIRE(buffer == peer_msg); + } + + if(send_first) + { + // I sent first so now I need to await my peer's message. + auto pstatus = co_await me.poll(coro::poll_op::read); + REQUIRE(pstatus == coro::poll_status::event); + + std::string buffer(64, '\0'); + auto [rstatus, recv_peer_info, rspan] = me.recvfrom(buffer); + REQUIRE(rstatus == coro::net::recv_status::ok); + REQUIRE(recv_peer_info == peer_info); + REQUIRE(rspan.size() == peer_msg.size()); + buffer.resize(rspan.size()); + REQUIRE(buffer == peer_msg); + } + else + { + auto [sstatus, remaining] = me.sendto(peer_info, my_msg); + REQUIRE(sstatus == coro::net::send_status::ok); + REQUIRE(remaining.empty()); + } + + co_return; + }; + + scheduler.schedule(make_peer_task(8081, 8080, false, peer2_msg, peer1_msg)); + scheduler.schedule(make_peer_task(8080, 8081, true, peer1_msg, peer2_msg)); } diff --git a/test/test_io_scheduler.cpp b/test/test_io_scheduler.cpp index ebf26d3..c3b6f7c 100644 --- a/test/test_io_scheduler.cpp +++ b/test/test_io_scheduler.cpp @@ -16,7 +16,7 @@ TEST_CASE("io_scheduler sizeof()") std::cerr << "sizeof(coro:task)=[" << sizeof(coro::task) << "]\n"; std::cerr << "sizeof(std::coroutine_handle<>)=[" << sizeof(std::coroutine_handle<>) << "]\n"; - std::cerr << "sizeof(std::variant>)=[" << sizeof(std::variant>) + std::cerr << "sizeof(std::variant, std::coroutine_handle<>>)=[" << sizeof(std::variant, std::coroutine_handle<>>) << "]\n"; REQUIRE(true);