diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 11f80a6..ee6b286 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -12,8 +12,6 @@ jobs: TZ: America/New_York DEBIAN_FRONTEND: noninteractive steps: - - name: Checkout - uses: actions/checkout@v2 - name: apt run: | apt-get update @@ -25,6 +23,10 @@ jobs: git \ ninja-build \ g++-10 + - name: Checkout # recurisve checkout requires git to be installed first + uses: actions/checkout@v2 + with: + submodules: recursive - name: build-release-g++ run: | mkdir build-release-g++ @@ -46,8 +48,6 @@ jobs: container: image: fedora:32 steps: - - name: Checkout - uses: actions/checkout@v2 - name: dnf run: | sudo dnf install -y \ @@ -56,6 +56,10 @@ jobs: ninja-build \ gcc-c++-10.2.1 \ lcov + - name: Checkout # recurisve checkout requires git to be installed first + uses: actions/checkout@v2 + with: + submodules: recursive - name: build-debug-g++ run: | mkdir build-debug-g++ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..8c2829b --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "vendor/c-ares/c-ares"] + path = vendor/c-ares/c-ares + url = git@github.com:c-ares/c-ares.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 45417e2..8e1f0a7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,12 +7,23 @@ option(CORO_CODE_COVERAGE "Enable code coverage, tests must also be enabled, De message("${PROJECT_NAME} CORO_BUILD_TESTS = ${CORO_BUILD_TESTS}") message("${PROJECT_NAME} CORO_CODE_COVERAGE = ${CORO_CODE_COVERAGE}") +set(CARES_STATIC ON CACHE INTERNAL "") +set(CARES_SHARED OFF CACHE INTERNAL "") +set(CARES_INSTALL OFF CACHE INTERNAL "") + +add_subdirectory(vendor/c-ares/c-ares) + set(LIBCORO_SOURCE_FILES inc/coro/detail/void_value.hpp + inc/coro/net/hostname.hpp + inc/coro/net/ip_address.hpp src/net/ip_address.cpp + inc/coro/net/socket.hpp + inc/coro/awaitable.hpp inc/coro/connect.hpp src/connect.cpp inc/coro/coro.hpp + inc/coro/dns_client.hpp src/dns_client.cpp inc/coro/event.hpp src/event.cpp inc/coro/generator.hpp inc/coro/io_scheduler.hpp @@ -20,11 +31,10 @@ set(LIBCORO_SOURCE_FILES inc/coro/poll.hpp inc/coro/promise.hpp inc/coro/shutdown.hpp - inc/coro/socket.hpp inc/coro/sync_wait.hpp src/sync_wait.cpp inc/coro/task.hpp inc/coro/tcp_client.hpp src/tcp_client.cpp - inc/coro/tcp_scheduler.hpp + inc/coro/tcp_scheduler.hpp src/tcp_scheduler.cpp inc/coro/thread_pool.hpp src/thread_pool.cpp inc/coro/when_all.hpp ) @@ -33,7 +43,7 @@ add_library(${PROJECT_NAME} STATIC ${LIBCORO_SOURCE_FILES}) set_target_properties(${PROJECT_NAME} PROPERTIES LINKER_LANGUAGE CXX) target_compile_features(${PROJECT_NAME} PUBLIC cxx_std_20) target_include_directories(${PROJECT_NAME} PUBLIC inc) -target_link_libraries(${PROJECT_NAME} PUBLIC pthread) +target_link_libraries(${PROJECT_NAME} PUBLIC pthread c-ares) if(${CMAKE_CXX_COMPILER_ID} MATCHES "GNU") if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "10.2.0") diff --git a/inc/coro/connect.hpp b/inc/coro/connect.hpp index 573c268..c609b67 100644 --- a/inc/coro/connect.hpp +++ b/inc/coro/connect.hpp @@ -13,9 +13,17 @@ enum class connect_status /// The connection operation timed out. timeout, /// There was an error, use errno to get more information on the specific error. - error + error, + /// The client was given a hostname but no dns client to resolve the ip address. + dns_client_required, + /// The dns hostname lookup failed + dns_lookup_failure }; +/** + * @param status String representation of the connection status. + * @throw std::logic_error If provided an invalid connect_status enum value. + */ auto to_string(const connect_status& status) -> const std::string&; } // namespace coro diff --git a/inc/coro/coro.hpp b/inc/coro/coro.hpp index b1ea640..921fbd7 100644 --- a/inc/coro/coro.hpp +++ b/inc/coro/coro.hpp @@ -1,6 +1,12 @@ #pragma once +#include "coro/net/hostname.hpp" +#include "coro/net/ip_address.hpp" +#include "coro/net/socket.hpp" + #include "coro/awaitable.hpp" +#include "coro/connect.hpp" +#include "coro/dns_client.hpp" #include "coro/event.hpp" #include "coro/generator.hpp" #include "coro/io_scheduler.hpp" diff --git a/inc/coro/dns_client.hpp b/inc/coro/dns_client.hpp new file mode 100644 index 0000000..6414118 --- /dev/null +++ b/inc/coro/dns_client.hpp @@ -0,0 +1,87 @@ +#pragma once + +#include "coro/io_scheduler.hpp" +#include "coro/net/ip_address.hpp" +#include "coro/net/hostname.hpp" +#include "coro/task.hpp" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace coro +{ + +class dns_client; + +enum class dns_status +{ + complete, + error +}; + +class dns_result +{ + friend dns_client; +public: + explicit dns_result(coro::resume_token& token, uint64_t pending_dns_requests); + ~dns_result() = default; + + auto status() const -> dns_status { return m_status; } + auto ip_addresses() const -> const std::vector& { return m_ip_addresses; } +private: + coro::resume_token& m_token; + uint64_t m_pending_dns_requests{0}; + dns_status m_status{dns_status::complete}; + std::vector m_ip_addresses{}; + + friend auto ares_dns_callback( + void* arg, + int status, + int timeouts, + struct hostent* host + ) -> void; +}; + +class dns_client +{ +public: + explicit dns_client(io_scheduler& scheduler, std::chrono::milliseconds timeout); + dns_client(const dns_client&) = delete; + dns_client(dns_client&&) = delete; + auto operator=(const dns_client&) noexcept -> dns_client& = delete; + auto operator=(dns_client&&) noexcept -> dns_client& = delete; + ~dns_client(); + + auto host_by_name(const net::hostname& hn) -> coro::task>; +private: + /// The io scheduler to drive the events for dns lookups. + io_scheduler& m_scheduler; + + /// The global timeout per dns lookup request. + std::chrono::milliseconds m_timeout{0}; + + /// The libc-ares channel for looking up dns entries. + ares_channel m_ares_channel{nullptr}; + + /// This is the set of sockets that are currently being actively polled so multiple poll tasks + /// are not setup when ares_poll() is called. + std::unordered_set m_active_sockets{}; + + /// Global count to track if c-ares has been initialized or cleaned up. + static uint64_t m_ares_count; + /// Critical section around the c-ares global init/cleanup to prevent heap corruption. + static std::mutex m_ares_mutex; + + auto ares_poll() -> void; + auto make_poll_task(io_scheduler::fd_t fd, poll_op ops) -> coro::task; +}; + +} // namespace coro diff --git a/inc/coro/io_scheduler.hpp b/inc/coro/io_scheduler.hpp index 753a40a..5b0cf80 100644 --- a/inc/coro/io_scheduler.hpp +++ b/inc/coro/io_scheduler.hpp @@ -3,7 +3,7 @@ #include "coro/awaitable.hpp" #include "coro/poll.hpp" #include "coro/shutdown.hpp" -#include "coro/socket.hpp" +#include "coro/net/socket.hpp" #include "coro/task.hpp" #include @@ -653,7 +653,7 @@ public: } auto read( - const coro::socket& sock, + const net::socket& sock, std::span buffer, std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task> { @@ -684,7 +684,7 @@ public: } auto write( - const coro::socket& sock, + const net::socket& sock, const std::span buffer, std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task> { diff --git a/inc/coro/net/hostname.hpp b/inc/coro/net/hostname.hpp new file mode 100644 index 0000000..8b74048 --- /dev/null +++ b/inc/coro/net/hostname.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include + +namespace coro::net +{ + +class hostname +{ +public: + hostname() = default; + explicit hostname(std::string hn) + : m_hostname(std::move(hn)) {} + hostname(const hostname&) = default; + hostname(hostname&&) = default; + auto operator=(const hostname&) noexcept -> hostname& = default; + auto operator=(hostname&&) noexcept -> hostname& = default; + ~hostname() = default; + + auto data() const -> const std::string& { return m_hostname; } + + auto operator<=>(const hostname& other) const + { + return m_hostname <=> other.m_hostname; + } +private: + std::string m_hostname; +}; + +} // namespace coro::net diff --git a/inc/coro/net/ip_address.hpp b/inc/coro/net/ip_address.hpp new file mode 100644 index 0000000..dfa1761 --- /dev/null +++ b/inc/coro/net/ip_address.hpp @@ -0,0 +1,106 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace coro::net +{ + +enum class domain_t : int +{ + ipv4 = AF_INET, + ipv6 = AF_INET6 +}; + +auto to_string(domain_t domain) -> const std::string&; + +class ip_address +{ +public: + static const constexpr size_t ipv4_len{4}; + static const constexpr size_t ipv6_len{16}; + + ip_address() = default; + ip_address(std::span binary_address, domain_t domain = domain_t::ipv4) + : m_domain(domain) + { + if(m_domain == domain_t::ipv4 && binary_address.size() > ipv4_len) + { + throw std::runtime_error{"coro::net::ip_address provided binary ip address is too long"}; + } + else if(binary_address.size() > ipv6_len) + { + throw std::runtime_error{"coro::net::ip_address provided binary ip address is too long"}; + } + + std::copy(binary_address.begin(), binary_address.end(), m_data.begin()); + } + ip_address(const ip_address&) = default; + ip_address(ip_address&&) = default; + auto operator=(const ip_address&) noexcept -> ip_address& = default; + auto operator=(ip_address&&) noexcept -> ip_address& = default; + ~ip_address() = default; + + auto domain() const -> domain_t { return m_domain; } + auto data() const -> std::span + { + if(m_domain == domain_t::ipv4) + { + return std::span{m_data.data(), ipv4_len}; + } + else + { + return std::span{m_data.data(), ipv6_len}; + } + } + + static auto from_string(const std::string& address, domain_t domain = domain_t::ipv4) -> ip_address + { + ip_address addr{}; + addr.m_domain = domain; + + auto success = inet_pton(static_cast(addr.m_domain), address.data(), addr.m_data.data()); + if(success != 1) + { + throw std::runtime_error{"coro::net::ip_address faild to convert from string"}; + } + + return addr; + } + + auto to_string() const -> std::string + { + std::string output; + if(m_domain == domain_t::ipv4) + { + output.resize(INET_ADDRSTRLEN, '\0'); + } + else + { + output.resize(INET6_ADDRSTRLEN, '\0'); + } + + auto success = inet_ntop(static_cast(m_domain), m_data.data(), output.data(), output.length()); + if(success != nullptr) + { + auto len = strnlen(success, output.length()); + output.resize(len); + } + else + { + throw std::runtime_error{"coro::net::ip_address failed to convert to string representation"}; + } + + return output; + } + +private: + domain_t m_domain{domain_t::ipv4}; + std::array m_data{}; +}; + +} // namespace coro::net diff --git a/inc/coro/socket.hpp b/inc/coro/net/socket.hpp similarity index 80% rename from inc/coro/socket.hpp rename to inc/coro/net/socket.hpp index f4c163c..185f86f 100644 --- a/inc/coro/socket.hpp +++ b/inc/coro/net/socket.hpp @@ -1,5 +1,6 @@ #pragma once +#include "coro/net/ip_address.hpp" #include "coro/poll.hpp" #include @@ -10,17 +11,11 @@ #include -namespace coro +namespace coro::net { class socket { public: - enum class domain_t - { - ipv4, - ipv6 - }; - enum class type_t { udp, @@ -40,19 +35,6 @@ public: blocking_t blocking; }; - static auto domain_to_os(const domain_t& domain) -> int - { - switch (domain) - { - case domain_t::ipv4: - return AF_INET; - case domain_t::ipv6: - return AF_INET6; - default: - throw std::runtime_error{"Unknown socket::domain_t."}; - } - } - static auto type_to_os(const type_t& type) -> int { switch (type) @@ -68,7 +50,7 @@ public: static auto make_socket(const options& opts) -> socket { - socket s{::socket(domain_to_os(opts.domain), type_to_os(opts.type), 0)}; + socket s{::socket(static_cast(opts.domain), type_to_os(opts.type), 0)}; if (s.native_handle() < 0) { throw std::runtime_error{"Failed to create socket."}; @@ -86,10 +68,10 @@ public: } static auto make_accept_socket( - const options& opts, - const std::string& address, // force string to guarantee null terminated. - uint16_t port, - int32_t backlog = 128) -> socket + const options& opts, + const net::ip_address& address, + uint16_t port, + int32_t backlog = 128) -> socket { socket s = make_socket(opts); @@ -100,13 +82,14 @@ public: } sockaddr_in server{}; - server.sin_family = domain_to_os(opts.domain); + server.sin_family = static_cast(opts.domain); server.sin_port = htons(port); + server.sin_addr = *reinterpret_cast(address.data().data()); - if (inet_pton(server.sin_family, address.data(), &server.sin_addr) <= 0) - { - throw std::runtime_error{"Failed to translate IP Address."}; - } + // if (inet_pton(server.sin_family, address.data(), &server.sin_addr) <= 0) + // { + // throw std::runtime_error{"Failed to translate IP Address."}; + // } if (bind(s.native_handle(), (struct sockaddr*)&server, sizeof(server)) < 0) { @@ -202,4 +185,4 @@ private: int m_fd{-1}; }; -} // namespace coro +} // namespace coro::net diff --git a/inc/coro/poll.hpp b/inc/coro/poll.hpp index a19e30f..24ffd91 100644 --- a/inc/coro/poll.hpp +++ b/inc/coro/poll.hpp @@ -4,7 +4,7 @@ namespace coro { -enum class poll_op +enum class poll_op : uint64_t { /// Poll for read operations. read = EPOLLIN, @@ -14,6 +14,16 @@ enum class poll_op read_write = EPOLLIN | EPOLLOUT }; +inline auto poll_op_readable(poll_op op) -> bool +{ + return (static_cast(op) & EPOLLIN); +} + +inline auto poll_op_writeable(poll_op op) -> bool +{ + return (static_cast(op) & EPOLLOUT); +} + enum class poll_status { /// The poll operation was was successful. diff --git a/inc/coro/task.hpp b/inc/coro/task.hpp index bad058e..425d272 100644 --- a/inc/coro/task.hpp +++ b/inc/coro/task.hpp @@ -196,7 +196,7 @@ public: return false; } - auto operator co_await() const noexcept + auto operator co_await() const & noexcept { struct awaitable : public awaitable_base { @@ -206,6 +206,16 @@ public: return awaitable{m_coroutine}; } + auto operator co_await() const && noexcept + { + struct awaitable : public awaitable_base + { + auto await_resume() -> decltype(auto) { return std::move(this->m_coroutine.promise()).return_value(); } + }; + + return awaitable{m_coroutine}; + } + auto promise() & -> promise_type& { return m_coroutine.promise(); } auto promise() const& -> const promise_type& { return m_coroutine.promise(); } diff --git a/inc/coro/tcp_client.hpp b/inc/coro/tcp_client.hpp index d0627cb..7f6d0f7 100644 --- a/inc/coro/tcp_client.hpp +++ b/inc/coro/tcp_client.hpp @@ -1,11 +1,17 @@ #pragma once +#include "coro/net/ip_address.hpp" +#include "coro/net/hostname.hpp" +#include "coro/net/socket.hpp" #include "coro/connect.hpp" #include "coro/poll.hpp" -#include "coro/socket.hpp" #include "coro/task.hpp" +#include "coro/dns_client.hpp" #include +#include +#include +#include namespace coro { @@ -16,12 +22,22 @@ class tcp_client public: struct options { - std::string address{"127.0.0.1"}; - int16_t port{8080}; - socket::domain_t domain{socket::domain_t::ipv4}; + /// The hostname or ip address to connect to. If using hostname then a dns client must be provided. + std::variant address{net::ip_address::from_string("127.0.0.1")}; + /// The port to connect to. + int16_t port{8080}; + /// The protocol domain to connect with. + net::domain_t domain{net::domain_t::ipv4}; + /// If using a hostname to connect to then provide a dns client to lookup the host's ip address. + /// This is optional if using ip addresses directly. + dns_client* dns{nullptr}; }; - tcp_client(io_scheduler& scheduler, options opts = options{"127.0.0.1", 8080, socket::domain_t::ipv4}); + tcp_client(io_scheduler& scheduler, options opts = options{ + .address = {net::ip_address::from_string("127.0.0.1")}, + .port = 8080, + .domain = net::domain_t::ipv4, + .dns = nullptr}); tcp_client(const tcp_client&) = delete; tcp_client(tcp_client&&) = default; auto operator=(const tcp_client&) noexcept -> tcp_client& = delete; @@ -30,13 +46,18 @@ public: auto connect(std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task; - auto socket() const -> const coro::socket& { return m_socket; } - auto socket() -> coro::socket& { return m_socket; } + auto socket() const -> const net::socket& { return m_socket; } + auto socket() -> net::socket& { return m_socket; } private: + /// The scheduler that will drive this tcp client. io_scheduler& m_io_scheduler; - options m_options; - coro::socket m_socket; + /// Options for what server to connect to. + options m_options; + /// The tcp socket. + net::socket m_socket; + /// Cache the status of the connect in the event the user calls connect() again. + std::optional m_connect_status{std::nullopt}; }; } // namespace coro diff --git a/inc/coro/tcp_scheduler.hpp b/inc/coro/tcp_scheduler.hpp index 4dc44ab..e6517ac 100644 --- a/inc/coro/tcp_scheduler.hpp +++ b/inc/coro/tcp_scheduler.hpp @@ -1,61 +1,45 @@ #pragma once +#include "coro/net/ip_address.hpp" #include "coro/io_scheduler.hpp" -#include "coro/socket.hpp" +#include "coro/net/socket.hpp" #include "coro/task.hpp" #include #include #include -#include - namespace coro { class tcp_scheduler : public io_scheduler { public: - using on_connection_t = std::function(tcp_scheduler&, socket)>; + using on_connection_t = std::function(tcp_scheduler&, net::socket)>; struct options { - std::string address = "0.0.0.0"; + net::ip_address address = net::ip_address::from_string("0.0.0.0"); uint16_t port = 8080; int32_t backlog = 128; on_connection_t on_connection = nullptr; io_scheduler::options io_options{}; }; - tcp_scheduler( + explicit tcp_scheduler( options opts = options{ - "0.0.0.0", + net::ip_address::from_string("0.0.0.0"), 8080, 128, - [](tcp_scheduler&, socket) -> task { co_return; }, - io_scheduler::options{9, 2, io_scheduler::thread_strategy_t::spawn}}) - : io_scheduler(std::move(opts.io_options)), - m_opts(std::move(opts)), - m_accept_socket(socket::make_accept_socket( - socket::options{socket::domain_t::ipv4, socket::type_t::tcp, socket::blocking_t::no}, - m_opts.address, - m_opts.port, - m_opts.backlog)) - { - if (m_opts.on_connection == nullptr) - { - throw std::runtime_error{"options::on_connection cannot be nullptr."}; - } - - schedule(make_accept_task()); - } + [](tcp_scheduler&, net::socket) -> task { co_return; }, + io_scheduler::options{9, 2, io_scheduler::thread_strategy_t::spawn}}); tcp_scheduler(const tcp_scheduler&) = delete; tcp_scheduler(tcp_scheduler&&) = delete; auto operator=(const tcp_scheduler&) -> tcp_scheduler& = delete; auto operator=(tcp_scheduler&&) -> tcp_scheduler& = delete; - ~tcp_scheduler() override { shutdown(); } + ~tcp_scheduler() override; auto empty() const -> bool { return size() == 0; } @@ -66,20 +50,7 @@ public: return (size > 0) ? size - 1 : 0; } - auto shutdown(shutdown_t wait_for_tasks = shutdown_t::sync) -> void override - { - if (m_accept_new_connections.exchange(false, std::memory_order_release)) - { - m_accept_socket.shutdown(); // wake it up by shutting down read/write operations. - - while (m_accept_task_exited.load(std::memory_order::acquire) == false) - { - std::this_thread::sleep_for(std::chrono::milliseconds{1}); - } - - io_scheduler::shutdown(wait_for_tasks); - } - } + auto shutdown(shutdown_t wait_for_tasks = shutdown_t::sync) -> void override; private: options m_opts; @@ -87,45 +58,9 @@ private: /// Should the accept task continue accepting new connections? std::atomic m_accept_new_connections{true}; std::atomic m_accept_task_exited{false}; - socket m_accept_socket{-1}; + net::socket m_accept_socket{-1}; - auto make_accept_task() -> coro::task - { - sockaddr_in client{}; - constexpr const int len = sizeof(struct sockaddr_in); - - std::vector> tasks{}; - tasks.reserve(16); - - while (m_accept_new_connections.load(std::memory_order::acquire)) - { - co_await poll(m_accept_socket.native_handle(), coro::poll_op::read); - // auto status = co_await poll(m_accept_socket.native_handle(), coro::poll_op::read); - // (void)status; // TODO: introduce timeouts on io_scheduer.poll(); - - // On accept socket read drain the listen accept queue. - while (true) - { - socket s{::accept(m_accept_socket.native_handle(), (struct sockaddr*)&client, (socklen_t*)&len)}; - if (s.native_handle() < 0) - { - break; - } - - tasks.emplace_back(m_opts.on_connection(std::ref(*this), std::move(s))); - } - - if (!tasks.empty()) - { - schedule(tasks); - tasks.clear(); - } - } - - m_accept_task_exited.exchange(true, std::memory_order::release); - - co_return; - }; + auto make_accept_task() -> coro::task; }; } // namespace coro diff --git a/src/connect.cpp b/src/connect.cpp index 2009610..a227b4a 100644 --- a/src/connect.cpp +++ b/src/connect.cpp @@ -1,11 +1,15 @@ #include "coro/connect.hpp" +#include + namespace coro { static std::string connect_status_connected{"connected"}; static std::string connect_status_invalid_ip_address{"invalid_ip_address"}; static std::string connect_status_timeout{"timeout"}; static std::string connect_status_error{"error"}; +static std::string connect_status_dns_client_required{"dns_client_required"}; +static std::string connect_status_dns_lookup_failure{"dns_lookup_failure"}; auto to_string(const connect_status& status) -> const std::string& { @@ -19,7 +23,13 @@ auto to_string(const connect_status& status) -> const std::string& return connect_status_timeout; case connect_status::error: return connect_status_error; + case connect_status::dns_client_required: + return connect_status_dns_client_required; + case connect_status::dns_lookup_failure: + return connect_status_dns_lookup_failure; } + + throw std::logic_error{"Invalid/unknown connect status."}; } } // namespace coro diff --git a/src/dns_client.cpp b/src/dns_client.cpp new file mode 100644 index 0000000..c4a91dd --- /dev/null +++ b/src/dns_client.cpp @@ -0,0 +1,193 @@ +#include "coro/dns_client.hpp" + +#include +#include +#include + +namespace coro +{ + +uint64_t dns_client::m_ares_count{0}; +std::mutex dns_client::m_ares_mutex{}; + +auto ares_dns_callback( + void* arg, + int status, + int /*timeouts*/, + struct hostent* host +) -> void +{ + auto& result = *static_cast(arg); + --result.m_pending_dns_requests; + + if(host == nullptr || status != ARES_SUCCESS) + { + result.m_status = dns_status::error; + } + else + { + result.m_status = dns_status::complete; + + for(size_t i = 0; host->h_addr_list[i] != nullptr; ++i) + { + size_t len = (host->h_addrtype == AF_INET) ? net::ip_address::ipv4_len : net::ip_address::ipv6_len; + net::ip_address ip_addr{ + std::span{reinterpret_cast(host->h_addr_list[i]), len}, + static_cast(host->h_addrtype) + }; + + result.m_ip_addresses.emplace_back(std::move(ip_addr)); + } + } + + if(result.m_pending_dns_requests == 0) + { + result.m_token.resume(); + } +} + +dns_result::dns_result(coro::resume_token& token, uint64_t pending_dns_requests) + : m_token(token), + m_pending_dns_requests(pending_dns_requests) +{ + +} + +dns_client::dns_client(io_scheduler& scheduler, std::chrono::milliseconds timeout) + : m_scheduler(scheduler), + m_timeout(timeout) +{ + { + std::lock_guard g{m_ares_mutex}; + if(m_ares_count == 0) + { + auto ares_status = ares_library_init(ARES_LIB_INIT_ALL); + if(ares_status != ARES_SUCCESS) + { + throw std::runtime_error{ares_strerror(ares_status)}; + } + } + ++m_ares_count; + } + + auto channel_init_status = ares_init(&m_ares_channel); + if(channel_init_status != ARES_SUCCESS) + { + throw std::runtime_error{ares_strerror(channel_init_status)}; + } +} + +dns_client::~dns_client() +{ + if(m_ares_channel != nullptr) + { + ares_destroy(m_ares_channel); + m_ares_channel = nullptr; + } + + { + std::lock_guard g{m_ares_mutex}; + --m_ares_count; + if(m_ares_count == 0) + { + ares_library_cleanup(); + } + } +} + +auto dns_client::host_by_name(const net::hostname& hn) -> coro::task> +{ + auto token = m_scheduler.generate_resume_token(); + auto result_ptr = std::make_unique(token, 2); + + ares_gethostbyname(m_ares_channel, hn.data().data(), AF_INET, ares_dns_callback, result_ptr.get()); + ares_gethostbyname(m_ares_channel, hn.data().data(), AF_INET6, ares_dns_callback, result_ptr.get()); + + // Add all required poll calls for ares to kick off the dns requests. + ares_poll(); + + // Suspend until this specific result is completed by ares. + co_await m_scheduler.yield(token); + co_return result_ptr; +} + +auto dns_client::ares_poll() -> void +{ + std::array ares_sockets{}; + std::array poll_ops{}; + + int bitmask = ares_getsock(m_ares_channel, ares_sockets.data(), ARES_GETSOCK_MAXNUM); + + size_t new_sockets{0}; + + for(size_t i = 0; i < ARES_GETSOCK_MAXNUM; ++i) + { + uint64_t ops{0}; + + if(ARES_GETSOCK_READABLE(bitmask, i)) + { + ops |= static_cast(poll_op::read); + } + if(ARES_GETSOCK_WRITABLE(bitmask, i)) + { + ops |= static_cast(poll_op::write); + } + + if(ops != 0) + { + poll_ops[i] = static_cast(ops); + ++new_sockets; + } + else + { + // According to ares usage within curl once a bitmask for a socket is zero the rest of + // the bitmask will also be zero. + break; + } + } + + for(size_t i = 0; i < new_sockets; ++i) + { + io_scheduler::fd_t fd = static_cast(ares_sockets[i]); + + // If this socket is not currently actively polling, start polling! + if(m_active_sockets.emplace(fd).second) + { + m_scheduler.schedule(make_poll_task(fd, poll_ops[i])); + } + } +} + +auto dns_client::make_poll_task(io_scheduler::fd_t fd, poll_op ops) -> coro::task +{ + auto result = co_await m_scheduler.poll(fd, ops, m_timeout); + switch(result) + { + case poll_status::event: + { + auto read_sock = poll_op_readable(ops) ? fd : ARES_SOCKET_BAD; + auto write_sock = poll_op_writeable(ops) ? fd : ARES_SOCKET_BAD; + ares_process_fd(m_ares_channel, read_sock, write_sock); + } + break; + case poll_status::timeout: + ares_process_fd(m_ares_channel, ARES_SOCKET_BAD, ARES_SOCKET_BAD); + break; + case poll_status::closed: + // might need to do something like call with two ARES_SOCKET_BAD? + break; + case poll_status::error: + // might need to do something like call with two ARES_SOCKET_BAD? + break; + } + + // Remove from the list of actively polling sockets. + m_active_sockets.erase(fd); + + // Re-initialize sockets/polls for ares since this one has now triggered. + ares_poll(); + + co_return; +}; + +} // namespace coro diff --git a/src/net/ip_address.cpp b/src/net/ip_address.cpp new file mode 100644 index 0000000..74ea25c --- /dev/null +++ b/src/net/ip_address.cpp @@ -0,0 +1,19 @@ +#include "coro/net/ip_address.hpp" + +namespace coro::net +{ + +static std::string domain_ipv4{"ipv4"}; +static std::string domain_ipv6{"ipv6"}; + +auto to_string(domain_t domain) -> const std::string& +{ + switch(domain) + { + case domain_t::ipv4: return domain_ipv4; + case domain_t::ipv6: return domain_ipv6; + } + throw std::runtime_error{"coro::net::to_string(domain_t) unknown domain"}; +} + +} // namespace coro::net diff --git a/src/tcp_client.cpp b/src/tcp_client.cpp index 4c38d02..01755a6 100644 --- a/src/tcp_client.cpp +++ b/src/tcp_client.cpp @@ -1,30 +1,67 @@ #include "coro/tcp_client.hpp" #include "coro/io_scheduler.hpp" +#include + namespace coro { tcp_client::tcp_client(io_scheduler& scheduler, options opts) : m_io_scheduler(scheduler), m_options(std::move(opts)), - m_socket(socket::make_socket(socket::options{m_options.domain, socket::type_t::tcp, socket::blocking_t::yes})) + m_socket(net::socket::make_socket(net::socket::options{m_options.domain, net::socket::type_t::tcp, net::socket::blocking_t::yes})) { } auto tcp_client::connect(std::chrono::milliseconds timeout) -> coro::task { - sockaddr_in server{}; - server.sin_family = socket::domain_to_os(m_options.domain); - server.sin_port = htons(m_options.port); - - if (inet_pton(server.sin_family, m_options.address.data(), &server.sin_addr) <= 0) + if(m_connect_status.has_value() && m_connect_status.value() == connect_status::connected) { - co_return connect_status::invalid_ip_address; + co_return m_connect_status.value(); } + const net::ip_address* ip_addr{nullptr}; + std::unique_ptr result_ptr{nullptr}; + + if(std::holds_alternative(m_options.address)) + { + if(m_options.dns == nullptr) + { + m_connect_status = connect_status::dns_client_required; + co_return connect_status::dns_client_required; + } + const auto& hn = std::get(m_options.address); + result_ptr = co_await m_options.dns->host_by_name(hn); + if(result_ptr->status() != dns_status::complete) + { + m_connect_status = connect_status::dns_lookup_failure; + co_return connect_status::dns_lookup_failure; + } + + if(result_ptr->ip_addresses().empty()) + { + m_connect_status = connect_status::dns_lookup_failure; + co_return connect_status::dns_lookup_failure; + } + + // TODO: for now we'll just take the first ip address given, but should probably allow the + // user to take preference on ipv4/ipv6 addresses. + ip_addr = &result_ptr->ip_addresses().front(); + } + else + { + ip_addr = &std::get(m_options.address); + } + + sockaddr_in server{}; + server.sin_family = static_cast(m_options.domain); + server.sin_port = htons(m_options.port); + server.sin_addr = *reinterpret_cast(ip_addr->data().data()); + auto cret = ::connect(m_socket.native_handle(), (struct sockaddr*)&server, sizeof(server)); if (cret == 0) { // Immediate connect. + m_connect_status = connect_status::connected; co_return connect_status::connected; } else if (cret == -1) @@ -46,16 +83,19 @@ auto tcp_client::connect(std::chrono::milliseconds timeout) -> coro::task void +{ + if (m_accept_new_connections.exchange(false, std::memory_order::release)) + { + m_accept_socket.shutdown(); // wake it up by shutting down read/write operations. + + while (m_accept_task_exited.load(std::memory_order::acquire) == false) + { + std::this_thread::sleep_for(std::chrono::milliseconds{1}); + } + + io_scheduler::shutdown(wait_for_tasks); + } +} + +auto tcp_scheduler::make_accept_task() -> coro::task +{ + sockaddr_in client{}; + constexpr const int len = sizeof(struct sockaddr_in); + + std::vector> tasks{}; + tasks.reserve(16); + + while (m_accept_new_connections.load(std::memory_order::acquire)) + { + co_await poll(m_accept_socket.native_handle(), coro::poll_op::read); + // auto status = co_await poll(m_accept_socket.native_handle(), coro::poll_op::read); + // (void)status; // TODO: introduce timeouts on io_scheduer.poll(); + + // On accept socket read drain the listen accept queue. + while (true) + { + net::socket s{::accept(m_accept_socket.native_handle(), (struct sockaddr*)&client, (socklen_t*)&len)}; + if (s.native_handle() < 0) + { + break; + } + + tasks.emplace_back(m_opts.on_connection(std::ref(*this), std::move(s))); + } + + if (!tasks.empty()) + { + schedule(tasks); + tasks.clear(); + } + } + + m_accept_task_exited.exchange(true, std::memory_order::release); + + co_return; +}; + +} // namespace coro diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index adfaee5..ec63958 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -2,7 +2,10 @@ cmake_minimum_required(VERSION 3.16) project(libcoro_test) set(LIBCORO_TEST_SOURCE_FILES + net/test_ip_address.cpp + bench.cpp + test_dns_client.cpp test_event.cpp test_generator.cpp test_io_scheduler.cpp @@ -16,6 +19,7 @@ set(LIBCORO_TEST_SOURCE_FILES add_executable(${PROJECT_NAME} main.cpp ${LIBCORO_TEST_SOURCE_FILES}) target_compile_features(${PROJECT_NAME} PUBLIC cxx_std_20) +target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}) target_link_libraries(${PROJECT_NAME} PRIVATE coro) target_compile_options(${PROJECT_NAME} PUBLIC -fcoroutines) diff --git a/test/net/test_ip_address.cpp b/test/net/test_ip_address.cpp new file mode 100644 index 0000000..a172df3 --- /dev/null +++ b/test/net/test_ip_address.cpp @@ -0,0 +1,68 @@ +#include "catch.hpp" + +#include + +#include +#include + +TEST_CASE("net::ip_address from_string() ipv4") +{ + { + auto ip_addr = coro::net::ip_address::from_string("127.0.0.1"); + REQUIRE(ip_addr.to_string() == "127.0.0.1"); + REQUIRE(ip_addr.domain() == coro::net::domain_t::ipv4); + std::array expected{127, 0, 0, 1}; + REQUIRE(std::equal(expected.begin(), expected.end(), ip_addr.data().begin())); + } + + { + auto ip_addr = coro::net::ip_address::from_string("255.255.0.0"); + REQUIRE(ip_addr.to_string() == "255.255.0.0"); + REQUIRE(ip_addr.domain() == coro::net::domain_t::ipv4); + std::array expected{255, 255, 0, 0}; + REQUIRE(std::equal(expected.begin(), expected.end(), ip_addr.data().begin())); + } +} + +TEST_CASE("net::ip_address from_string() ipv6") +{ + { + auto ip_addr = coro::net::ip_address::from_string("0123:4567:89ab:cdef:0123:4567:89ab:cdef", coro::net::domain_t::ipv6); + REQUIRE(ip_addr.to_string() == "123:4567:89ab:cdef:123:4567:89ab:cdef"); + REQUIRE(ip_addr.domain() == coro::net::domain_t::ipv6); + std::array expected{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}; + REQUIRE(std::equal(expected.begin(), expected.end(), ip_addr.data().begin())); + } + + { + auto ip_addr = coro::net::ip_address::from_string("::", coro::net::domain_t::ipv6); + REQUIRE(ip_addr.to_string() == "::"); + REQUIRE(ip_addr.domain() == coro::net::domain_t::ipv6); + std::array expected{}; + REQUIRE(std::equal(expected.begin(), expected.end(), ip_addr.data().begin())); + } + + { + auto ip_addr = coro::net::ip_address::from_string("::1", coro::net::domain_t::ipv6); + REQUIRE(ip_addr.to_string() == "::1"); + REQUIRE(ip_addr.domain() == coro::net::domain_t::ipv6); + std::array expected{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}; + REQUIRE(std::equal(expected.begin(), expected.end(), ip_addr.data().begin())); + } + + { + auto ip_addr = coro::net::ip_address::from_string("1::1", coro::net::domain_t::ipv6); + REQUIRE(ip_addr.to_string() == "1::1"); + REQUIRE(ip_addr.domain() == coro::net::domain_t::ipv6); + std::array expected{0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}; + REQUIRE(std::equal(expected.begin(), expected.end(), ip_addr.data().begin())); + } + + { + auto ip_addr = coro::net::ip_address::from_string("1::", coro::net::domain_t::ipv6); + REQUIRE(ip_addr.to_string() == "1::"); + REQUIRE(ip_addr.domain() == coro::net::domain_t::ipv6); + std::array expected{0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + REQUIRE(std::equal(expected.begin(), expected.end(), ip_addr.data().begin())); + } +} diff --git a/test/test_dns_client.cpp b/test/test_dns_client.cpp new file mode 100644 index 0000000..8cf46a1 --- /dev/null +++ b/test/test_dns_client.cpp @@ -0,0 +1,43 @@ +#include "catch.hpp" + +#include + +#include + +TEST_CASE("dns_client basic") +{ + coro::io_scheduler scheduler{ + coro::io_scheduler::options{.thread_strategy = coro::io_scheduler::thread_strategy_t::spawn} + }; + + coro::dns_client dns_client{scheduler, std::chrono::milliseconds{5000}}; + + std::atomic done{false}; + + auto make_host_by_name_task = [&](coro::net::hostname hn) -> coro::task + { + auto result_ptr = co_await std::move(dns_client.host_by_name(hn)); + + if(result_ptr->status() == coro::dns_status::complete) + { + for(const auto& ip_addr : result_ptr->ip_addresses()) + { + std::cerr << coro::net::to_string(ip_addr.domain()) << " " << ip_addr.to_string() << "\n"; + } + } + + done = true; + + co_return; + }; + + scheduler.schedule(make_host_by_name_task(coro::net::hostname{"www.example.com"})); + + while(!done) + { + std::this_thread::sleep_for(std::chrono::milliseconds{10}); + } + + scheduler.shutdown(); + REQUIRE(scheduler.empty()); +} \ No newline at end of file diff --git a/test/test_tcp_scheduler.cpp b/test/test_tcp_scheduler.cpp index 993710e..caf9353 100644 --- a/test/test_tcp_scheduler.cpp +++ b/test/test_tcp_scheduler.cpp @@ -7,11 +7,11 @@ TEST_CASE("tcp_scheduler no on connection throws") REQUIRE_THROWS(coro::tcp_scheduler{coro::tcp_scheduler::options{.on_connection = nullptr}}); } -TEST_CASE("tcp_scheduler echo server") +TEST_CASE("tcp_scheduler echo server ip address") { const std::string msg{"Hello from client"}; - auto on_connection = [&msg](coro::tcp_scheduler& scheduler, coro::socket sock) -> coro::task { + auto on_connection = [&msg](coro::tcp_scheduler& scheduler, coro::net::socket sock) -> coro::task { std::string in(64, '\0'); auto [rstatus, rbytes] = co_await scheduler.read(sock, std::span{in.data(), in.size()}); @@ -28,7 +28,7 @@ TEST_CASE("tcp_scheduler echo server") }; coro::tcp_scheduler scheduler{coro::tcp_scheduler::options{ - .address = "0.0.0.0", + .address = coro::net::ip_address::from_string("0.0.0.0"), .port = 8080, .backlog = 128, .on_connection = on_connection, @@ -37,7 +37,81 @@ TEST_CASE("tcp_scheduler echo server") auto make_client_task = [&scheduler, &msg]() -> coro::task { coro::tcp_client client{ scheduler, - coro::tcp_client::options{.address = "127.0.0.1", .port = 8080, .domain = coro::socket::domain_t::ipv4}}; + coro::tcp_client::options{.address = coro::net::ip_address::from_string("127.0.0.1"), .port = 8080, .domain = coro::net::domain_t::ipv4}}; + + auto cstatus = co_await client.connect(); + REQUIRE(cstatus == coro::connect_status::connected); + + auto [wstatus, wbytes] = + co_await scheduler.write(client.socket(), std::span{msg.data(), msg.length()}); + + REQUIRE(wstatus == coro::poll_status::event); + REQUIRE(wbytes == msg.length()); + + std::string response(64, '\0'); + + auto [rstatus, rbytes] = + co_await scheduler.read(client.socket(), std::span{response.data(), response.length()}); + + REQUIRE(rstatus == coro::poll_status::event); + REQUIRE(rbytes == msg.length()); + response.resize(rbytes); + REQUIRE(response == msg); + + co_return; + }; + + scheduler.schedule(make_client_task()); + + // Shutting down the scheduler will cause it to stop accepting new connections, to avoid requiring + // another scheduler for this test the main thread can spin sleep until the tcp scheduler reports + // that it is empty. tcp schedulers do not report their accept task as a task in its size/empty count. + while (!scheduler.empty()) + { + std::this_thread::sleep_for(std::chrono::milliseconds{1}); + } + + scheduler.shutdown(); + REQUIRE(scheduler.empty()); +} + +TEST_CASE("tcp_scheduler echo server hostname") +{ + const std::string msg{"Hello from client with dns lookup"}; + + auto on_connection = [&msg](coro::tcp_scheduler& scheduler, coro::net::socket sock) -> coro::task { + std::string in(64, '\0'); + + auto [rstatus, rbytes] = co_await scheduler.read(sock, std::span{in.data(), in.size()}); + REQUIRE(rstatus == coro::poll_status::event); + + in.resize(rbytes); + REQUIRE(in == msg); + + auto [wstatus, wbytes] = co_await scheduler.write(sock, std::span(in.data(), in.length())); + REQUIRE(wstatus == coro::poll_status::event); + REQUIRE(wbytes == in.length()); + + co_return; + }; + + coro::tcp_scheduler scheduler{coro::tcp_scheduler::options{ + .address = coro::net::ip_address::from_string("0.0.0.0"), + .port = 8080, + .backlog = 128, + .on_connection = on_connection, + .io_options = coro::io_scheduler::options{.thread_strategy = coro::io_scheduler::thread_strategy_t::spawn}}}; + + coro::dns_client dns_client{scheduler, std::chrono::seconds{5}}; + + auto make_client_task = [&scheduler, &dns_client, &msg]() -> coro::task { + coro::tcp_client client{ + scheduler, + coro::tcp_client::options{ + .address = coro::net::hostname{"localhost"}, + .port = 8080, + .domain = coro::net::domain_t::ipv4, + .dns = &dns_client}}; auto cstatus = co_await client.connect(); REQUIRE(cstatus == coro::connect_status::connected); diff --git a/vendor/c-ares/c-ares b/vendor/c-ares/c-ares new file mode 160000 index 0000000..799e81d --- /dev/null +++ b/vendor/c-ares/c-ares @@ -0,0 +1 @@ +Subproject commit 799e81d4ace75af7d530857d4f8b35913a27463e