1
0
Fork 0
mirror of https://gitlab.com/niansa/libcrosscoro.git synced 2025-03-06 20:53:32 +01:00

Add SSL/TLS support for TCP client/server via OpenSSL (#54)

* Add SSL/TLS support for TCP client/server via OpenSSL

* Comments addressed
This commit is contained in:
Josh Baldwin 2021-02-15 14:01:48 -07:00 committed by GitHub
parent 730928e8b5
commit e1e52b1400
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
20 changed files with 727 additions and 69 deletions

View file

@ -165,6 +165,7 @@ thread pool worker 0 is shutting down.
CMake
make or ninja
pthreads
openssl
gcov/lcov (For generating coverage only)
### Instructions

View file

@ -22,7 +22,8 @@ jobs:
cmake \
git \
ninja-build \
g++-10
g++-10 \
libssl-dev
- name: Checkout # recurisve checkout requires git to be installed first
uses: actions/checkout@v2
with:
@ -55,7 +56,8 @@ jobs:
git \
ninja-build \
gcc-c++-10.2.1 \
lcov
lcov \
openssl-devel
- name: Checkout # recurisve checkout requires git to be installed first
uses: actions/checkout@v2
with:

View file

@ -36,6 +36,8 @@ set(LIBCORO_SOURCE_FILES
inc/coro/net/recv_status.hpp src/net/recv_status.cpp
inc/coro/net/send_status.hpp src/net/send_status.cpp
inc/coro/net/socket.hpp src/net/socket.cpp
inc/coro/net/ssl_context.hpp src/net/ssl_context.cpp
inc/coro/net/ssl_handshake_status.hpp
inc/coro/net/tcp_client.hpp src/net/tcp_client.cpp
inc/coro/net/tcp_server.hpp src/net/tcp_server.cpp
inc/coro/net/udp_peer.hpp src/net/udp_peer.cpp
@ -59,7 +61,7 @@ add_library(${PROJECT_NAME} STATIC ${LIBCORO_SOURCE_FILES})
set_target_properties(${PROJECT_NAME} PROPERTIES LINKER_LANGUAGE CXX)
target_compile_features(${PROJECT_NAME} PUBLIC cxx_std_20)
target_include_directories(${PROJECT_NAME} PUBLIC inc)
target_link_libraries(${PROJECT_NAME} PUBLIC pthread c-ares)
target_link_libraries(${PROJECT_NAME} PUBLIC pthread c-ares ssl crypto)
if(${CMAKE_CXX_COMPILER_ID} MATCHES "GNU")
if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "10.2.0")

View file

@ -462,6 +462,7 @@ thread pool worker 0 is shutting down.
CMake
make or ninja
pthreads
openssl
gcov/lcov (For generating coverage only)
### Instructions

View file

@ -11,6 +11,7 @@
#include "coro/net/recv_status.hpp"
#include "coro/net/send_status.hpp"
#include "coro/net/socket.hpp"
#include "coro/net/ssl_context.hpp"
#include "coro/net/tcp_client.hpp"
#include "coro/net/tcp_server.hpp"
#include "coro/net/udp_peer.hpp"

View file

@ -13,9 +13,7 @@ enum class connect_status
/// The connection operation timed out.
timeout,
/// There was an error, use errno to get more information on the specific error.
error,
/// The dns hostname lookup failed
dns_lookup_failure
error
};
/**

View file

@ -22,7 +22,9 @@ enum class recv_status : int64_t
invalid_argument = EINVAL,
no_memory = ENOMEM,
not_connected = ENOTCONN,
not_a_socket = ENOTSOCK
not_a_socket = ENOTSOCK,
ssl_error = -3
};
auto to_string(recv_status status) -> const std::string&;

View file

@ -24,7 +24,9 @@ enum class send_status : int64_t
not_connected = ENOTCONN,
not_a_socket = ENOTSOCK,
operationg_not_supported = EOPNOTSUPP,
pipe_closed = EPIPE
pipe_closed = EPIPE,
ssl_error = -3
};
} // namespace coro::net

View file

@ -0,0 +1,61 @@
#pragma once
#include <filesystem>
#include <mutex>
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/pem.h>
#include <openssl/ssl.h>
namespace coro::net
{
class tcp_client;
enum class ssl_file_type : int
{
/// The file is of type ASN1
asn1 = SSL_FILETYPE_ASN1,
/// The file is of type PEM
pem = SSL_FILETYPE_PEM
};
/**
* SSL context, used with client or server types to provide secure connections.
*/
class ssl_context
{
public:
/**
* Creates a context with no certificate and no private key, maybe useful for testing.
*/
ssl_context();
/**
* Creates a context with the given certificate and the given private key.
* @param certificate The location of the certificate file.
* @param certificate_type See `ssl_file_type`.
* @param private_key The location of the private key file.
* @param private_key_type See `ssl_file_type`.
*/
ssl_context(
std::filesystem::path certificate,
ssl_file_type certificate_type,
std::filesystem::path private_key,
ssl_file_type private_key_type);
~ssl_context();
private:
static uint64_t m_ssl_context_count;
static std::mutex m_ssl_context_mutex;
SSL_CTX* m_ssl_ctx{nullptr};
/// The following classes use the underlying SSL_CTX* object for performing SSL functions.
friend tcp_client;
auto native_handle() -> SSL_CTX* { return m_ssl_ctx; }
auto native_handle() const -> const SSL_CTX* { return m_ssl_ctx; }
};
} // namespace coro::net

