1
0
Fork 0
mirror of https://gitlab.com/niansa/libcrosscoro.git synced 2025-03-06 20:53:32 +01:00

libc-ares dns client for hostname -> ip addres lookups (#24)

* libc-ares dns client for hostname -> ip addres lookups

* Add tcp_client dns lookup if hostname + dns available
This commit is contained in:
Josh Baldwin 2020-12-29 17:19:26 -07:00 committed by GitHub
parent e11058ef22
commit c02aefe26e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
24 changed files with 887 additions and 141 deletions

View file

@ -12,8 +12,6 @@ jobs:
TZ: America/New_York
DEBIAN_FRONTEND: noninteractive
steps:
- name: Checkout
uses: actions/checkout@v2
- name: apt
run: |
apt-get update
@ -25,6 +23,10 @@ jobs:
git \
ninja-build \
g++-10
- name: Checkout # recurisve checkout requires git to be installed first
uses: actions/checkout@v2
with:
submodules: recursive
- name: build-release-g++
run: |
mkdir build-release-g++
@ -46,8 +48,6 @@ jobs:
container:
image: fedora:32
steps:
- name: Checkout
uses: actions/checkout@v2
- name: dnf
run: |
sudo dnf install -y \
@ -56,6 +56,10 @@ jobs:
ninja-build \
gcc-c++-10.2.1 \
lcov
- name: Checkout # recurisve checkout requires git to be installed first
uses: actions/checkout@v2
with:
submodules: recursive
- name: build-debug-g++
run: |
mkdir build-debug-g++

3
.gitmodules vendored Normal file
View file

@ -0,0 +1,3 @@
[submodule "vendor/c-ares/c-ares"]
path = vendor/c-ares/c-ares
url = git@github.com:c-ares/c-ares.git

View file

@ -7,12 +7,23 @@ option(CORO_CODE_COVERAGE "Enable code coverage, tests must also be enabled, De
message("${PROJECT_NAME} CORO_BUILD_TESTS = ${CORO_BUILD_TESTS}")
message("${PROJECT_NAME} CORO_CODE_COVERAGE = ${CORO_CODE_COVERAGE}")
set(CARES_STATIC ON CACHE INTERNAL "")
set(CARES_SHARED OFF CACHE INTERNAL "")
set(CARES_INSTALL OFF CACHE INTERNAL "")
add_subdirectory(vendor/c-ares/c-ares)
set(LIBCORO_SOURCE_FILES
inc/coro/detail/void_value.hpp
inc/coro/net/hostname.hpp
inc/coro/net/ip_address.hpp src/net/ip_address.cpp
inc/coro/net/socket.hpp
inc/coro/awaitable.hpp
inc/coro/connect.hpp src/connect.cpp
inc/coro/coro.hpp
inc/coro/dns_client.hpp src/dns_client.cpp
inc/coro/event.hpp src/event.cpp
inc/coro/generator.hpp
inc/coro/io_scheduler.hpp
@ -20,11 +31,10 @@ set(LIBCORO_SOURCE_FILES
inc/coro/poll.hpp
inc/coro/promise.hpp
inc/coro/shutdown.hpp
inc/coro/socket.hpp
inc/coro/sync_wait.hpp src/sync_wait.cpp
inc/coro/task.hpp
inc/coro/tcp_client.hpp src/tcp_client.cpp
inc/coro/tcp_scheduler.hpp
inc/coro/tcp_scheduler.hpp src/tcp_scheduler.cpp
inc/coro/thread_pool.hpp src/thread_pool.cpp
inc/coro/when_all.hpp
)
@ -33,7 +43,7 @@ add_library(${PROJECT_NAME} STATIC ${LIBCORO_SOURCE_FILES})
set_target_properties(${PROJECT_NAME} PROPERTIES LINKER_LANGUAGE CXX)
target_compile_features(${PROJECT_NAME} PUBLIC cxx_std_20)
target_include_directories(${PROJECT_NAME} PUBLIC inc)
target_link_libraries(${PROJECT_NAME} PUBLIC pthread)
target_link_libraries(${PROJECT_NAME} PUBLIC pthread c-ares)
if(${CMAKE_CXX_COMPILER_ID} MATCHES "GNU")
if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "10.2.0")

View file

@ -13,9 +13,17 @@ enum class connect_status
/// The connection operation timed out.
timeout,
/// There was an error, use errno to get more information on the specific error.
error
error,
/// The client was given a hostname but no dns client to resolve the ip address.
dns_client_required,
/// The dns hostname lookup failed
dns_lookup_failure
};
/**
* @param status String representation of the connection status.
* @throw std::logic_error If provided an invalid connect_status enum value.
*/
auto to_string(const connect_status& status) -> const std::string&;
} // namespace coro

View file

@ -1,6 +1,12 @@
#pragma once
#include "coro/net/hostname.hpp"
#include "coro/net/ip_address.hpp"
#include "coro/net/socket.hpp"
#include "coro/awaitable.hpp"
#include "coro/connect.hpp"
#include "coro/dns_client.hpp"
#include "coro/event.hpp"
#include "coro/generator.hpp"
#include "coro/io_scheduler.hpp"

87
inc/coro/dns_client.hpp Normal file
View file

