1
0
Fork 0
mirror of synced 2025-03-06 20:59:54 +01:00
linux/tools/testing/selftests/net/mptcp/mptcp_connect.c
Paolo Abeni 6bdb6211a6 mptcp: more stable diag self-tests
During diag self-tests we introduce long wait in the mptcp test
program to give the script enough time to access the sockets
dump.

Such wait is introduced after shutting down one sockets end. Since
commit 43b54c6ee3 ("mptcp: Use full MPTCP-level disconnect state
machine") if both sides shutdown the socket is correctly transitioned
into CLOSED status.

As a side effect some sockets are not dumped via the diag interface,
because the socket state (CLOSED) does not match the default filter, and
this cause self-tests instability.

Address the issue moving the above mentioned wait before shutting
down the socket.

Closes: https://github.com/multipath-tcp/mptcp_net-next/issues/68
Fixes: df62f2ec3d ("selftests/mptcp: add diag interface tests")
Tested-and-acked-by: Matthieu Baerts <matthieu.baerts@tessares.net>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
2020-08-07 17:25:38 -07:00

903 lines
17 KiB
C

// SPDX-License-Identifier: GPL-2.0
#define _GNU_SOURCE
#include <errno.h>
#include <limits.h>
#include <fcntl.h>
#include <string.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <strings.h>
#include <signal.h>
#include <unistd.h>
#include <sys/poll.h>
#include <sys/sendfile.h>
#include <sys/stat.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/mman.h>
#include <netdb.h>
#include <netinet/in.h>
#include <linux/tcp.h>
extern int optind;
#ifndef IPPROTO_MPTCP
#define IPPROTO_MPTCP 262
#endif
#ifndef TCP_ULP
#define TCP_ULP 31
#endif
static int poll_timeout = 10 * 1000;
static bool listen_mode;
static bool quit;
enum cfg_mode {
CFG_MODE_POLL,
CFG_MODE_MMAP,
CFG_MODE_SENDFILE,
};
static enum cfg_mode cfg_mode = CFG_MODE_POLL;
static const char *cfg_host;
static const char *cfg_port = "12000";
static int cfg_sock_proto = IPPROTO_MPTCP;
static bool tcpulp_audit;
static int pf = AF_INET;
static int cfg_sndbuf;
static int cfg_rcvbuf;
static bool cfg_join;
static int cfg_wait;
static void die_usage(void)
{
fprintf(stderr, "Usage: mptcp_connect [-6] [-u] [-s MPTCP|TCP] [-p port] [-m mode]"
"[-l] [-w sec] connect_address\n");
fprintf(stderr, "\t-6 use ipv6\n");
fprintf(stderr, "\t-t num -- set poll timeout to num\n");
fprintf(stderr, "\t-S num -- set SO_SNDBUF to num\n");
fprintf(stderr, "\t-R num -- set SO_RCVBUF to num\n");
fprintf(stderr, "\t-p num -- use port num\n");
fprintf(stderr, "\t-m [MPTCP|TCP] -- use tcp or mptcp sockets\n");
fprintf(stderr, "\t-s [mmap|poll] -- use poll (default) or mmap\n");
fprintf(stderr, "\t-u -- check mptcp ulp\n");
fprintf(stderr, "\t-w num -- wait num sec before closing the socket\n");
exit(1);
}
static void handle_signal(int nr)
{
quit = true;
}
static const char *getxinfo_strerr(int err)
{
if (err == EAI_SYSTEM)
return strerror(errno);
return gai_strerror(err);
}
static void xgetnameinfo(const struct sockaddr *addr, socklen_t addrlen,
char *host, socklen_t hostlen,
char *serv, socklen_t servlen)
{
int flags = NI_NUMERICHOST | NI_NUMERICSERV;
int err = getnameinfo(addr, addrlen, host, hostlen, serv, servlen,
flags);
if (err) {
const char *errstr = getxinfo_strerr(err);
fprintf(stderr, "Fatal: getnameinfo: %s\n", errstr);
exit(1);
}
}
static void xgetaddrinfo(const char *node, const char *service,
const struct addrinfo *hints,
struct addrinfo **res)
{
int err = getaddrinfo(node, service, hints, res);
if (err) {
const char *errstr = getxinfo_strerr(err);
fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
node ? node : "", service ? service : "", errstr);
exit(1);
}
}
static void set_rcvbuf(int fd, unsigned int size)
{
int err;
err = setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &size, sizeof(size));
if (err) {
perror("set SO_RCVBUF");
exit(1);
}
}
static void set_sndbuf(int fd, unsigned int size)
{
int err;
err = setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &size, sizeof(size));
if (err) {
perror("set SO_SNDBUF");
exit(1);
}
}
static int sock_listen_mptcp(const char * const listenaddr,
const char * const port)
{
int sock;
struct addrinfo hints = {
.ai_protocol = IPPROTO_TCP,
.ai_socktype = SOCK_STREAM,
.ai_flags = AI_PASSIVE | AI_NUMERICHOST
};
hints.ai_family = pf;
struct addrinfo *a, *addr;
int one = 1;
xgetaddrinfo(listenaddr, port, &hints, &addr);
hints.ai_family = pf;
for (a = addr; a; a = a->ai_next) {
sock = socket(a->ai_family, a->ai_socktype, cfg_sock_proto);
if (sock < 0)
continue;
if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
sizeof(one)))
perror("setsockopt");
if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
break; /* success */
perror("bind");
close(sock);
sock = -1;
}
freeaddrinfo(addr);
if (sock < 0) {
fprintf(stderr, "Could not create listen socket\n");
return sock;
}
if (listen(sock, 20)) {
perror("listen");
close(sock);
return -1;
}
return sock;
}
static bool sock_test_tcpulp(const char * const remoteaddr,
const char * const port)
{
struct addrinfo hints = {
.ai_protocol = IPPROTO_TCP,
.ai_socktype = SOCK_STREAM,
};
struct addrinfo *a, *addr;
int sock = -1, ret = 0;
bool test_pass = false;
hints.ai_family = AF_INET;
xgetaddrinfo(remoteaddr, port, &hints, &addr);
for (a = addr; a; a = a->ai_next) {
sock = socket(a->ai_family, a->ai_socktype, IPPROTO_TCP);
if (sock < 0) {
perror("socket");
continue;
}
ret = setsockopt(sock, IPPROTO_TCP, TCP_ULP, "mptcp",
sizeof("mptcp"));
if (ret == -1 && errno == EOPNOTSUPP)
test_pass = true;
close(sock);
if (test_pass)
break;
if (!ret)
fprintf(stderr,
"setsockopt(TCP_ULP) returned 0\n");
else
perror("setsockopt(TCP_ULP)");
}
return test_pass;
}
static int sock_connect_mptcp(const char * const remoteaddr,
const char * const port, int proto)
{
struct addrinfo hints = {
.ai_protocol = IPPROTO_TCP,
.ai_socktype = SOCK_STREAM,
};
struct addrinfo *a, *addr;
int sock = -1;
hints.ai_family = pf;
xgetaddrinfo(remoteaddr, port, &hints, &addr);
for (a = addr; a; a = a->ai_next) {
sock = socket(a->ai_family, a->ai_socktype, proto);
if (sock < 0) {
perror("socket");
continue;
}
if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
break; /* success */
perror("connect()");
close(sock);
sock = -1;
}
freeaddrinfo(addr);
return sock;
}
static size_t do_rnd_write(const int fd, char *buf, const size_t len)
{
static bool first = true;
unsigned int do_w;
ssize_t bw;
do_w = rand() & 0xffff;
if (do_w == 0 || do_w > len)
do_w = len;
if (cfg_join && first && do_w > 100)
do_w = 100;
bw = write(fd, buf, do_w);
if (bw < 0)
perror("write");
/* let the join handshake complete, before going on */
if (cfg_join && first) {
usleep(200000);
first = false;
}
return bw;
}
static size_t do_write(const int fd, char *buf, const size_t len)
{
size_t offset = 0;
while (offset < len) {
size_t written;
ssize_t bw;
bw = write(fd, buf + offset, len - offset);
if (bw < 0) {
perror("write");
return 0;
}
written = (size_t)bw;
offset += written;
}
return offset;
}
static ssize_t do_rnd_read(const int fd, char *buf, const size_t len)
{
size_t cap = rand();
cap &= 0xffff;
if (cap == 0)
cap = 1;
else if (cap > len)
cap = len;
return read(fd, buf, cap);
}
static void set_nonblock(int fd)
{
int flags = fcntl(fd, F_GETFL);
if (flags == -1)
return;
fcntl(fd, F_SETFL, flags | O_NONBLOCK);
}
static int copyfd_io_poll(int infd, int peerfd, int outfd)
{
struct pollfd fds = {
.fd = peerfd,
.events = POLLIN | POLLOUT,
};
unsigned int woff = 0, wlen = 0;
char wbuf[8192];
set_nonblock(peerfd);
for (;;) {
char rbuf[8192];
ssize_t len;
if (fds.events == 0)
break;
switch (poll(&fds, 1, poll_timeout)) {
case -1:
if (errno == EINTR)
continue;
perror("poll");
return 1;
case 0:
fprintf(stderr, "%s: poll timed out (events: "
"POLLIN %u, POLLOUT %u)\n", __func__,
fds.events & POLLIN, fds.events & POLLOUT);
return 2;
}
if (fds.revents & POLLIN) {
len = do_rnd_read(peerfd, rbuf, sizeof(rbuf));
if (len == 0) {
/* no more data to receive:
* peer has closed its write side
*/
fds.events &= ~POLLIN;
if ((fds.events & POLLOUT) == 0)
/* and nothing more to send */
break;
/* Else, still have data to transmit */
} else if (len < 0) {
perror("read");
return 3;
}
do_write(outfd, rbuf, len);
}
if (fds.revents & POLLOUT) {
if (wlen == 0) {
woff = 0;
wlen = read(infd, wbuf, sizeof(wbuf));
}
if (wlen > 0) {
ssize_t bw;
bw = do_rnd_write(peerfd, wbuf + woff, wlen);
if (bw < 0)
return 111;
woff += bw;
wlen -= bw;
} else if (wlen == 0) {
/* We have no more data to send. */
fds.events &= ~POLLOUT;
if ((fds.events & POLLIN) == 0)
/* ... and peer also closed already */
break;
/* ... but we still receive.
* Close our write side, ev. give some time
* for address notification and/or checking
* the current status
*/
if (cfg_wait)
usleep(cfg_wait);
shutdown(peerfd, SHUT_WR);
} else {
if (errno == EINTR)
continue;
perror("read");
return 4;
}
}
if (fds.revents & (POLLERR | POLLNVAL)) {
fprintf(stderr, "Unexpected revents: "
"POLLERR/POLLNVAL(%x)\n", fds.revents);
return 5;
}
}
/* leave some time for late join/announce */
if (cfg_join)
usleep(cfg_wait);
close(peerfd);
return 0;
}
static int do_recvfile(int infd, int outfd)
{
ssize_t r;
do {
char buf[16384];
r = do_rnd_read(infd, buf, sizeof(buf));
if (r > 0) {
if (write(outfd, buf, r) != r)
break;
} else if (r < 0) {
perror("read");
}
} while (r > 0);
return (int)r;
}
static int do_mmap(int infd, int outfd, unsigned int size)
{
char *inbuf = mmap(NULL, size, PROT_READ, MAP_SHARED, infd, 0);
ssize_t ret = 0, off = 0;
size_t rem;
if (inbuf == MAP_FAILED) {
perror("mmap");
return 1;
}
rem = size;
while (rem > 0) {
ret = write(outfd, inbuf + off, rem);
if (ret < 0) {
perror("write");
break;
}
off += ret;
rem -= ret;
}
munmap(inbuf, size);
return rem;
}
static int get_infd_size(int fd)
{
struct stat sb;
ssize_t count;
int err;
err = fstat(fd, &sb);
if (err < 0) {
perror("fstat");
return -1;
}
if ((sb.st_mode & S_IFMT) != S_IFREG) {
fprintf(stderr, "%s: stdin is not a regular file\n", __func__);
return -2;
}
count = sb.st_size;
if (count > INT_MAX) {
fprintf(stderr, "File too large: %zu\n", count);
return -3;
}
return (int)count;
}
static int do_sendfile(int infd, int outfd, unsigned int count)
{
while (count > 0) {
ssize_t r;
r = sendfile(outfd, infd, NULL, count);
if (r < 0) {
perror("sendfile");
return 3;
}
count -= r;
}
return 0;
}
static int copyfd_io_mmap(int infd, int peerfd, int outfd,
unsigned int size)
{
int err;
if (listen_mode) {
err = do_recvfile(peerfd, outfd);
if (err)
return err;
err = do_mmap(infd, peerfd, size);
} else {
err = do_mmap(infd, peerfd, size);
if (err)
return err;
shutdown(peerfd, SHUT_WR);
err = do_recvfile(peerfd, outfd);
}
return err;
}
static int copyfd_io_sendfile(int infd, int peerfd, int outfd,
unsigned int size)
{
int err;
if (listen_mode) {
err = do_recvfile(peerfd, outfd);
if (err)
return err;
err = do_sendfile(infd, peerfd, size);
} else {
err = do_sendfile(infd, peerfd, size);
if (err)
return err;
err = do_recvfile(peerfd, outfd);
}
return err;
}
static int copyfd_io(int infd, int peerfd, int outfd)
{
int file_size;
switch (cfg_mode) {
case CFG_MODE_POLL:
return copyfd_io_poll(infd, peerfd, outfd);
case CFG_MODE_MMAP:
file_size = get_infd_size(infd);
if (file_size < 0)
return file_size;
return copyfd_io_mmap(infd, peerfd, outfd, file_size);
case CFG_MODE_SENDFILE:
file_size = get_infd_size(infd);
if (file_size < 0)
return file_size;
return copyfd_io_sendfile(infd, peerfd, outfd, file_size);
}
fprintf(stderr, "Invalid mode %d\n", cfg_mode);
die_usage();
return 1;
}
static void check_sockaddr(int pf, struct sockaddr_storage *ss,
socklen_t salen)
{
struct sockaddr_in6 *sin6;
struct sockaddr_in *sin;
socklen_t wanted_size = 0;
switch (pf) {
case AF_INET:
wanted_size = sizeof(*sin);
sin = (void *)ss;
if (!sin->sin_port)
fprintf(stderr, "accept: something wrong: ip connection from port 0");
break;
case AF_INET6:
wanted_size = sizeof(*sin6);
sin6 = (void *)ss;
if (!sin6->sin6_port)
fprintf(stderr, "accept: something wrong: ipv6 connection from port 0");
break;
default:
fprintf(stderr, "accept: Unknown pf %d, salen %u\n", pf, salen);
return;
}
if (salen != wanted_size)
fprintf(stderr, "accept: size mismatch, got %d expected %d\n",
(int)salen, wanted_size);
if (ss->ss_family != pf)
fprintf(stderr, "accept: pf mismatch, expect %d, ss_family is %d\n",
(int)ss->ss_family, pf);
}
static void check_getpeername(int fd, struct sockaddr_storage *ss, socklen_t salen)
{
struct sockaddr_storage peerss;
socklen_t peersalen = sizeof(peerss);
if (getpeername(fd, (struct sockaddr *)&peerss, &peersalen) < 0) {
perror("getpeername");
return;
}
if (peersalen != salen) {
fprintf(stderr, "%s: %d vs %d\n", __func__, peersalen, salen);
return;
}
if (memcmp(ss, &peerss, peersalen)) {
char a[INET6_ADDRSTRLEN];
char b[INET6_ADDRSTRLEN];
char c[INET6_ADDRSTRLEN];
char d[INET6_ADDRSTRLEN];
xgetnameinfo((struct sockaddr *)ss, salen,
a, sizeof(a), b, sizeof(b));
xgetnameinfo((struct sockaddr *)&peerss, peersalen,
c, sizeof(c), d, sizeof(d));
fprintf(stderr, "%s: memcmp failure: accept %s vs peername %s, %s vs %s salen %d vs %d\n",
__func__, a, c, b, d, peersalen, salen);
}
}
static void check_getpeername_connect(int fd)
{
struct sockaddr_storage ss;
socklen_t salen = sizeof(ss);
char a[INET6_ADDRSTRLEN];
char b[INET6_ADDRSTRLEN];
if (getpeername(fd, (struct sockaddr *)&ss, &salen) < 0) {
perror("getpeername");
return;
}
xgetnameinfo((struct sockaddr *)&ss, salen,
a, sizeof(a), b, sizeof(b));
if (strcmp(cfg_host, a) || strcmp(cfg_port, b))
fprintf(stderr, "%s: %s vs %s, %s vs %s\n", __func__,
cfg_host, a, cfg_port, b);
}
static void maybe_close(int fd)
{
unsigned int r = rand();
if (!cfg_join && (r & 1))
close(fd);
}
int main_loop_s(int listensock)
{
struct sockaddr_storage ss;
struct pollfd polls;
socklen_t salen;
int remotesock;
polls.fd = listensock;
polls.events = POLLIN;
switch (poll(&polls, 1, poll_timeout)) {
case -1:
perror("poll");
return 1;
case 0:
fprintf(stderr, "%s: timed out\n", __func__);
close(listensock);
return 2;
}
salen = sizeof(ss);
remotesock = accept(listensock, (struct sockaddr *)&ss, &salen);
if (remotesock >= 0) {
maybe_close(listensock);
check_sockaddr(pf, &ss, salen);
check_getpeername(remotesock, &ss, salen);
return copyfd_io(0, remotesock, 1);
}
perror("accept");
return 1;
}
static void init_rng(void)
{
int fd = open("/dev/urandom", O_RDONLY);
unsigned int foo;
if (fd > 0) {
int ret = read(fd, &foo, sizeof(foo));
if (ret < 0)
srand(fd + foo);
close(fd);
}
srand(foo);
}
int main_loop(void)
{
int fd;
/* listener is ready. */
fd = sock_connect_mptcp(cfg_host, cfg_port, cfg_sock_proto);
if (fd < 0)
return 2;
check_getpeername_connect(fd);
if (cfg_rcvbuf)
set_rcvbuf(fd, cfg_rcvbuf);
if (cfg_sndbuf)
set_sndbuf(fd, cfg_sndbuf);
return copyfd_io(0, fd, 1);
}
int parse_proto(const char *proto)
{
if (!strcasecmp(proto, "MPTCP"))
return IPPROTO_MPTCP;
if (!strcasecmp(proto, "TCP"))
return IPPROTO_TCP;
fprintf(stderr, "Unknown protocol: %s\n.", proto);
die_usage();
/* silence compiler warning */
return 0;
}
int parse_mode(const char *mode)
{
if (!strcasecmp(mode, "poll"))
return CFG_MODE_POLL;
if (!strcasecmp(mode, "mmap"))
return CFG_MODE_MMAP;
if (!strcasecmp(mode, "sendfile"))
return CFG_MODE_SENDFILE;
fprintf(stderr, "Unknown test mode: %s\n", mode);
fprintf(stderr, "Supported modes are:\n");
fprintf(stderr, "\t\t\"poll\" - interleaved read/write using poll()\n");
fprintf(stderr, "\t\t\"mmap\" - send entire input file (mmap+write), then read response (-l will read input first)\n");
fprintf(stderr, "\t\t\"sendfile\" - send entire input file (sendfile), then read response (-l will read input first)\n");
die_usage();
/* silence compiler warning */
return 0;
}
static int parse_int(const char *size)
{
unsigned long s;
errno = 0;
s = strtoul(size, NULL, 0);
if (errno) {
fprintf(stderr, "Invalid sndbuf size %s (%s)\n",
size, strerror(errno));
die_usage();
}
if (s > INT_MAX) {
fprintf(stderr, "Invalid sndbuf size %s (%s)\n",
size, strerror(ERANGE));
die_usage();
}
return (int)s;
}
static void parse_opts(int argc, char **argv)
{
int c;
while ((c = getopt(argc, argv, "6jlp:s:hut:m:S:R:w:")) != -1) {
switch (c) {
case 'j':
cfg_join = true;
cfg_mode = CFG_MODE_POLL;
cfg_wait = 400000;
break;
case 'l':
listen_mode = true;
break;
case 'p':
cfg_port = optarg;
break;
case 's':
cfg_sock_proto = parse_proto(optarg);
break;
case 'h':
die_usage();
break;
case 'u':
tcpulp_audit = true;
break;
case '6':
pf = AF_INET6;
break;
case 't':
poll_timeout = atoi(optarg) * 1000;
if (poll_timeout <= 0)
poll_timeout = -1;
break;
case 'm':
cfg_mode = parse_mode(optarg);
break;
case 'S':
cfg_sndbuf = parse_int(optarg);
break;
case 'R':
cfg_rcvbuf = parse_int(optarg);
break;
case 'w':
cfg_wait = atoi(optarg)*1000000;
break;
}
}
if (optind + 1 != argc)
die_usage();
cfg_host = argv[optind];
if (strchr(cfg_host, ':'))
pf = AF_INET6;
}
int main(int argc, char *argv[])
{
init_rng();
signal(SIGUSR1, handle_signal);
parse_opts(argc, argv);
if (tcpulp_audit)
return sock_test_tcpulp(cfg_host, cfg_port) ? 0 : 1;
if (listen_mode) {
int fd = sock_listen_mptcp(cfg_host, cfg_port);
if (fd < 0)
return 1;
if (cfg_rcvbuf)
set_rcvbuf(fd, cfg_rcvbuf);
if (cfg_sndbuf)
set_sndbuf(fd, cfg_sndbuf);
return main_loop_s(fd);
}
return main_loop();
}