View file

@ -0,0 +1,28 @@
#pragma once
namespace coro::net
{
enum class ssl_handshake_status
{
/// The ssl handshake was successful.
ok,
/// The connection hasn't been established yet, use connect() prior to the ssl_handshake().
not_connected,
/// The connection needs a coro::net::ssl_context to perform the handshake.
ssl_context_required,
/// The internal ssl memory alocation failed.
ssl_resource_allocation_failed,
/// Attempting to set the connections ssl socket/file descriptor failed.
ssl_set_fd_failure,
/// The handshake had an error.
handshake_failed,
/// The handshake timed out.
timeout,
/// An error occurred while polling for read or write operations on the socket.
poll_error,
/// The socket was unexpectedly closed while attempting the handshake.
unexpected_close
};
} // namespace coro::net

View file

@ -7,15 +7,14 @@
#include "coro/net/recv_status.hpp"
#include "coro/net/send_status.hpp"
#include "coro/net/socket.hpp"
#include "coro/net/ssl_context.hpp"
#include "coro/net/ssl_handshake_status.hpp"
#include "coro/poll.hpp"
#include "coro/task.hpp"
#include <chrono>
#include <memory>
#include <optional>
#include <string>
#include <variant>
#include <vector>
namespace coro::net
{
@ -30,6 +29,8 @@ public:
net::ip_address address{net::ip_address::from_string("127.0.0.1")};
/// The port to connect to.
uint16_t port{8080};
/// Should this tcp_client connect using a secure connection SSL/TLS?
ssl_context* ssl_ctx{nullptr};
};
/**
@ -41,12 +42,13 @@ public:
*/
tcp_client(
io_scheduler& scheduler,
options opts = options{.address = {net::ip_address::from_string("127.0.0.1")}, .port = 8080});
options opts = options{
.address = {net::ip_address::from_string("127.0.0.1")}, .port = 8080, .ssl_ctx = nullptr});
tcp_client(const tcp_client&) = delete;
tcp_client(tcp_client&&) = default;
tcp_client(tcp_client&& other);
auto operator=(const tcp_client&) noexcept -> tcp_client& = delete;
auto operator=(tcp_client&&) noexcept -> tcp_client& = default;
~tcp_client() = default;
auto operator =(tcp_client&& other) noexcept -> tcp_client&;
~tcp_client();
/**
* @return The tcp socket this client is using.
@ -64,6 +66,32 @@ public:
*/
auto connect(std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task<net::connect_status>;
/**
* If this client is connected and the connection is SSL/TLS then perform the ssl handshake.
* This must be done after a successful connect() call for clients that are initiating a
* connection to a server. This must be done after a successful accept() call for clients that
* have been accepted by a tcp_server. TCP server 'client's start in the connected state and
* thus skip the connect() call.
*
* tcp_client initiating to a server:
* tcp_client client{...options...};
* co_await client.connect();
* co_await client.ssl_handshake(); // <-- only perform if ssl/tls connection
*
* tcp_server accepting a client connection:
* tcp_server server{...options...};
* co_await server.poll();
* auto client = server.accept();
* if(client.socket().is_valid())
* {
* co_await client.ssl_handshake(); // <-- only perform if ssl/tls connection
* }
* @param timeout How long to allow for the ssl handshake to successfully complete?
* @return The result of the ssl handshake.
*/
auto ssl_handshake(std::chrono::milliseconds timeout = std::chrono::milliseconds{0})
-> coro::task<ssl_handshake_status>;
/**
* Polls for the given operation on this client's tcp socket. This should be done prior to
* calling recv and after a send that doesn't send the entire buffer.
@ -75,7 +103,7 @@ public:
auto poll(coro::poll_op op, std::chrono::milliseconds timeout = std::chrono::milliseconds{0})
-> coro::task<poll_status>
{
return m_io_scheduler.poll(m_socket, op, timeout);
return m_io_scheduler->poll(m_socket, op, timeout);
}
/**
@ -94,21 +122,48 @@ public:
return {recv_status::ok, std::span<char>{}};
}
auto bytes_recv = ::recv(m_socket.native_handle(), buffer.data(), buffer.size(), 0);
if (bytes_recv > 0)
if (m_options.ssl_ctx == nullptr)
{
// Ok, we've recieved some data.
return {recv_status::ok, std::span<char>{buffer.data(), static_cast<size_t>(bytes_recv)}};
}
else if (bytes_recv == 0)
{
// On TCP stream sockets 0 indicates the connection has been closed by the peer.
return {recv_status::closed, std::span<char>{}};
auto bytes_recv = ::recv(m_socket.native_handle(), buffer.data(), buffer.size(), 0);
if (bytes_recv > 0)
{
// Ok, we've recieved some data.
return {recv_status::ok, std::span<char>{buffer.data(), static_cast<size_t>(bytes_recv)}};
}
else if (bytes_recv == 0)
{
// On TCP stream sockets 0 indicates the connection has been closed by the peer.
return {recv_status::closed, std::span<char>{}};
}
else
{
// Report the error to the user.
return {static_cast<recv_status>(errno), std::span<char>{}};
}
}
else
{
// Report the error to the user.
return {static_cast<recv_status>(errno), std::span<char>{}};
ERR_clear_error();
size_t bytes_recv{0};
int r = SSL_read_ex(m_ssl_info.m_ssl_ptr.get(), buffer.data(), buffer.size(), &bytes_recv);
if (r == 0)
{
int err = SSL_get_error(m_ssl_info.m_ssl_ptr.get(), r);
if (err == SSL_ERROR_WANT_READ)
{
return {recv_status::would_block, std::span<char>{}};
}
else
{
// TODO: Flesh out all possible ssl errors:
// https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html
return {recv_status::ssl_error, std::span<char>{}};
}
}
else
{
return {recv_status::ok, std::span<char>{buffer.data(), static_cast<size_t>(bytes_recv)}};
}
}
}
@ -130,32 +185,120 @@ public:
return {send_status::ok, std::span<const char>{buffer.data(), buffer.size()}};
}
auto bytes_sent = ::send(m_socket.native_handle(), buffer.data(), buffer.size(), 0);
if (bytes_sent >= 0)
if (m_options.ssl_ctx == nullptr)
{
// Some or all of the bytes were written.
return {send_status::ok, std::span<const char>{buffer.data() + bytes_sent, buffer.size() - bytes_sent}};
auto bytes_sent = ::send(m_socket.native_handle(), buffer.data(), buffer.size(), 0);
if (bytes_sent >= 0)
{
// Some or all of the bytes were written.
return {send_status::ok, std::span<const char>{buffer.data() + bytes_sent, buffer.size() - bytes_sent}};
}
else
{
// Due to the error none of the bytes were written.
return {static_cast<send_status>(errno), std::span<const char>{buffer.data(), buffer.size()}};
}
}
else
{
// Due to the error none of the bytes were written.
return {static_cast<send_status>(errno), std::span<const char>{buffer.data(), buffer.size()}};
ERR_clear_error();
size_t bytes_sent{0};
int r = SSL_write_ex(m_ssl_info.m_ssl_ptr.get(), buffer.data(), buffer.size(), &bytes_sent);
if (r == 0)
{
int err = SSL_get_error(m_ssl_info.m_ssl_ptr.get(), r);
if (err == SSL_ERROR_WANT_WRITE)
{
return {send_status::would_block, std::span<char>{}};
}
else
{
// TODO: Flesh out all possible ssl errors:
// https://www.openssl.org/docs/man1.1.1/man3/SSL_get_error.html
return {send_status::ssl_error, std::span<char>{}};
}
}
else
{
return {send_status::ok, std::span<const char>{buffer.data() + bytes_sent, buffer.size() - bytes_sent}};
}
}
}
private:
struct ssl_deleter
{
auto operator()(SSL* ssl) const -> void { SSL_free(ssl); }
};
using ssl_unique_ptr = std::unique_ptr<SSL, ssl_deleter>;
enum class ssl_connection_type
{
/// This connection is a client connecting to a server.
connect,
/// This connection is an accepted connection on a sever.
accept
};
struct ssl_info
{
ssl_info() {}
explicit ssl_info(ssl_connection_type type) : m_ssl_connection_type(type) {}
ssl_info(const ssl_info&) noexcept = delete;
ssl_info(ssl_info&& other) noexcept
: m_ssl_connection_type(std::exchange(other.m_ssl_connection_type, ssl_connection_type::connect)),
m_ssl_ptr(std::move(other.m_ssl_ptr)),
m_ssl_error(std::exchange(other.m_ssl_error, false)),
m_ssl_handshake_status(std::move(other.m_ssl_handshake_status))
{
}
auto operator=(const ssl_info&) noexcept -> ssl_info& = delete;
auto operator=(ssl_info&& other) noexcept -> ssl_info&
{
if (std::addressof(other) != this)
{
m_ssl_connection_type = std::exchange(other.m_ssl_connection_type, ssl_connection_type::connect);
m_ssl_ptr = std::move(other.m_ssl_ptr);
m_ssl_error = std::exchange(other.m_ssl_error, false);
m_ssl_handshake_status = std::move(other.m_ssl_handshake_status);
}
return *this;
}
/// What kind of connection is this, client initiated connect or server side accept?
ssl_connection_type m_ssl_connection_type{ssl_connection_type::connect};
/// OpenSSL ssl connection.
ssl_unique_ptr m_ssl_ptr{nullptr};
/// Was there an error with the SSL/TLS connection?
bool m_ssl_error{false};
/// The result of the ssl handshake.
std::optional<ssl_handshake_status> m_ssl_handshake_status{std::nullopt};
};
/// The tcp_server creates already connected clients and provides a tcp socket pre-built.
friend tcp_server;
tcp_client(io_scheduler& scheduler, net::socket socket, options opts);
/// The scheduler that will drive this tcp client.
io_scheduler& m_io_scheduler;
io_scheduler* m_io_scheduler{nullptr};
/// Options for what server to connect to.
options m_options{};
/// The tcp socket.
net::socket m_socket{-1};
/// Cache the status of the connect in the event the user calls connect() again.
std::optional<net::connect_status> m_connect_status{std::nullopt};
/// SSL/TLS specific information if m_options.ssl_ctx != nullptr.
ssl_info m_ssl_info{};
private:
static auto ssl_shutdown_and_free(
io_scheduler& io_scheduler,
net::socket s,
ssl_unique_ptr ssl_ptr,
std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task<void>;
};
} // namespace coro::net

View file

@ -1,17 +1,22 @@
#pragma once
#include "coro/io_scheduler.hpp"
#include "coro/net/ip_address.hpp"
#include "coro/net/socket.hpp"
#include "coro/net/tcp_client.hpp"
#include "coro/task.hpp"
#include <fcntl.h>
#include <functional>
#include <sys/socket.h>
namespace coro
{
class io_scheduler;
} // namespace coro
namespace coro::net
{
class ssl_context;
class tcp_server
{
public:
@ -23,17 +28,21 @@ public:
uint16_t port{8080};
/// The kernel backlog of connections to buffer.
int32_t backlog{128};
/// Should this tcp server use TLS/SSL? If provided all accepted connections will use the
/// given SSL certificate and private key to secure the connections.
ssl_context* ssl_ctx{nullptr};
};
tcp_server(
io_scheduler& scheduler,
options opts = options{.address = net::ip_address::from_string("0.0.0.0"), .port = 8080, .backlog = 128});
options opts = options{
.address = net::ip_address::from_string("0.0.0.0"), .port = 8080, .backlog = 128, .ssl_ctx = nullptr});
tcp_server(const tcp_server&) = delete;
tcp_server(tcp_server&&) = delete;
tcp_server(tcp_server&& other);
auto operator=(const tcp_server&) -> tcp_server& = delete;
auto operator=(tcp_server&&) -> tcp_server& = delete;
~tcp_server() = default;
auto operator =(tcp_server&& other) -> tcp_server&;
~tcp_server() = default;
/**
* Polls for new incoming tcp connections.
@ -43,7 +52,7 @@ public:
*/
auto poll(std::chrono::milliseconds timeout = std::chrono::milliseconds{0}) -> coro::task<coro::poll_status>
{
return m_io_scheduler.poll(m_accept_socket, coro::poll_op::read, timeout);
return m_io_scheduler->poll(m_accept_socket, coro::poll_op::read, timeout);
}
/**
@ -55,7 +64,7 @@ public:
private:
/// The io scheduler for awaiting new connections.
io_scheduler& m_io_scheduler;
io_scheduler* m_io_scheduler{nullptr};
/// The bind and listen options for this server.
options m_options;
/// The socket for accepting new tcp connections on.

View file

@ -4,11 +4,10 @@
namespace coro::net
{
static std::string connect_status_connected{"connected"};
static std::string connect_status_invalid_ip_address{"invalid_ip_address"};
static std::string connect_status_timeout{"timeout"};
static std::string connect_status_error{"error"};
static std::string connect_status_dns_lookup_failure{"dns_lookup_failure"};
const static std::string connect_status_connected{"connected"};
const static std::string connect_status_invalid_ip_address{"invalid_ip_address"};
const static std::string connect_status_timeout{"timeout"};
const static std::string connect_status_error{"error"};
auto to_string(const connect_status& status) -> const std::string&
{
@ -22,8 +21,6 @@ auto to_string(const connect_status& status) -> const std::string&
return connect_status_timeout;
case connect_status::error:
return connect_status_error;
case connect_status::dns_lookup_failure:
return connect_status_dns_lookup_failure;
}
throw std::logic_error{"Invalid/unknown connect status."};

View file

@ -51,7 +51,7 @@ dns_resolver::dns_resolver(io_scheduler& scheduler, std::chrono::milliseconds ti
m_timeout(timeout)
{
{
std::lock_guard<std::mutex> g{m_ares_mutex};
std::scoped_lock g{m_ares_mutex};
if (m_ares_count == 0)
{
auto ares_status = ares_library_init(ARES_LIB_INIT_ALL);
@ -79,7 +79,7 @@ dns_resolver::~dns_resolver()
}
{
std::lock_guard<std::mutex> g{m_ares_mutex};
std::scoped_lock g{m_ares_mutex};
--m_ares_count;
if (m_ares_count == 0)
{

View file

@ -17,6 +17,8 @@ static const std::string recv_status_not_connected{"not_connected"};
static const std::string recv_status_not_a_socket{"not_a_socket"};
static const std::string recv_status_unknown{"unknown"};
static const std::string recv_status_ssl_error{"ssl_error"};
auto to_string(recv_status status) -> const std::string&
{
switch (status)
@ -46,6 +48,9 @@ auto to_string(recv_status status) -> const std::string&
return recv_status_not_connected;
case recv_status::not_a_socket:
return recv_status_not_a_socket;
case recv_status::ssl_error:
return recv_status_ssl_error;
}
return recv_status_unknown;

65
src/net/ssl_context.cpp Normal file
View file

@ -0,0 +1,65 @@
#include "coro/net/ssl_context.hpp"
#include <iostream>
namespace coro::net
{
uint64_t ssl_context::m_ssl_context_count{0};
std::mutex ssl_context::m_ssl_context_mutex{};
ssl_context::ssl_context()
{
{
std::scoped_lock g{m_ssl_context_mutex};
if (m_ssl_context_count == 0)
{
OPENSSL_init_ssl(0, nullptr);
}
++m_ssl_context_count;
}
m_ssl_ctx = SSL_CTX_new(TLS_method());
if (m_ssl_ctx == nullptr)
{
throw std::runtime_error{"Failed to initialize OpenSSL Context object."};
}
// Disable SSLv3
SSL_CTX_set_options(m_ssl_ctx, SSL_OP_ALL | SSL_OP_NO_SSLv3);
}
ssl_context::ssl_context(
std::filesystem::path certificate,
ssl_file_type certificate_type,
std::filesystem::path private_key,
ssl_file_type private_key_type)
: ssl_context()
{
if (auto r = SSL_CTX_use_certificate_file(m_ssl_ctx, certificate.c_str(), static_cast<int>(certificate_type));
r != 1)
{
throw std::runtime_error{"Failed to load certificate file " + certificate.string()};
}
if (auto r = SSL_CTX_use_PrivateKey_file(m_ssl_ctx, private_key.c_str(), static_cast<int>(private_key_type));
r != 1)
{
throw std::runtime_error{"Failed to load private key file " + private_key.string()};
}
if (auto r = SSL_CTX_check_private_key(m_ssl_ctx); r != 1)
{
throw std::runtime_error{"Certificate and private key do not match."};
}
}
ssl_context::~ssl_context()
{
if (m_ssl_ctx != nullptr)
{
SSL_CTX_free(m_ssl_ctx);
m_ssl_ctx = nullptr;
}
}
} // namespace coro::net

View file

@ -5,7 +5,7 @@ namespace coro::net
using namespace std::chrono_literals;
tcp_client::tcp_client(io_scheduler& scheduler, options opts)
: m_io_scheduler(scheduler),
: m_io_scheduler(&scheduler),
m_options(std::move(opts)),
m_socket(net::make_socket(
net::socket::options{m_options.address.domain(), net::socket::type_t::tcp, net::socket::blocking_t::no}))
@ -13,20 +13,66 @@ tcp_client::tcp_client(io_scheduler& scheduler, options opts)
}
tcp_client::tcp_client(io_scheduler& scheduler, net::socket socket, options opts)
: m_io_scheduler(scheduler),
: m_io_scheduler(&scheduler),
m_options(std::move(opts)),
m_socket(std::move(socket)),
m_connect_status(connect_status::connected)
m_connect_status(connect_status::connected),
m_ssl_info(ssl_connection_type::accept)
{
// Force the socket to be non-blocking.
m_socket.blocking(coro::net::socket::blocking_t::no);
}
tcp_client::tcp_client(tcp_client&& other)
: m_io_scheduler(std::exchange(other.m_io_scheduler, nullptr)),
m_options(std::move(other.m_options)),
m_socket(std::move(other.m_socket)),
m_connect_status(std::exchange(other.m_connect_status, std::nullopt)),
m_ssl_info(std::move(other.m_ssl_info))
{
}
tcp_client::~tcp_client()
{
// If this tcp client is using SSL and the connection did not have an ssl error, schedule a task
// to shutdown the connection cleanly. This is done on a background scheduled task since the
// tcp client's destructor cannot co_await the SSL_shutdown() read and write poll operations.
if (m_ssl_info.m_ssl_ptr != nullptr && !m_ssl_info.m_ssl_error)
{
// Should the shutdown timeout be configurable?
ssl_shutdown_and_free(
*m_io_scheduler, std::move(m_socket), std::move(m_ssl_info.m_ssl_ptr), std::chrono::seconds{30});
}
}
auto tcp_client::operator=(tcp_client&& other) noexcept -> tcp_client&
{
if (std::addressof(other) != this)
{
m_io_scheduler = std::exchange(other.m_io_scheduler, nullptr);
m_options = std::move(other.m_options);
m_socket = std::move(other.m_socket);
m_connect_status = std::exchange(other.m_connect_status, std::nullopt);
m_ssl_info = std::move(other.m_ssl_info);
}
return *this;
}
auto tcp_client::connect(std::chrono::milliseconds timeout) -> coro::task<connect_status>
{
if (m_connect_status.has_value() && m_connect_status.value() == connect_status::connected)
// Only allow the user to connect per tcp client once, if they need to re-connect they should
// make a new tcp_client.
if (m_connect_status.has_value())
{
co_return m_connect_status.value();
}
// This enforces the connection status is aways set on the client object upon returning.
auto return_value = [this](connect_status s) -> connect_status {
m_connect_status = s;
return s;
};
sockaddr_in server{};
server.sin_family = static_cast<int>(m_options.address.domain());
server.sin_port = htons(m_options.port);
@ -35,9 +81,7 @@ auto tcp_client::connect(std::chrono::milliseconds timeout) -> coro::task<connec
auto cret = ::connect(m_socket.native_handle(), (struct sockaddr*)&server, sizeof(server));
if (cret == 0)
{
// Immediate connect.
m_connect_status = connect_status::connected;
co_return connect_status::connected;
co_return return_value(connect_status::connected);
}
else if (cret == -1)
{
@ -45,7 +89,7 @@ auto tcp_client::connect(std::chrono::milliseconds timeout) -> coro::task<connec
// when the connection is established.
if (errno == EAGAIN || errno == EINPROGRESS)
{
auto pstatus = co_await m_io_scheduler.poll(m_socket.native_handle(), poll_op::write, timeout);
auto pstatus = co_await m_io_scheduler->poll(m_socket, poll_op::write, timeout);
if (pstatus == poll_status::event)
{
int result{0};
@ -57,21 +101,154 @@ auto tcp_client::connect(std::chrono::milliseconds timeout) -> coro::task<connec
if (result == 0)
{
// success, connected
m_connect_status = connect_status::connected;
co_return connect_status::connected;
co_return return_value(connect_status::connected);
}
}
else if (pstatus == poll_status::timeout)
{
m_connect_status = connect_status::timeout;
co_return connect_status::timeout;
co_return return_value(connect_status::timeout);
}
}
}
m_connect_status = connect_status::error;
co_return connect_status::error;
co_return return_value(connect_status::error);
}
auto tcp_client::ssl_handshake(std::chrono::milliseconds timeout) -> coro::task<ssl_handshake_status>
{
if (!m_connect_status.has_value() || m_connect_status.value() != connect_status::connected)
{
// Can't ssl handshake if the connection isn't established.
co_return ssl_handshake_status::not_connected;
}
if (m_options.ssl_ctx == nullptr)
{
// ssl isn't setup
co_return ssl_handshake_status::ssl_context_required;
}
if (m_ssl_info.m_ssl_handshake_status.has_value())
{
// The user has already called this function.
co_return m_ssl_info.m_ssl_handshake_status.value();
}
// Enforce on any return past here to set the cached handshake status.
auto return_value = [this](ssl_handshake_status s) -> ssl_handshake_status {
m_ssl_info.m_ssl_handshake_status = s;
return s;
};
m_ssl_info.m_ssl_ptr = ssl_unique_ptr{SSL_new(m_options.ssl_ctx->native_handle())};
if (m_ssl_info.m_ssl_ptr == nullptr)
{
co_return return_value(ssl_handshake_status::ssl_resource_allocation_failed);
}
if (auto r = SSL_set_fd(m_ssl_info.m_ssl_ptr.get(), m_socket.native_handle()); r == 0)
{
co_return return_value(ssl_handshake_status::ssl_set_fd_failure);
}
if (m_ssl_info.m_ssl_connection_type == ssl_connection_type::connect)
{
SSL_set_connect_state(m_ssl_info.m_ssl_ptr.get());
}
else // ssl_connection_type::accept
{
SSL_set_accept_state(m_ssl_info.m_ssl_ptr.get());
}
int r{0};
ERR_clear_error();
while ((r = SSL_do_handshake(m_ssl_info.m_ssl_ptr.get())) != 1)
{
poll_op op{poll_op::read_write};
int err = SSL_get_error(m_ssl_info.m_ssl_ptr.get(), r);
if (err == SSL_ERROR_WANT_WRITE)
{
op = poll_op::write;
}
else if (err == SSL_ERROR_WANT_READ)
{
op = poll_op::read;
}
else
{
// char error_buffer[256];
// ERR_error_string(err, error_buffer);
// std::cerr << "ssl_handleshake error=[" << error_buffer << "]\n";
co_return return_value(ssl_handshake_status::handshake_failed);
}
// TODO: adjust timeout based on elapsed time so far.
auto pstatus = co_await m_io_scheduler->poll(m_socket, op, timeout);
switch (pstatus)
{
case poll_status::timeout:
co_return return_value(ssl_handshake_status::timeout);
case poll_status::error:
co_return return_value(ssl_handshake_status::poll_error);
case poll_status::closed:
co_return return_value(ssl_handshake_status::unexpected_close);
default:
// Event triggered, continue handshake.
break;
}
}
co_return return_value(ssl_handshake_status::ok);
}
auto tcp_client::ssl_shutdown_and_free(
io_scheduler& io_scheduler, net::socket s, ssl_unique_ptr ssl_ptr, std::chrono::milliseconds timeout)
-> coro::task<void>
{
// Immediately transfer onto the scheduler thread pool for background processing.
co_await io_scheduler.schedule();
while (true)
{
auto r = SSL_shutdown(ssl_ptr.get());
if (r == 1) // shutdown complete
{
co_return;
}
else if (r == 0) // shutdown in progress
{
coro::poll_op op{coro::poll_op::read_write};
auto err = SSL_get_error(ssl_ptr.get(), r);
if (err == SSL_ERROR_WANT_WRITE)
{
op = coro::poll_op::write;
}
else if (err == SSL_ERROR_WANT_READ)
{
op = coro::poll_op::read;
}
else
{
co_return;
}
auto pstatus = co_await io_scheduler.poll(s, op, timeout);
switch (pstatus)
{
case poll_status::timeout:
case poll_status::error:
case poll_status::closed:
co_return;
default:
// continue shutdown.
break;
}
}
else // r < 0 error
{
co_return;
}
}
}
} // namespace coro::net

View file

@ -1,9 +1,11 @@
#include "coro/net/tcp_server.hpp"
#include "coro/io_scheduler.hpp"
namespace coro::net
{
tcp_server::tcp_server(io_scheduler& scheduler, options opts)
: m_io_scheduler(scheduler),
: m_io_scheduler(&scheduler),
m_options(std::move(opts)),
m_accept_socket(net::make_accept_socket(
net::socket::options{net::domain_t::ipv4, net::socket::type_t::tcp, net::socket::blocking_t::no},
@ -13,6 +15,24 @@ tcp_server::tcp_server(io_scheduler& scheduler, options opts)
{
}
tcp_server::tcp_server(tcp_server&& other)
: m_io_scheduler(std::exchange(other.m_io_scheduler, nullptr)),
m_options(std::move(other.m_options)),
m_accept_socket(std::move(other.m_accept_socket))
{
}
auto tcp_server::operator=(tcp_server&& other) -> tcp_server&
{
if (std::addressof(other) != this)
{
m_io_scheduler = std::exchange(other.m_io_scheduler, nullptr);
m_options = std::move(other.m_options);
m_accept_socket = std::move(other.m_accept_socket);
}
return *this;
}
auto tcp_server::accept() -> coro::net::tcp_client
{
sockaddr_in client{};
@ -25,11 +45,12 @@ auto tcp_server::accept() -> coro::net::tcp_client
};
return tcp_client{
m_io_scheduler,
*m_io_scheduler,
std::move(s),
tcp_client::options{
.address = net::ip_address{ip_addr_view, static_cast<net::domain_t>(client.sin_family)},
.port = ntohs(client.sin_port)}};
.port = ntohs(client.sin_port),
.ssl_ctx = m_options.ssl_ctx}};
};
} // namespace coro::net

View file

@ -1,2 +1,30 @@
#define CATCH_CONFIG_MAIN
#include "catch.hpp"
#include <signal.h>
/**
* This structure invokes a constructor to setup some global test settings that are needed prior
* to executing the tests.
*/
struct test_setup
{
test_setup()
{
// Ignore SIGPIPE, the library should be handling these gracefully.
signal(SIGPIPE, SIG_IGN);
// For SSL/TLS tests create a localhost cert.pem and key.pem, tests expected these files
// to be generated into the same directory that the tests are running in.
system(
"openssl req -x509 -newkey rsa:4096 -keyout key.pem -out cert.pem -days 365 -subj '/CN=localhost' -nodes");
}
~test_setup()
{
// Cleanup the temporary key.pem and cert.pem files.
system("rm key.pem cert.pem");
}
};
static test_setup g_test_setup{};

View file

@ -80,3 +80,118 @@ TEST_CASE("tcp_server ping server", "[tcp_server]")
coro::sync_wait(coro::when_all(make_server_task(), make_client_task()));
}
TEST_CASE("tcp_server with ssl", "[tcp_server]")
{
coro::io_scheduler scheduler{coro::io_scheduler::options{.pool = coro::thread_pool::options{.thread_count = 1}}};
coro::net::ssl_context client_ssl_context{};
coro::net::ssl_context server_ssl_context{
"cert.pem", coro::net::ssl_file_type::pem, "key.pem", coro::net::ssl_file_type::pem};
std::string client_msg = "Hello world from SSL client!";
std::string server_msg = "Hello world from SSL server!!";
auto make_client_task = [&]() -> coro::task<void> {
co_await scheduler.schedule();
coro::net::tcp_client client{scheduler, coro::net::tcp_client::options{.ssl_ctx = &client_ssl_context}};
std::cerr << "client.connect()\n";
auto cstatus = co_await client.connect();
REQUIRE(cstatus == coro::net::connect_status::connected);
std::cerr << "client.connected\n";
std::cerr << "client.ssl_handshake()\n";
auto hstatus = co_await client.ssl_handshake();
REQUIRE(hstatus == coro::net::ssl_handshake_status::ok);
std::cerr << "client.poll(write)\n";
auto pstatus = co_await client.poll(coro::poll_op::write);
REQUIRE(pstatus == coro::poll_status::event);
std::cerr << "client.send()\n";
auto [sstatus, remaining] = client.send(client_msg);
REQUIRE(sstatus == coro::net::send_status::ok);
REQUIRE(remaining.empty());
std::string response;
response.resize(256, '\0');
while (true)
{
std::cerr << "client.poll(read)\n";
pstatus = co_await client.poll(coro::poll_op::read);
REQUIRE(pstatus == coro::poll_status::event);
std::cerr << "client.recv()\n";
auto [rstatus, rspan] = client.recv(response);
if (rstatus == coro::net::recv_status::would_block)
{
std::cerr << coro::net::to_string(rstatus) << "\n";
continue;
}
else
{
std::cerr << coro::net::to_string(rstatus) << "\n";
REQUIRE(rstatus == coro::net::recv_status::ok);
REQUIRE(rspan.size() == server_msg.size());
response.resize(rspan.size());
break;
}
}
REQUIRE(response == server_msg);
std::cerr << "client received message: " << response << "\n";
std::cerr << "client finished\n";
co_return;
};
auto make_server_task = [&]() -> coro::task<void> {
co_await scheduler.schedule();
coro::net::tcp_server server{scheduler, coro::net::tcp_server::options{.ssl_ctx = &server_ssl_context}};
std::cerr << "server.poll()\n";
auto pstatus = co_await server.poll();
REQUIRE(pstatus == coro::poll_status::event);
std::cerr << "server.accept()\n";
auto client = server.accept();
REQUIRE(client.socket().is_valid());
std::cerr << "server client.handshake()\n";
auto hstatus = co_await client.ssl_handshake();
REQUIRE(hstatus == coro::net::ssl_handshake_status::ok);
std::cerr << "server client.poll(read)\n";
pstatus = co_await client.poll(coro::poll_op::read);
REQUIRE(pstatus == coro::poll_status::event);
std::string buffer;
buffer.resize(256, '\0');
std::cerr << "server client.recv()\n";
auto [rstatus, rspan] = client.recv(buffer);
REQUIRE(rstatus == coro::net::recv_status::ok);
REQUIRE(rspan.size() == client_msg.size());
buffer.resize(rspan.size());
REQUIRE(buffer == client_msg);
std::cerr << "server received message: " << buffer << "\n";
std::cerr << "server client.poll(write)\n";
pstatus = co_await client.poll(coro::poll_op::write);
REQUIRE(pstatus == coro::poll_status::event);
std::cerr << "server client.send()\n";
auto [sstatus, remaining] = client.send(server_msg);
REQUIRE(sstatus == coro::net::send_status::ok);
REQUIRE(remaining.empty());
std::cerr << "server finished\n";
co_return;
};
coro::sync_wait(coro::when_all(make_server_task(), make_client_task()));
}