@ -0,0 +1,87 @@
#pragma once
#include "coro/io_scheduler.hpp"
#include "coro/net/ip_address.hpp"
#include "coro/net/hostname.hpp"
#include "coro/task.hpp"
#include <ares.h>
#include <mutex>
#include <functional>
#include <vector>
#include <array>
#include <memory>
#include <chrono>
#include <unordered_set>
#include <sys/epoll.h>
namespace coro
{
class dns_client;
enum class dns_status
{
complete,
error
};
class dns_result
{
friend dns_client;
public:
explicit dns_result(coro::resume_token<void>& token, uint64_t pending_dns_requests);
~dns_result() = default;
auto status() const -> dns_status { return m_status; }
auto ip_addresses() const -> const std::vector<coro::net::ip_address>& { return m_ip_addresses; }
private:
coro::resume_token<void>& m_token;
uint64_t m_pending_dns_requests{0};
dns_status m_status{dns_status::complete};
std::vector<coro::net::ip_address> m_ip_addresses{};
friend auto ares_dns_callback(
void* arg,
int status,
int timeouts,
struct hostent* host
) -> void;
};
class dns_client
{
public:
explicit dns_client(io_scheduler& scheduler, std::chrono::milliseconds timeout);
dns_client(const dns_client&) = delete;
dns_client(dns_client&&) = delete;
auto operator=(const dns_client&) noexcept -> dns_client& = delete;
auto operator=(dns_client&&) noexcept -> dns_client& = delete;
~dns_client();
auto host_by_name(const net::hostname& hn) -> coro::task<std::unique_ptr<dns_result>>;
private:
/// The io scheduler to drive the events for dns lookups.
io_scheduler& m_scheduler;
/// The global timeout per dns lookup request.
std::chrono::milliseconds m_timeout{0};
/// The libc-ares channel for looking up dns entries.
ares_channel m_ares_channel{nullptr};
/// This is the set of sockets that are currently being actively polled so multiple poll tasks
/// are not setup when ares_poll() is called.
std::unordered_set<io_scheduler::fd_t> m_active_sockets{};
/// Global count to track if c-ares has been initialized or cleaned up.
static uint64_t m_ares_count;
/// Critical section around the c-ares global init/cleanup to prevent heap corruption.
static std::mutex m_ares_mutex;
auto ares_poll() -> void;
auto make_poll_task(io_scheduler::fd_t fd, poll_op ops) -> coro::task<void>;
};
} // namespace coro

View file

