diff options
Diffstat (limited to 'tools/testing/selftests/net/rds/test.py')
-rw-r--r-- | tools/testing/selftests/net/rds/test.py | 262 |
1 files changed, 262 insertions, 0 deletions
diff --git a/tools/testing/selftests/net/rds/test.py b/tools/testing/selftests/net/rds/test.py new file mode 100644 index 000000000000..e6bb109bcead --- /dev/null +++ b/tools/testing/selftests/net/rds/test.py @@ -0,0 +1,262 @@ +#! /usr/bin/env python3 +# SPDX-License-Identifier: GPL-2.0 + +import argparse +import ctypes +import errno +import hashlib +import os +import select +import signal +import socket +import subprocess +import sys +import atexit +from pwd import getpwuid +from os import stat +from lib.py import ip + + +libc = ctypes.cdll.LoadLibrary('libc.so.6') +setns = libc.setns + +net0 = 'net0' +net1 = 'net1' + +veth0 = 'veth0' +veth1 = 'veth1' + +# Helper function for creating a socket inside a network namespace. +# We need this because otherwise RDS will detect that the two TCP +# sockets are on the same interface and use the loop transport instead +# of the TCP transport. +def netns_socket(netns, *args): + u0, u1 = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET) + + child = os.fork() + if child == 0: + # change network namespace + with open(f'/var/run/netns/{netns}') as f: + try: + ret = setns(f.fileno(), 0) + except IOError as e: + print(e.errno) + print(e) + + # create socket in target namespace + s = socket.socket(*args) + + # send resulting socket to parent + socket.send_fds(u0, [], [s.fileno()]) + + sys.exit(0) + + # receive socket from child + _, s, _, _ = socket.recv_fds(u1, 0, 1) + os.waitpid(child, 0) + u0.close() + u1.close() + return socket.fromfd(s[0], *args) + +def signal_handler(sig, frame): + print('Test timed out') + sys.exit(1) + +#Parse out command line arguments. We take an optional +# timeout parameter and an optional log output folder +parser = argparse.ArgumentParser(description="init script args", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) +parser.add_argument("-d", "--logdir", action="store", + help="directory to store logs", default="/tmp") +parser.add_argument('--timeout', help="timeout to terminate hung test", + type=int, default=0) +parser.add_argument('-l', '--loss', help="Simulate tcp packet loss", + type=int, default=0) +parser.add_argument('-c', '--corruption', help="Simulate tcp packet corruption", + type=int, default=0) +parser.add_argument('-u', '--duplicate', help="Simulate tcp packet duplication", + type=int, default=0) +args = parser.parse_args() +logdir=args.logdir +packet_loss=str(args.loss)+'%' +packet_corruption=str(args.corruption)+'%' +packet_duplicate=str(args.duplicate)+'%' + +ip(f"netns add {net0}") +ip(f"netns add {net1}") +ip(f"link add type veth") + +addrs = [ + # we technically don't need different port numbers, but this will + # help identify traffic in the network analyzer + ('10.0.0.1', 10000), + ('10.0.0.2', 20000), +] + +# move interfaces to separate namespaces so they can no longer be +# bound directly; this prevents rds from switching over from the tcp +# transport to the loop transport. +ip(f"link set {veth0} netns {net0} up") +ip(f"link set {veth1} netns {net1} up") + + + +# add addresses +ip(f"-n {net0} addr add {addrs[0][0]}/32 dev {veth0}") +ip(f"-n {net1} addr add {addrs[1][0]}/32 dev {veth1}") + +# add routes +ip(f"-n {net0} route add {addrs[1][0]}/32 dev {veth0}") +ip(f"-n {net1} route add {addrs[0][0]}/32 dev {veth1}") + +# sanity check that our two interfaces/addresses are correctly set up +# and communicating by doing a single ping +ip(f"netns exec {net0} ping -c 1 {addrs[1][0]}") + +# Start a packet capture on each network +for net in [net0, net1]: + tcpdump_pid = os.fork() + if tcpdump_pid == 0: + pcap = logdir+'/'+net+'.pcap' + subprocess.check_call(['touch', pcap]) + user = getpwuid(stat(pcap).st_uid).pw_name + ip(f"netns exec {net} /usr/sbin/tcpdump -Z {user} -i any -w {pcap}") + sys.exit(0) + +# simulate packet loss, duplication and corruption +for net, iface in [(net0, veth0), (net1, veth1)]: + ip(f"netns exec {net} /usr/sbin/tc qdisc add dev {iface} root netem \ + corrupt {packet_corruption} loss {packet_loss} duplicate \ + {packet_duplicate}") + +# add a timeout +if args.timeout > 0: + signal.alarm(args.timeout) + signal.signal(signal.SIGALRM, signal_handler) + +sockets = [ + netns_socket(net0, socket.AF_RDS, socket.SOCK_SEQPACKET), + netns_socket(net1, socket.AF_RDS, socket.SOCK_SEQPACKET), +] + +for s, addr in zip(sockets, addrs): + s.bind(addr) + s.setblocking(0) + +fileno_to_socket = { + s.fileno(): s for s in sockets +} + +addr_to_socket = { + addr: s for addr, s in zip(addrs, sockets) +} + +socket_to_addr = { + s: addr for addr, s in zip(addrs, sockets) +} + +send_hashes = {} +recv_hashes = {} + +ep = select.epoll() + +for s in sockets: + ep.register(s, select.EPOLLRDNORM) + +n = 50000 +nr_send = 0 +nr_recv = 0 + +while nr_send < n: + # Send as much as we can without blocking + print("sending...", nr_send, nr_recv) + while nr_send < n: + send_data = hashlib.sha256( + f'packet {nr_send}'.encode('utf-8')).hexdigest().encode('utf-8') + + # pseudo-random send/receive pattern + sender = sockets[nr_send % 2] + receiver = sockets[1 - (nr_send % 3) % 2] + + try: + sender.sendto(send_data, socket_to_addr[receiver]) + send_hashes.setdefault((sender.fileno(), receiver.fileno()), + hashlib.sha256()).update(f'<{send_data}>'.encode('utf-8')) + nr_send = nr_send + 1 + except BlockingIOError as e: + break + except OSError as e: + if e.errno in [errno.ENOBUFS, errno.ECONNRESET, errno.EPIPE]: + break + raise + + # Receive as much as we can without blocking + print("receiving...", nr_send, nr_recv) + while nr_recv < nr_send: + for fileno, eventmask in ep.poll(): + receiver = fileno_to_socket[fileno] + + if eventmask & select.EPOLLRDNORM: + while True: + try: + recv_data, address = receiver.recvfrom(1024) + sender = addr_to_socket[address] + recv_hashes.setdefault((sender.fileno(), + receiver.fileno()), hashlib.sha256()).update( + f'<{recv_data}>'.encode('utf-8')) + nr_recv = nr_recv + 1 + except BlockingIOError as e: + break + + # exercise net/rds/tcp.c:rds_tcp_sysctl_reset() + for net in [net0, net1]: + ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_rcvbuf=10000") + ip(f"netns exec {net} /usr/sbin/sysctl net.rds.tcp.rds_tcp_sndbuf=10000") + +print("done", nr_send, nr_recv) + +# the Python socket module doesn't know these +RDS_INFO_FIRST = 10000 +RDS_INFO_LAST = 10017 + +nr_success = 0 +nr_error = 0 + +for s in sockets: + for optname in range(RDS_INFO_FIRST, RDS_INFO_LAST + 1): + # Sigh, the Python socket module doesn't allow us to pass + # buffer lengths greater than 1024 for some reason. RDS + # wants multiple pages. + try: + s.getsockopt(socket.SOL_RDS, optname, 1024) + nr_success = nr_success + 1 + except OSError as e: + nr_error = nr_error + 1 + if e.errno == errno.ENOSPC: + # ignore + pass + +print(f"getsockopt(): {nr_success}/{nr_error}") + +print("Stopping network packet captures") +subprocess.check_call(['killall', '-q', 'tcpdump']) + +# We're done sending and receiving stuff, now let's check if what +# we received is what we sent. +for (sender, receiver), send_hash in send_hashes.items(): + recv_hash = recv_hashes.get((sender, receiver)) + + if recv_hash is None: + print("FAIL: No data received") + sys.exit(1) + + if send_hash.hexdigest() != recv_hash.hexdigest(): + print("FAIL: Send/recv mismatch") + print("hash expected:", send_hash.hexdigest()) + print("hash received:", recv_hash.hexdigest()) + sys.exit(1) + + print(f"{sender}/{receiver}: ok") + +print("Success") +sys.exit(0) |