diff options
Diffstat (limited to 'debian/tests/python/ucspi_tcp_test/__main__.py')
-rw-r--r-- | debian/tests/python/ucspi_tcp_test/__main__.py | 322 |
1 files changed, 322 insertions, 0 deletions
diff --git a/debian/tests/python/ucspi_tcp_test/__main__.py b/debian/tests/python/ucspi_tcp_test/__main__.py new file mode 100644 index 0000000..5bddf1e --- /dev/null +++ b/debian/tests/python/ucspi_tcp_test/__main__.py @@ -0,0 +1,322 @@ +"""Test the TCP implementation of UCSPI.""" + +from __future__ import annotations + +import argparse +import dataclasses +import enum +import pathlib +import socket +import typing + +import netifaces +import utf8_locale + +import ucspi_test + + +if typing.TYPE_CHECKING: + from typing import Any, Final + + +@dataclasses.dataclass +class InvalidPortNumberError(ucspi_test.RunnerError): + """Could not connect to a TCP socket.""" + + proto: str + portstr: str + err: Exception + + def __str__(self) -> str: + """Provide a human-readable error message.""" + return f"{self.proto}: could not convert {self.portstr!r} to a number: {self.err}" + + +@dataclasses.dataclass +class NoAvailablePortsError(ucspi_test.RunnerError): + """Could not find an available port number.""" + + addr: str + + def __str__(self) -> str: + """Provide a human-readable error message.""" + return f"Could not find a suitable port on {self.addr}" + + +@dataclasses.dataclass +class SocketCreateError(ucspi_test.RunnerError): + """Could not create a TCP socket.""" + + err: Exception + + def __str__(self) -> str: + """Provide a human-readable error message.""" + return f"Could not create a TCP socket: {self.err}" + + +@dataclasses.dataclass +class SocketReuseAddressError(ucspi_test.RunnerError): + """Could not create a TCP socket.""" + + err: Exception + + def __str__(self) -> str: + """Provide a human-readable error message.""" + return f"Could not set the 'reuse address' option on a TCP socket: {self.err}" + + +@dataclasses.dataclass +class SocketBindError(ucspi_test.RunnerError): + """Could not bind a TCP socket.""" + + addr: str + port: int + err: Exception + + def __str__(self) -> str: + """Provide a human-readable error message.""" + return f"Could not bind to {self.addr}:{self.port}: {self.err}" + + +@dataclasses.dataclass +class SocketListenError(ucspi_test.RunnerError): + """Could not listen on a TCP socket.""" + + addr: str + port: int + err: Exception + + def __str__(self) -> str: + """Provide a human-readable error message.""" + return f"Could not listen on {self.addr}:{self.port}: {self.err}" + + +@dataclasses.dataclass +class SocketConnectError(ucspi_test.RunnerError): + """Could not connect to a TCP socket.""" + + addr: str + port: int + err: Exception + + def __str__(self) -> str: + """Provide a human-readable error message.""" + return f"Could not connect a TCP socket to {self.addr}:{self.port}: {self.err}" + + +@dataclasses.dataclass(frozen=True) +class Config(ucspi_test.Config): + """Runtime configuration for the TCP test runner.""" + + listen_addr: str + listen_addr_len: set[int] + listen_family: socket.AddressFamily + + +class TcpRunner(ucspi_test.Runner): + """Run ucspi-tcp tests.""" + + def find_listening_address(self) -> list[str]: + """Find a local address/port combination.""" + print(f"{self.proto}.find_listening_address() starting") + for port in range(6502, 8086): + if not isinstance(self.cfg, Config): + raise TypeError(repr(self.cfg)) + addr = self.cfg.listen_addr + sock = socket.socket(self.cfg.listen_family, socket.SOCK_STREAM, socket.IPPROTO_TCP) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + try: + sock.bind((addr, port)) + print(f"- got {port}") + sock.close() + return [addr, str(port)] + except OSError: + pass + + raise NoAvailablePortsError(addr) + + def get_listening_socket(self, addr: list[str]) -> socket.socket: + """Start listening on the specified address.""" + if not isinstance(self.cfg, Config): + raise TypeError(repr(self.cfg)) + if len(addr) not in self.cfg.listen_addr_len: + raise ucspi_test.SocketAddressLengthError(self.proto, addr) + laddr: Final = addr[0] + try: + lport: Final = int(addr[1]) + except ValueError as err: + raise InvalidPortNumberError(self.proto, addr[1], err) from err + + try: + sock: Final = socket.socket( + self.cfg.listen_family, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + ) + except OSError as err: + raise SocketCreateError(err) from err + try: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + except OSError as err: + raise SocketReuseAddressError(err) from err + try: + sock.bind((laddr, lport)) + except OSError as err: + raise SocketBindError(laddr, lport, err) from err + try: + sock.listen(5) + except OSError as err: + raise SocketListenError(laddr, lport, err) from err + + return sock + + def get_connected_socket(self, addr: list[str]) -> socket.socket: + """Connect to the specified address.""" + if not isinstance(self.cfg, Config): + raise TypeError(repr(self.cfg)) + if len(addr) not in self.cfg.listen_addr_len: + raise ucspi_test.SocketAddressLengthError(self.proto, addr) + laddr: Final = addr[0] + try: + lport: Final = int(addr[1]) + except ValueError as err: + raise InvalidPortNumberError(self.proto, addr[1], err) from err + + try: + sock: Final = socket.socket( + self.cfg.listen_family, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + ) + except OSError as err: + raise SocketCreateError(err) from err + try: + sock.connect((laddr, lport)) + except OSError as err: + raise SocketConnectError(laddr, lport, err) from err + + return sock + + def format_local_addr(self, addr: list[str]) -> str: + """Format an address returned by accept(), etc.""" + if not isinstance(self.cfg, Config): + raise TypeError(repr(self.cfg)) + if len(addr) not in self.cfg.listen_addr_len: + raise TypeError(repr(addr)) + return f"{addr[0]}:{addr[1]}" + + def format_remote_addr(self, addr: Any) -> str: # noqa: ANN401 + """Format an address returned by accept(), etc.""" + if not isinstance(self.cfg, Config): + raise TypeError(repr(self.cfg)) + if ( + not isinstance(addr, tuple) + or len(addr) not in self.cfg.listen_addr_len + or not isinstance(addr[0], str) + or not isinstance(addr[1], int) + ): + raise TypeError(repr(addr)) + return f"{addr[0]}:{addr[1]}" + + +class IPVersion(str, enum.Enum): + """The IP address family for the listening socket.""" + + IPV4: Final = "4" + IPV6: Final = "6" + + def __str__(self) -> str: + """Return the string value itself.""" + return self.value + + def addr_len(self) -> set[int]: + """Obtain the expected length of an address/port tuple.""" + match self: + case IPVersion.IPV4: + return {2} + + case IPVersion.IPV6: + return {2, 4} + + def family(self) -> socket.AddressFamily: + """Obtain the address family corresponding to this value.""" + match self: + case IPVersion.IPV4: + return socket.AF_INET + + case IPVersion.IPV6: + return socket.AF_INET6 + + +def get_listen_address(ip_version: IPVersion) -> tuple[str, set[int], socket.AddressFamily] | None: + """Get a loopback address for the specified address family, if any are configured.""" + ifaces: Final = netifaces.interfaces() + if "lo" not in ifaces: + print("No 'lo' interface at all?!") + return None + + family: Final = ip_version.family() + addrs: Final = netifaces.ifaddresses("lo") + candidates: Final = addrs.get(family) + if not candidates: + print("No addresses for the specified family on the 'lo' interface") + return None + + return candidates[0]["addr"], ip_version.addr_len(), family + + +def parse_args() -> Config | None: + """Parse the command-line arguments.""" + parser: Final = argparse.ArgumentParser(prog="uctest") + + parser.add_argument( + "-d", + "--bindir", + type=pathlib.Path, + required=True, + help="the path to the UCSPI utilities", + ) + parser.add_argument( + "-i", + "--ip-version", + type=IPVersion, + default=IPVersion.IPV4, + help="the address family to listen on ('4' for IPv4, '6' for IPv6)", + choices=["4", "6"], + ) + parser.add_argument( + "-p", + "--proto", + type=str, + required=True, + help="the UCSPI protocol ('tcp', 'unix', etc)", + ) + args: Final = parser.parse_args() + + listen_data: Final = get_listen_address(args.ip_version) + if listen_data is None: + return None + + return Config( + bindir=args.bindir.absolute(), + listen_addr=listen_data[0], + listen_addr_len=listen_data[1], + listen_family=listen_data[2], + proto=args.proto, + utf8_env=utf8_locale.UTF8Detect().detect().env, + ) + + +def main() -> None: + """Parse command-line arguments, run the tests.""" + cfg: Final = parse_args() + if cfg is None: + print("No loopback interface addresses for the requested family") + return + + ucspi_test.add_handler("tcp", TcpRunner) + ucspi_test.run_test_handler(cfg) + + +if __name__ == "__main__": + main() |