mirror of
https://gitlab.com/niansa/libcrosscoro.git
synced 2025-03-06 20:53:32 +01:00
* std::shared_ptr<executor_type> for coro::shared_mutex * implement remaining types that leverage executor or io_scheduler
259 lines
8.6 KiB
C++
259 lines
8.6 KiB
C++
#include "coro/net/tcp_client.hpp"
|
|
|
|
namespace coro::net
|
|
{
|
|
using namespace std::chrono_literals;
|
|
|
|
tcp_client::tcp_client(std::shared_ptr<io_scheduler> scheduler, options opts)
|
|
: m_io_scheduler(std::move(scheduler)),
|
|
m_options(std::move(opts)),
|
|
m_socket(net::make_socket(
|
|
net::socket::options{m_options.address.domain(), net::socket::type_t::tcp, net::socket::blocking_t::no}))
|
|
{
|
|
if (m_io_scheduler == nullptr)
|
|
{
|
|
throw std::runtime_error{"tcp_client cannot have nullptr io_scheduler"};
|
|
}
|
|
}
|
|
|
|
tcp_client::tcp_client(std::shared_ptr<io_scheduler> scheduler, net::socket socket, options opts)
|
|
: m_io_scheduler(std::move(scheduler)),
|
|
m_options(std::move(opts)),
|
|
m_socket(std::move(socket)),
|
|
m_connect_status(connect_status::connected),
|
|
m_ssl_info(ssl_connection_type::accept)
|
|
{
|
|
// io_scheduler is assumed good since it comes from a tcp_server.
|
|
|
|
// Force the socket to be non-blocking.
|
|
m_socket.blocking(coro::net::socket::blocking_t::no);
|
|
}
|
|
|
|
tcp_client::tcp_client(tcp_client&& other)
|
|
: m_io_scheduler(std::move(other.m_io_scheduler)),
|
|
m_options(std::move(other.m_options)),
|
|
m_socket(std::move(other.m_socket)),
|
|
m_connect_status(std::exchange(other.m_connect_status, std::nullopt)),
|
|
m_ssl_info(std::move(other.m_ssl_info))
|
|
{
|
|
}
|
|
|
|
tcp_client::~tcp_client()
|
|
{
|
|
// If this tcp client is using SSL and the connection did not have an ssl error, schedule a task
|
|
// to shutdown the connection cleanly. This is done on a background scheduled task since the
|
|
// tcp client's destructor cannot co_await the SSL_shutdown() read and write poll operations.
|
|
if (m_ssl_info.m_ssl_ptr != nullptr && !m_ssl_info.m_ssl_error)
|
|
{
|
|
// Should the shutdown timeout be configurable?
|
|
m_io_scheduler->schedule(ssl_shutdown_and_free(
|
|
m_io_scheduler, std::move(m_socket), std::move(m_ssl_info.m_ssl_ptr), std::chrono::seconds{30}));
|
|
}
|
|
}
|
|
|
|
auto tcp_client::operator=(tcp_client&& other) noexcept -> tcp_client&
|
|
{
|
|
if (std::addressof(other) != this)
|
|
{
|
|
m_io_scheduler = std::move(other.m_io_scheduler);
|
|
m_options = std::move(other.m_options);
|
|
m_socket = std::move(other.m_socket);
|
|
m_connect_status = std::exchange(other.m_connect_status, std::nullopt);
|
|
m_ssl_info = std::move(other.m_ssl_info);
|
|
}
|
|
return *this;
|
|
}
|
|
|
|
auto tcp_client::connect(std::chrono::milliseconds timeout) -> coro::task<connect_status>
|
|
{
|
|
// Only allow the user to connect per tcp client once, if they need to re-connect they should
|
|
// make a new tcp_client.
|
|
if (m_connect_status.has_value())
|
|
{
|
|
co_return m_connect_status.value();
|
|
}
|
|
|
|
// This enforces the connection status is aways set on the client object upon returning.
|
|
auto return_value = [this](connect_status s) -> connect_status {
|
|
m_connect_status = s;
|
|
return s;
|
|
};
|
|
|
|
sockaddr_in server{};
|
|
server.sin_family = static_cast<int>(m_options.address.domain());
|
|
server.sin_port = htons(m_options.port);
|
|
server.sin_addr = *reinterpret_cast<const in_addr*>(m_options.address.data().data());
|
|
|
|
auto cret = ::connect(m_socket.native_handle(), (struct sockaddr*)&server, sizeof(server));
|
|
if (cret == 0)
|
|
{
|
|
co_return return_value(connect_status::connected);
|
|
}
|
|
else if (cret == -1)
|
|
{
|
|
// If the connect is happening in the background poll for write on the socket to trigger
|
|
// when the connection is established.
|
|
if (errno == EAGAIN || errno == EINPROGRESS)
|
|
{
|
|
auto pstatus = co_await m_io_scheduler->poll(m_socket, poll_op::write, timeout);
|
|
if (pstatus == poll_status::event)
|
|
{
|
|
int result{0};
|
|
socklen_t result_length{sizeof(result)};
|
|
if (getsockopt(m_socket.native_handle(), SOL_SOCKET, SO_ERROR, &result, &result_length) < 0)
|
|
{
|
|
std::cerr << "connect failed to getsockopt after write poll event\n";
|
|
}
|
|
|
|
if (result == 0)
|
|
{
|
|
co_return return_value(connect_status::connected);
|
|
}
|
|
}
|
|
else if (pstatus == poll_status::timeout)
|
|
{
|
|
co_return return_value(connect_status::timeout);
|
|
}
|
|
}
|
|
}
|
|
|
|
co_return return_value(connect_status::error);
|
|
}
|
|
|
|
auto tcp_client::ssl_handshake(std::chrono::milliseconds timeout) -> coro::task<ssl_handshake_status>
|
|
{
|
|
if (!m_connect_status.has_value() || m_connect_status.value() != connect_status::connected)
|
|
{
|
|
// Can't ssl handshake if the connection isn't established.
|
|
co_return ssl_handshake_status::not_connected;
|
|
}
|
|
|
|
if (m_options.ssl_ctx == nullptr)
|
|
{
|
|
// ssl isn't setup
|
|
co_return ssl_handshake_status::ssl_context_required;
|
|
}
|
|
|
|
if (m_ssl_info.m_ssl_handshake_status.has_value())
|
|
{
|
|
// The user has already called this function.
|
|
co_return m_ssl_info.m_ssl_handshake_status.value();
|
|
}
|
|
|
|
// Enforce on any return past here to set the cached handshake status.
|
|
auto return_value = [this](ssl_handshake_status s) -> ssl_handshake_status {
|
|
m_ssl_info.m_ssl_handshake_status = s;
|
|
return s;
|
|
};
|
|
|
|
m_ssl_info.m_ssl_ptr = ssl_unique_ptr{SSL_new(m_options.ssl_ctx->native_handle())};
|
|
if (m_ssl_info.m_ssl_ptr == nullptr)
|
|
{
|
|
co_return return_value(ssl_handshake_status::ssl_resource_allocation_failed);
|
|
}
|
|
|
|
if (auto r = SSL_set_fd(m_ssl_info.m_ssl_ptr.get(), m_socket.native_handle()); r == 0)
|
|
{
|
|
co_return return_value(ssl_handshake_status::ssl_set_fd_failure);
|
|
}
|
|
|
|
if (m_ssl_info.m_ssl_connection_type == ssl_connection_type::connect)
|
|
{
|
|
SSL_set_connect_state(m_ssl_info.m_ssl_ptr.get());
|
|
}
|
|
else // ssl_connection_type::accept
|
|
{
|
|
SSL_set_accept_state(m_ssl_info.m_ssl_ptr.get());
|
|
}
|
|
|
|
int r{0};
|
|
ERR_clear_error();
|
|
while ((r = SSL_do_handshake(m_ssl_info.m_ssl_ptr.get())) != 1)
|
|
{
|
|
poll_op op{poll_op::read_write};
|
|
int err = SSL_get_error(m_ssl_info.m_ssl_ptr.get(), r);
|
|
if (err == SSL_ERROR_WANT_WRITE)
|
|
{
|
|
op = poll_op::write;
|
|
}
|
|
else if (err == SSL_ERROR_WANT_READ)
|
|
{
|
|
op = poll_op::read;
|
|
}
|
|
else
|
|
{
|
|
// char error_buffer[256];
|
|
// ERR_error_string(err, error_buffer);
|
|
// std::cerr << "ssl_handleshake error=[" << error_buffer << "]\n";
|
|
co_return return_value(ssl_handshake_status::handshake_failed);
|
|
}
|
|
|
|
// TODO: adjust timeout based on elapsed time so far.
|
|
auto pstatus = co_await m_io_scheduler->poll(m_socket, op, timeout);
|
|
switch (pstatus)
|
|
{
|
|
case poll_status::timeout:
|
|
co_return return_value(ssl_handshake_status::timeout);
|
|
case poll_status::error:
|
|
co_return return_value(ssl_handshake_status::poll_error);
|
|
case poll_status::closed:
|
|
co_return return_value(ssl_handshake_status::unexpected_close);
|
|
default:
|
|
// Event triggered, continue handshake.
|
|
break;
|
|
}
|
|
}
|
|
|
|
co_return return_value(ssl_handshake_status::ok);
|
|
}
|
|
|
|
auto tcp_client::ssl_shutdown_and_free(
|
|
std::shared_ptr<io_scheduler> io_scheduler,
|
|
net::socket s,
|
|
ssl_unique_ptr ssl_ptr,
|
|
std::chrono::milliseconds timeout) -> coro::task<void>
|
|
{
|
|
while (true)
|
|
{
|
|
auto r = SSL_shutdown(ssl_ptr.get());
|
|
if (r == 1) // shutdown complete
|
|
{
|
|
co_return;
|
|
}
|
|
else if (r == 0) // shutdown in progress
|
|
{
|
|
coro::poll_op op{coro::poll_op::read_write};
|
|
auto err = SSL_get_error(ssl_ptr.get(), r);
|
|
if (err == SSL_ERROR_WANT_WRITE)
|
|
{
|
|
op = coro::poll_op::write;
|
|
}
|
|
else if (err == SSL_ERROR_WANT_READ)
|
|
{
|
|
op = coro::poll_op::read;
|
|
}
|
|
else
|
|
{
|
|
co_return;
|
|
}
|
|
|
|
auto pstatus = co_await io_scheduler->poll(s, op, timeout);
|
|
switch (pstatus)
|
|
{
|
|
case poll_status::timeout:
|
|
case poll_status::error:
|
|
case poll_status::closed:
|
|
co_return;
|
|
default:
|
|
// continue shutdown.
|
|
break;
|
|
}
|
|
}
|
|
else // r < 0 error
|
|
{
|
|
co_return;
|
|
}
|
|
}
|
|
}
|
|
|
|
} // namespace coro::net
|