@ -3,7 +3,7 @@
#include "coro/awaitable.hpp"
#include "coro/poll.hpp"
#include "coro/shutdown.hpp"
#include "coro/socket.hpp"
#include "coro/net/socket.hpp"
#include "coro/task.hpp"
#include <atomic>
@ -653,7 +653,7 @@ public:
}
auto read(
const coro::socket& sock,
const net::socket& sock,
std::span<char> buffer,
std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task<std::pair<poll_status, ssize_t>>
{
@ -684,7 +684,7 @@ public:
}
auto write(
const coro::socket& sock,
const net::socket& sock,
const std::span<const char> buffer,
std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task<std::pair<poll_status, ssize_t>>
{

30
inc/coro/net/hostname.hpp Normal file
View file

@ -0,0 +1,30 @@
#pragma once
#include <string>
namespace coro::net
{
class hostname
{
public:
hostname() = default;
explicit hostname(std::string hn)
: m_hostname(std::move(hn)) {}
hostname(const hostname&) = default;
hostname(hostname&&) = default;
auto operator=(const hostname&) noexcept -> hostname& = default;
auto operator=(hostname&&) noexcept -> hostname& = default;
~hostname() = default;
auto data() const -> const std::string& { return m_hostname; }
auto operator<=>(const hostname& other) const
{
return m_hostname <=> other.m_hostname;
}
private:
std::string m_hostname;
};
} // namespace coro::net

106
inc/coro/net/ip_address.hpp Normal file
View file

@ -0,0 +1,106 @@
#pragma once
#include <string>
#include <stdexcept>
#include <array>
#include <span>
#include <arpa/inet.h>
#include <cstring>
namespace coro::net
{
enum class domain_t : int
{
ipv4 = AF_INET,
ipv6 = AF_INET6
};
auto to_string(domain_t domain) -> const std::string&;
class ip_address
{
public:
static const constexpr size_t ipv4_len{4};
static const constexpr size_t ipv6_len{16};
ip_address() = default;
ip_address(std::span<const uint8_t> binary_address, domain_t domain = domain_t::ipv4)
: m_domain(domain)
{
if(m_domain == domain_t::ipv4 && binary_address.size() > ipv4_len)
{
throw std::runtime_error{"coro::net::ip_address provided binary ip address is too long"};
}
else if(binary_address.size() > ipv6_len)
{
throw std::runtime_error{"coro::net::ip_address provided binary ip address is too long"};
}
std::copy(binary_address.begin(), binary_address.end(), m_data.begin());
}
ip_address(const ip_address&) = default;
ip_address(ip_address&&) = default;
auto operator=(const ip_address&) noexcept -> ip_address& = default;
auto operator=(ip_address&&) noexcept -> ip_address& = default;
~ip_address() = default;
auto domain() const -> domain_t { return m_domain; }
auto data() const -> std::span<const uint8_t>
{
if(m_domain == domain_t::ipv4)
{
return std::span<const uint8_t>{m_data.data(), ipv4_len};
}
else
{
return std::span<const uint8_t>{m_data.data(), ipv6_len};
}
}
static auto from_string(const std::string& address, domain_t domain = domain_t::ipv4) -> ip_address
{
ip_address addr{};
addr.m_domain = domain;
auto success = inet_pton(static_cast<int>(addr.m_domain), address.data(), addr.m_data.data());
if(success != 1)
{
throw std::runtime_error{"coro::net::ip_address faild to convert from string"};
}
return addr;
}
auto to_string() const -> std::string
{
std::string output;
if(m_domain == domain_t::ipv4)
{
output.resize(INET_ADDRSTRLEN, '\0');
}
else
{
output.resize(INET6_ADDRSTRLEN, '\0');
}
auto success = inet_ntop(static_cast<int>(m_domain), m_data.data(), output.data(), output.length());
if(success != nullptr)
{
auto len = strnlen(success, output.length());
output.resize(len);
}
else
{
throw std::runtime_error{"coro::net::ip_address failed to convert to string representation"};
}
return output;
}
private:
domain_t m_domain{domain_t::ipv4};
std::array<uint8_t, ipv6_len> m_data{};
};
} // namespace coro::net

View file

@ -1,5 +1,6 @@
#pragma once
#include "coro/net/ip_address.hpp"
#include "coro/poll.hpp"
#include <arpa/inet.h>
@ -10,17 +11,11 @@
#include <iostream>
namespace coro
namespace coro::net
{
class socket
{
public:
enum class domain_t
{
ipv4,
ipv6
};
enum class type_t
{
udp,
@ -40,19 +35,6 @@ public:
blocking_t blocking;
};
static auto domain_to_os(const domain_t& domain) -> int
{
switch (domain)
{
case domain_t::ipv4:
return AF_INET;
case domain_t::ipv6:
return AF_INET6;
default:
throw std::runtime_error{"Unknown socket::domain_t."};
}
}
static auto type_to_os(const type_t& type) -> int
{
switch (type)
@ -68,7 +50,7 @@ public:
static auto make_socket(const options& opts) -> socket
{
socket s{::socket(domain_to_os(opts.domain), type_to_os(opts.type), 0)};
socket s{::socket(static_cast<int>(opts.domain), type_to_os(opts.type), 0)};
if (s.native_handle() < 0)
{
throw std::runtime_error{"Failed to create socket."};
@ -86,10 +68,10 @@ public:
}
static auto make_accept_socket(
const options& opts,
const std::string& address, // force string to guarantee null terminated.
uint16_t port,
int32_t backlog = 128) -> socket
const options& opts,
const net::ip_address& address,
uint16_t port,
int32_t backlog = 128) -> socket
{
socket s = make_socket(opts);
@ -100,13 +82,14 @@ public:
}
sockaddr_in server{};
server.sin_family = domain_to_os(opts.domain);
server.sin_family = static_cast<int>(opts.domain);
server.sin_port = htons(port);
server.sin_addr = *reinterpret_cast<const in_addr*>(address.data().data());
if (inet_pton(server.sin_family, address.data(), &server.sin_addr) <= 0)
{
throw std::runtime_error{"Failed to translate IP Address."};
}
// if (inet_pton(server.sin_family, address.data(), &server.sin_addr) <= 0)
// {
// throw std::runtime_error{"Failed to translate IP Address."};
// }
if (bind(s.native_handle(), (struct sockaddr*)&server, sizeof(server)) < 0)
{
@ -202,4 +185,4 @@ private:
int m_fd{-1};
};
} // namespace coro
} // namespace coro::net

View file

@ -4,7 +4,7 @@
namespace coro
{
enum class poll_op
enum class poll_op : uint64_t
{
/// Poll for read operations.
read = EPOLLIN,
@ -14,6 +14,16 @@ enum class poll_op
read_write = EPOLLIN | EPOLLOUT
};
inline auto poll_op_readable(poll_op op) -> bool
{
return (static_cast<uint64_t>(op) & EPOLLIN);
}
inline auto poll_op_writeable(poll_op op) -> bool
{
return (static_cast<uint64_t>(op) & EPOLLOUT);
}
enum class poll_status
{
/// The poll operation was was successful.

View file

@ -196,7 +196,7 @@ public:
return false;
}
auto operator co_await() const noexcept
auto operator co_await() const & noexcept
{
struct awaitable : public awaitable_base
{
@ -206,6 +206,16 @@ public:
return awaitable{m_coroutine};
}
auto operator co_await() const && noexcept
{
struct awaitable : public awaitable_base
{
auto await_resume() -> decltype(auto) { return std::move(this->m_coroutine.promise()).return_value(); }
};
return awaitable{m_coroutine};
}
auto promise() & -> promise_type& { return m_coroutine.promise(); }
auto promise() const& -> const promise_type& { return m_coroutine.promise(); }

View file

@ -1,11 +1,17 @@
#pragma once
#include "coro/net/ip_address.hpp"
#include "coro/net/hostname.hpp"
#include "coro/net/socket.hpp"
#include "coro/connect.hpp"
#include "coro/poll.hpp"
#include "coro/socket.hpp"
#include "coro/task.hpp"
#include "coro/dns_client.hpp"
#include <chrono>
#include <optional>
#include <variant>
#include <memory>
namespace coro
{
@ -16,12 +22,22 @@ class tcp_client
public:
struct options
{
std::string address{"127.0.0.1"};
int16_t port{8080};
socket::domain_t domain{socket::domain_t::ipv4};
/// The hostname or ip address to connect to. If using hostname then a dns client must be provided.
std::variant<net::hostname, net::ip_address> address{net::ip_address::from_string("127.0.0.1")};
/// The port to connect to.
int16_t port{8080};
/// The protocol domain to connect with.
net::domain_t domain{net::domain_t::ipv4};
/// If using a hostname to connect to then provide a dns client to lookup the host's ip address.
/// This is optional if using ip addresses directly.
dns_client* dns{nullptr};
};
tcp_client(io_scheduler& scheduler, options opts = options{"127.0.0.1", 8080, socket::domain_t::ipv4});
tcp_client(io_scheduler& scheduler, options opts = options{
.address = {net::ip_address::from_string("127.0.0.1")},
.port = 8080,
.domain = net::domain_t::ipv4,
.dns = nullptr});
tcp_client(const tcp_client&) = delete;
tcp_client(tcp_client&&) = default;
auto operator=(const tcp_client&) noexcept -> tcp_client& = delete;
@ -30,13 +46,18 @@ public:
auto connect(std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task<connect_status>;
auto socket() const -> const coro::socket& { return m_socket; }
auto socket() -> coro::socket& { return m_socket; }
auto socket() const -> const net::socket& { return m_socket; }
auto socket() -> net::socket& { return m_socket; }
private:
/// The scheduler that will drive this tcp client.
io_scheduler& m_io_scheduler;
options m_options;
coro::socket m_socket;
/// Options for what server to connect to.
options m_options;
/// The tcp socket.
net::socket m_socket;
/// Cache the status of the connect in the event the user calls connect() again.
std::optional<connect_status> m_connect_status{std::nullopt};
};
} // namespace coro

View file

@ -1,61 +1,45 @@
#pragma once
#include "coro/net/ip_address.hpp"
#include "coro/io_scheduler.hpp"
#include "coro/socket.hpp"
#include "coro/net/socket.hpp"
#include "coro/task.hpp"
#include <fcntl.h>
#include <functional>
#include <sys/socket.h>
#include <iostream>
namespace coro
{
class tcp_scheduler : public io_scheduler
{
public:
using on_connection_t = std::function<task<void>(tcp_scheduler&, socket)>;
using on_connection_t = std::function<task<void>(tcp_scheduler&, net::socket)>;
struct options
{
std::string address = "0.0.0.0";
net::ip_address address = net::ip_address::from_string("0.0.0.0");
uint16_t port = 8080;
int32_t backlog = 128;
on_connection_t on_connection = nullptr;
io_scheduler::options io_options{};
};
tcp_scheduler(
explicit tcp_scheduler(
options opts =
options{
"0.0.0.0",
net::ip_address::from_string("0.0.0.0"),
8080,
128,
[](tcp_scheduler&, socket) -> task<void> { co_return; },
io_scheduler::options{9, 2, io_scheduler::thread_strategy_t::spawn}})
: io_scheduler(std::move(opts.io_options)),
m_opts(std::move(opts)),
m_accept_socket(socket::make_accept_socket(
socket::options{socket::domain_t::ipv4, socket::type_t::tcp, socket::blocking_t::no},
m_opts.address,
m_opts.port,
m_opts.backlog))
{
if (m_opts.on_connection == nullptr)
{
throw std::runtime_error{"options::on_connection cannot be nullptr."};
}
schedule(make_accept_task());
}
[](tcp_scheduler&, net::socket) -> task<void> { co_return; },
io_scheduler::options{9, 2, io_scheduler::thread_strategy_t::spawn}});
tcp_scheduler(const tcp_scheduler&) = delete;
tcp_scheduler(tcp_scheduler&&) = delete;
auto operator=(const tcp_scheduler&) -> tcp_scheduler& = delete;
auto operator=(tcp_scheduler&&) -> tcp_scheduler& = delete;
~tcp_scheduler() override { shutdown(); }
~tcp_scheduler() override;
auto empty() const -> bool { return size() == 0; }
@ -66,20 +50,7 @@ public:
return (size > 0) ? size - 1 : 0;
}
auto shutdown(shutdown_t wait_for_tasks = shutdown_t::sync) -> void override
{
if (m_accept_new_connections.exchange(false, std::memory_order_release))
{
m_accept_socket.shutdown(); // wake it up by shutting down read/write operations.
while (m_accept_task_exited.load(std::memory_order::acquire) == false)
{
std::this_thread::sleep_for(std::chrono::milliseconds{1});
}
io_scheduler::shutdown(wait_for_tasks);
}
}
auto shutdown(shutdown_t wait_for_tasks = shutdown_t::sync) -> void override;
private:
options m_opts;
@ -87,45 +58,9 @@ private:
/// Should the accept task continue accepting new connections?
std::atomic<bool> m_accept_new_connections{true};
std::atomic<bool> m_accept_task_exited{false};
socket m_accept_socket{-1};
net::socket m_accept_socket{-1};
auto make_accept_task() -> coro::task<void>
{
sockaddr_in client{};
constexpr const int len = sizeof(struct sockaddr_in);
std::vector<task<void>> tasks{};
tasks.reserve(16);
while (m_accept_new_connections.load(std::memory_order::acquire))
{
co_await poll(m_accept_socket.native_handle(), coro::poll_op::read);
// auto status = co_await poll(m_accept_socket.native_handle(), coro::poll_op::read);
// (void)status; // TODO: introduce timeouts on io_scheduer.poll();
// On accept socket read drain the listen accept queue.
while (true)
{
socket s{::accept(m_accept_socket.native_handle(), (struct sockaddr*)&client, (socklen_t*)&len)};
if (s.native_handle() < 0)
{
break;
}
tasks.emplace_back(m_opts.on_connection(std::ref(*this), std::move(s)));
}
if (!tasks.empty())
{
schedule(tasks);
tasks.clear();
}
}
m_accept_task_exited.exchange(true, std::memory_order::release);
co_return;
};
auto make_accept_task() -> coro::task<void>;
};
} // namespace coro

View file

@ -1,11 +1,15 @@
#include "coro/connect.hpp"
#include <stdexcept>
namespace coro
{
static std::string connect_status_connected{"connected"};
static std::string connect_status_invalid_ip_address{"invalid_ip_address"};
static std::string connect_status_timeout{"timeout"};
static std::string connect_status_error{"error"};
static std::string connect_status_dns_client_required{"dns_client_required"};
static std::string connect_status_dns_lookup_failure{"dns_lookup_failure"};
auto to_string(const connect_status& status) -> const std::string&
{
@ -19,7 +23,13 @@ auto to_string(const connect_status& status) -> const std::string&
return connect_status_timeout;
case connect_status::error:
return connect_status_error;
case connect_status::dns_client_required:
return connect_status_dns_client_required;
case connect_status::dns_lookup_failure:
return connect_status_dns_lookup_failure;
}
throw std::logic_error{"Invalid/unknown connect status."};
}
} // namespace coro

193
src/dns_client.cpp Normal file
View file

@ -0,0 +1,193 @@
#include "coro/dns_client.hpp"
#include <iostream>
#include <netdb.h>
#include <arpa/inet.h>
namespace coro
{
uint64_t dns_client::m_ares_count{0};
std::mutex dns_client::m_ares_mutex{};
auto ares_dns_callback(
void* arg,
int status,
int /*timeouts*/,
struct hostent* host
) -> void
{
auto& result = *static_cast<dns_result*>(arg);
--result.m_pending_dns_requests;
if(host == nullptr || status != ARES_SUCCESS)
{
result.m_status = dns_status::error;
}
else
{
result.m_status = dns_status::complete;
for(size_t i = 0; host->h_addr_list[i] != nullptr; ++i)
{
size_t len = (host->h_addrtype == AF_INET) ? net::ip_address::ipv4_len : net::ip_address::ipv6_len;
net::ip_address ip_addr{
std::span<const uint8_t>{reinterpret_cast<const uint8_t*>(host->h_addr_list[i]), len},
static_cast<net::domain_t>(host->h_addrtype)
};
result.m_ip_addresses.emplace_back(std::move(ip_addr));
}
}
if(result.m_pending_dns_requests == 0)
{
result.m_token.resume();
}
}
dns_result::dns_result(coro::resume_token<void>& token, uint64_t pending_dns_requests)
: m_token(token),
m_pending_dns_requests(pending_dns_requests)
{
}
dns_client::dns_client(io_scheduler& scheduler, std::chrono::milliseconds timeout)
: m_scheduler(scheduler),
m_timeout(timeout)
{
{
std::lock_guard<std::mutex> g{m_ares_mutex};
if(m_ares_count == 0)
{
auto ares_status = ares_library_init(ARES_LIB_INIT_ALL);
if(ares_status != ARES_SUCCESS)
{
throw std::runtime_error{ares_strerror(ares_status)};
}
}
++m_ares_count;
}
auto channel_init_status = ares_init(&m_ares_channel);
if(channel_init_status != ARES_SUCCESS)
{
throw std::runtime_error{ares_strerror(channel_init_status)};
}
}
dns_client::~dns_client()
{
if(m_ares_channel != nullptr)
{
ares_destroy(m_ares_channel);
m_ares_channel = nullptr;
}
{
std::lock_guard<std::mutex> g{m_ares_mutex};
--m_ares_count;
if(m_ares_count == 0)
{
ares_library_cleanup();
}
}
}
auto dns_client::host_by_name(const net::hostname& hn) -> coro::task<std::unique_ptr<dns_result>>
{
auto token = m_scheduler.generate_resume_token<void>();
auto result_ptr = std::make_unique<dns_result>(token, 2);
ares_gethostbyname(m_ares_channel, hn.data().data(), AF_INET, ares_dns_callback, result_ptr.get());
ares_gethostbyname(m_ares_channel, hn.data().data(), AF_INET6, ares_dns_callback, result_ptr.get());
// Add all required poll calls for ares to kick off the dns requests.
ares_poll();
// Suspend until this specific result is completed by ares.
co_await m_scheduler.yield(token);
co_return result_ptr;
}
auto dns_client::ares_poll() -> void
{
std::array<ares_socket_t, ARES_GETSOCK_MAXNUM> ares_sockets{};
std::array<poll_op, ARES_GETSOCK_MAXNUM> poll_ops{};
int bitmask = ares_getsock(m_ares_channel, ares_sockets.data(), ARES_GETSOCK_MAXNUM);
size_t new_sockets{0};
for(size_t i = 0; i < ARES_GETSOCK_MAXNUM; ++i)
{
uint64_t ops{0};
if(ARES_GETSOCK_READABLE(bitmask, i))
{
ops |= static_cast<uint64_t>(poll_op::read);
}
if(ARES_GETSOCK_WRITABLE(bitmask, i))
{
ops |= static_cast<uint64_t>(poll_op::write);
}
if(ops != 0)
{
poll_ops[i] = static_cast<poll_op>(ops);
++new_sockets;
}
else
{
// According to ares usage within curl once a bitmask for a socket is zero the rest of
// the bitmask will also be zero.
break;
}
}
for(size_t i = 0; i < new_sockets; ++i)
{
io_scheduler::fd_t fd = static_cast<io_scheduler::fd_t>(ares_sockets[i]);
// If this socket is not currently actively polling, start polling!
if(m_active_sockets.emplace(fd).second)
{
m_scheduler.schedule(make_poll_task(fd, poll_ops[i]));
}
}
}
auto dns_client::make_poll_task(io_scheduler::fd_t fd, poll_op ops) -> coro::task<void>
{
auto result = co_await m_scheduler.poll(fd, ops, m_timeout);
switch(result)
{
case poll_status::event:
{
auto read_sock = poll_op_readable(ops) ? fd : ARES_SOCKET_BAD;
auto write_sock = poll_op_writeable(ops) ? fd : ARES_SOCKET_BAD;
ares_process_fd(m_ares_channel, read_sock, write_sock);
}
break;
case poll_status::timeout:
ares_process_fd(m_ares_channel, ARES_SOCKET_BAD, ARES_SOCKET_BAD);
break;
case poll_status::closed:
// might need to do something like call with two ARES_SOCKET_BAD?
break;
case poll_status::error:
// might need to do something like call with two ARES_SOCKET_BAD?
break;
}
// Remove from the list of actively polling sockets.
m_active_sockets.erase(fd);
// Re-initialize sockets/polls for ares since this one has now triggered.
ares_poll();
co_return;
};
} // namespace coro

19
src/net/ip_address.cpp Normal file
View file

@ -0,0 +1,19 @@
#include "coro/net/ip_address.hpp"
namespace coro::net
{
static std::string domain_ipv4{"ipv4"};
static std::string domain_ipv6{"ipv6"};
auto to_string(domain_t domain) -> const std::string&
{
switch(domain)
{
case domain_t::ipv4: return domain_ipv4;
case domain_t::ipv6: return domain_ipv6;
}
throw std::runtime_error{"coro::net::to_string(domain_t) unknown domain"};
}
} // namespace coro::net

View file

@ -1,30 +1,67 @@
#include "coro/tcp_client.hpp"
#include "coro/io_scheduler.hpp"
#include <ares.h>
namespace coro
{
tcp_client::tcp_client(io_scheduler& scheduler, options opts)
: m_io_scheduler(scheduler),
m_options(std::move(opts)),
m_socket(socket::make_socket(socket::options{m_options.domain, socket::type_t::tcp, socket::blocking_t::yes}))
m_socket(net::socket::make_socket(net::socket::options{m_options.domain, net::socket::type_t::tcp, net::socket::blocking_t::yes}))
{
}
auto tcp_client::connect(std::chrono::milliseconds timeout) -> coro::task<connect_status>
{
sockaddr_in server{};
server.sin_family = socket::domain_to_os(m_options.domain);
server.sin_port = htons(m_options.port);
if (inet_pton(server.sin_family, m_options.address.data(), &server.sin_addr) <= 0)
if(m_connect_status.has_value() && m_connect_status.value() == connect_status::connected)
{
co_return connect_status::invalid_ip_address;
co_return m_connect_status.value();
}
const net::ip_address* ip_addr{nullptr};
std::unique_ptr<dns_result> result_ptr{nullptr};
if(std::holds_alternative<net::hostname>(m_options.address))
{
if(m_options.dns == nullptr)
{
m_connect_status = connect_status::dns_client_required;
co_return connect_status::dns_client_required;
}
const auto& hn = std::get<net::hostname>(m_options.address);
result_ptr = co_await m_options.dns->host_by_name(hn);
if(result_ptr->status() != dns_status::complete)
{
m_connect_status = connect_status::dns_lookup_failure;
co_return connect_status::dns_lookup_failure;
}
if(result_ptr->ip_addresses().empty())
{
m_connect_status = connect_status::dns_lookup_failure;
co_return connect_status::dns_lookup_failure;
}
// TODO: for now we'll just take the first ip address given, but should probably allow the
// user to take preference on ipv4/ipv6 addresses.
ip_addr = &result_ptr->ip_addresses().front();
}
else
{
ip_addr = &std::get<net::ip_address>(m_options.address);
}
sockaddr_in server{};
server.sin_family = static_cast<int>(m_options.domain);
server.sin_port = htons(m_options.port);
server.sin_addr = *reinterpret_cast<const in_addr*>(ip_addr->data().data());
auto cret = ::connect(m_socket.native_handle(), (struct sockaddr*)&server, sizeof(server));
if (cret == 0)
{
// Immediate connect.
m_connect_status = connect_status::connected;
co_return connect_status::connected;
}
else if (cret == -1)
@ -46,16 +83,19 @@ auto tcp_client::connect(std::chrono::milliseconds timeout) -> coro::task<connec
if (result == 0)
{
// success, connected
m_connect_status = connect_status::connected;
co_return connect_status::connected;
}
}
else if (pstatus == poll_status::timeout)
{
m_connect_status = connect_status::timeout;
co_return connect_status::timeout;
}
}
}
m_connect_status = connect_status::error;
co_return connect_status::error;
}

81
src/tcp_scheduler.cpp Normal file
View file

@ -0,0 +1,81 @@
#include "coro/tcp_scheduler.hpp"
namespace coro
{
tcp_scheduler::tcp_scheduler(options opts)
: io_scheduler(std::move(opts.io_options)),
m_opts(std::move(opts)),
m_accept_socket(net::socket::make_accept_socket(
net::socket::options{net::domain_t::ipv4, net::socket::type_t::tcp, net::socket::blocking_t::no},
m_opts.address,
m_opts.port,
m_opts.backlog))
{
if (m_opts.on_connection == nullptr)
{
throw std::runtime_error{"options::on_connection cannot be nullptr."};
}
schedule(make_accept_task());
}
tcp_scheduler::~tcp_scheduler()
{
shutdown();
}
auto tcp_scheduler::shutdown(shutdown_t wait_for_tasks) -> void
{
if (m_accept_new_connections.exchange(false, std::memory_order::release))
{
m_accept_socket.shutdown(); // wake it up by shutting down read/write operations.
while (m_accept_task_exited.load(std::memory_order::acquire) == false)
{
std::this_thread::sleep_for(std::chrono::milliseconds{1});
}
io_scheduler::shutdown(wait_for_tasks);
}
}
auto tcp_scheduler::make_accept_task() -> coro::task<void>
{
sockaddr_in client{};
constexpr const int len = sizeof(struct sockaddr_in);
std::vector<task<void>> tasks{};
tasks.reserve(16);
while (m_accept_new_connections.load(std::memory_order::acquire))
{
co_await poll(m_accept_socket.native_handle(), coro::poll_op::read);
// auto status = co_await poll(m_accept_socket.native_handle(), coro::poll_op::read);
// (void)status; // TODO: introduce timeouts on io_scheduer.poll();
// On accept socket read drain the listen accept queue.
while (true)
{
net::socket s{::accept(m_accept_socket.native_handle(), (struct sockaddr*)&client, (socklen_t*)&len)};
if (s.native_handle() < 0)
{
break;
}
tasks.emplace_back(m_opts.on_connection(std::ref(*this), std::move(s)));
}
if (!tasks.empty())
{
schedule(tasks);
tasks.clear();
}
}
m_accept_task_exited.exchange(true, std::memory_order::release);
co_return;
};
} // namespace coro

View file

@ -2,7 +2,10 @@ cmake_minimum_required(VERSION 3.16)
project(libcoro_test)
set(LIBCORO_TEST_SOURCE_FILES
net/test_ip_address.cpp
bench.cpp
test_dns_client.cpp
test_event.cpp
test_generator.cpp
test_io_scheduler.cpp
@ -16,6 +19,7 @@ set(LIBCORO_TEST_SOURCE_FILES
add_executable(${PROJECT_NAME} main.cpp ${LIBCORO_TEST_SOURCE_FILES})
target_compile_features(${PROJECT_NAME} PUBLIC cxx_std_20)
target_include_directories(${PROJECT_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_link_libraries(${PROJECT_NAME} PRIVATE coro)
target_compile_options(${PROJECT_NAME} PUBLIC -fcoroutines)

View file

@ -0,0 +1,68 @@
#include "catch.hpp"
#include <coro/coro.hpp>
#include <chrono>
#include <iomanip>
TEST_CASE("net::ip_address from_string() ipv4")
{
{
auto ip_addr = coro::net::ip_address::from_string("127.0.0.1");
REQUIRE(ip_addr.to_string() == "127.0.0.1");
REQUIRE(ip_addr.domain() == coro::net::domain_t::ipv4);
std::array<uint8_t, coro::net::ip_address::ipv4_len> expected{127, 0, 0, 1};
REQUIRE(std::equal(expected.begin(), expected.end(), ip_addr.data().begin()));
}
{
auto ip_addr = coro::net::ip_address::from_string("255.255.0.0");
REQUIRE(ip_addr.to_string() == "255.255.0.0");
REQUIRE(ip_addr.domain() == coro::net::domain_t::ipv4);
std::array<uint8_t, coro::net::ip_address::ipv4_len> expected{255, 255, 0, 0};
REQUIRE(std::equal(expected.begin(), expected.end(), ip_addr.data().begin()));
}
}
TEST_CASE("net::ip_address from_string() ipv6")
{
{
auto ip_addr = coro::net::ip_address::from_string("0123:4567:89ab:cdef:0123:4567:89ab:cdef", coro::net::domain_t::ipv6);
REQUIRE(ip_addr.to_string() == "123:4567:89ab:cdef:123:4567:89ab:cdef");
REQUIRE(ip_addr.domain() == coro::net::domain_t::ipv6);
std::array<uint8_t, coro::net::ip_address::ipv6_len> expected{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef};
REQUIRE(std::equal(expected.begin(), expected.end(), ip_addr.data().begin()));
}
{
auto ip_addr = coro::net::ip_address::from_string("::", coro::net::domain_t::ipv6);
REQUIRE(ip_addr.to_string() == "::");
REQUIRE(ip_addr.domain() == coro::net::domain_t::ipv6);
std::array<uint8_t, coro::net::ip_address::ipv6_len> expected{};
REQUIRE(std::equal(expected.begin(), expected.end(), ip_addr.data().begin()));
}
{
auto ip_addr = coro::net::ip_address::from_string("::1", coro::net::domain_t::ipv6);
REQUIRE(ip_addr.to_string() == "::1");
REQUIRE(ip_addr.domain() == coro::net::domain_t::ipv6);
std::array<uint8_t, coro::net::ip_address::ipv6_len> expected{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01};
REQUIRE(std::equal(expected.begin(), expected.end(), ip_addr.data().begin()));
}
{
auto ip_addr = coro::net::ip_address::from_string("1::1", coro::net::domain_t::ipv6);
REQUIRE(ip_addr.to_string() == "1::1");
REQUIRE(ip_addr.domain() == coro::net::domain_t::ipv6);
std::array<uint8_t, coro::net::ip_address::ipv6_len> expected{0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01};
REQUIRE(std::equal(expected.begin(), expected.end(), ip_addr.data().begin()));
}
{
auto ip_addr = coro::net::ip_address::from_string("1::", coro::net::domain_t::ipv6);
REQUIRE(ip_addr.to_string() == "1::");
REQUIRE(ip_addr.domain() == coro::net::domain_t::ipv6);
std::array<uint8_t, coro::net::ip_address::ipv6_len> expected{0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00};
REQUIRE(std::equal(expected.begin(), expected.end(), ip_addr.data().begin()));
}
}

43
test/test_dns_client.cpp Normal file
View file

@ -0,0 +1,43 @@
#include "catch.hpp"
#include <coro/coro.hpp>
#include <chrono>
TEST_CASE("dns_client basic")
{
coro::io_scheduler scheduler{
coro::io_scheduler::options{.thread_strategy = coro::io_scheduler::thread_strategy_t::spawn}
};
coro::dns_client dns_client{scheduler, std::chrono::milliseconds{5000}};
std::atomic<bool> done{false};
auto make_host_by_name_task = [&](coro::net::hostname hn) -> coro::task<void>
{
auto result_ptr = co_await std::move(dns_client.host_by_name(hn));
if(result_ptr->status() == coro::dns_status::complete)
{
for(const auto& ip_addr : result_ptr->ip_addresses())
{
std::cerr << coro::net::to_string(ip_addr.domain()) << " " << ip_addr.to_string() << "\n";
}
}
done = true;
co_return;
};
scheduler.schedule(make_host_by_name_task(coro::net::hostname{"www.example.com"}));
while(!done)
{
std::this_thread::sleep_for(std::chrono::milliseconds{10});
}
scheduler.shutdown();
REQUIRE(scheduler.empty());
}

View file

@ -7,11 +7,11 @@ TEST_CASE("tcp_scheduler no on connection throws")
REQUIRE_THROWS(coro::tcp_scheduler{coro::tcp_scheduler::options{.on_connection = nullptr}});
}
TEST_CASE("tcp_scheduler echo server")
TEST_CASE("tcp_scheduler echo server ip address")
{
const std::string msg{"Hello from client"};
auto on_connection = [&msg](coro::tcp_scheduler& scheduler, coro::socket sock) -> coro::task<void> {
auto on_connection = [&msg](coro::tcp_scheduler& scheduler, coro::net::socket sock) -> coro::task<void> {
std::string in(64, '\0');
auto [rstatus, rbytes] = co_await scheduler.read(sock, std::span<char>{in.data(), in.size()});
@ -28,7 +28,7 @@ TEST_CASE("tcp_scheduler echo server")
};
coro::tcp_scheduler scheduler{coro::tcp_scheduler::options{
.address = "0.0.0.0",
.address = coro::net::ip_address::from_string("0.0.0.0"),
.port = 8080,
.backlog = 128,
.on_connection = on_connection,
@ -37,7 +37,81 @@ TEST_CASE("tcp_scheduler echo server")
auto make_client_task = [&scheduler, &msg]() -> coro::task<void> {
coro::tcp_client client{
scheduler,
coro::tcp_client::options{.address = "127.0.0.1", .port = 8080, .domain = coro::socket::domain_t::ipv4}};
coro::tcp_client::options{.address = coro::net::ip_address::from_string("127.0.0.1"), .port = 8080, .domain = coro::net::domain_t::ipv4}};
auto cstatus = co_await client.connect();
REQUIRE(cstatus == coro::connect_status::connected);
auto [wstatus, wbytes] =
co_await scheduler.write(client.socket(), std::span<const char>{msg.data(), msg.length()});
REQUIRE(wstatus == coro::poll_status::event);
REQUIRE(wbytes == msg.length());
std::string response(64, '\0');
auto [rstatus, rbytes] =
co_await scheduler.read(client.socket(), std::span<char>{response.data(), response.length()});
REQUIRE(rstatus == coro::poll_status::event);
REQUIRE(rbytes == msg.length());
response.resize(rbytes);
REQUIRE(response == msg);
co_return;
};
scheduler.schedule(make_client_task());
// Shutting down the scheduler will cause it to stop accepting new connections, to avoid requiring
// another scheduler for this test the main thread can spin sleep until the tcp scheduler reports
// that it is empty. tcp schedulers do not report their accept task as a task in its size/empty count.
while (!scheduler.empty())
{
std::this_thread::sleep_for(std::chrono::milliseconds{1});
}
scheduler.shutdown();
REQUIRE(scheduler.empty());
}
TEST_CASE("tcp_scheduler echo server hostname")
{
const std::string msg{"Hello from client with dns lookup"};
auto on_connection = [&msg](coro::tcp_scheduler& scheduler, coro::net::socket sock) -> coro::task<void> {
std::string in(64, '\0');
auto [rstatus, rbytes] = co_await scheduler.read(sock, std::span<char>{in.data(), in.size()});
REQUIRE(rstatus == coro::poll_status::event);
in.resize(rbytes);
REQUIRE(in == msg);
auto [wstatus, wbytes] = co_await scheduler.write(sock, std::span<const char>(in.data(), in.length()));
REQUIRE(wstatus == coro::poll_status::event);
REQUIRE(wbytes == in.length());
co_return;
};
coro::tcp_scheduler scheduler{coro::tcp_scheduler::options{
.address = coro::net::ip_address::from_string("0.0.0.0"),
.port = 8080,
.backlog = 128,
.on_connection = on_connection,
.io_options = coro::io_scheduler::options{.thread_strategy = coro::io_scheduler::thread_strategy_t::spawn}}};
coro::dns_client dns_client{scheduler, std::chrono::seconds{5}};
auto make_client_task = [&scheduler, &dns_client, &msg]() -> coro::task<void> {
coro::tcp_client client{
scheduler,
coro::tcp_client::options{
.address = coro::net::hostname{"localhost"},
.port = 8080,
.domain = coro::net::domain_t::ipv4,
.dns = &dns_client}};
auto cstatus = co_await client.connect();
REQUIRE(cstatus == coro::connect_status::connected);

1
vendor/c-ares/c-ares vendored Submodule

@ -0,0 +1 @@
Subproject commit 799e81d4ace75af7d530857d4f8b35913a27463e