summaryrefslogtreecommitdiff
path: root/debian/tests/python/ucspi_tcp_test/__main__.py
diff options
context:
space:
mode:
Diffstat (limited to 'debian/tests/python/ucspi_tcp_test/__main__.py')
-rw-r--r--debian/tests/python/ucspi_tcp_test/__main__.py322
1 files changed, 322 insertions, 0 deletions
diff --git a/debian/tests/python/ucspi_tcp_test/__main__.py b/debian/tests/python/ucspi_tcp_test/__main__.py
new file mode 100644
index 0000000..5bddf1e
--- /dev/null
+++ b/debian/tests/python/ucspi_tcp_test/__main__.py
@@ -0,0 +1,322 @@
+"""Test the TCP implementation of UCSPI."""
+
+from __future__ import annotations
+
+import argparse
+import dataclasses
+import enum
+import pathlib
+import socket
+import typing
+
+import netifaces
+import utf8_locale
+
+import ucspi_test
+
+
+if typing.TYPE_CHECKING:
+ from typing import Any, Final
+
+
+@dataclasses.dataclass
+class InvalidPortNumberError(ucspi_test.RunnerError):
+ """Could not connect to a TCP socket."""
+
+ proto: str
+ portstr: str
+ err: Exception
+
+ def __str__(self) -> str:
+ """Provide a human-readable error message."""
+ return f"{self.proto}: could not convert {self.portstr!r} to a number: {self.err}"
+
+
+@dataclasses.dataclass
+class NoAvailablePortsError(ucspi_test.RunnerError):
+ """Could not find an available port number."""
+
+ addr: str
+
+ def __str__(self) -> str:
+ """Provide a human-readable error message."""
+ return f"Could not find a suitable port on {self.addr}"
+
+
+@dataclasses.dataclass
+class SocketCreateError(ucspi_test.RunnerError):
+ """Could not create a TCP socket."""
+
+ err: Exception
+
+ def __str__(self) -> str:
+ """Provide a human-readable error message."""
+ return f"Could not create a TCP socket: {self.err}"
+
+
+@dataclasses.dataclass
+class SocketReuseAddressError(ucspi_test.RunnerError):
+ """Could not create a TCP socket."""
+
+ err: Exception
+
+ def __str__(self) -> str:
+ """Provide a human-readable error message."""
+ return f"Could not set the 'reuse address' option on a TCP socket: {self.err}"
+
+
+@dataclasses.dataclass
+class SocketBindError(ucspi_test.RunnerError):
+ """Could not bind a TCP socket."""
+
+ addr: str
+ port: int
+ err: Exception
+
+ def __str__(self) -> str:
+ """Provide a human-readable error message."""
+ return f"Could not bind to {self.addr}:{self.port}: {self.err}"
+
+
+@dataclasses.dataclass
+class SocketListenError(ucspi_test.RunnerError):
+ """Could not listen on a TCP socket."""
+
+ addr: str
+ port: int
+ err: Exception
+
+ def __str__(self) -> str:
+ """Provide a human-readable error message."""
+ return f"Could not listen on {self.addr}:{self.port}: {self.err}"
+
+
+@dataclasses.dataclass
+class SocketConnectError(ucspi_test.RunnerError):
+ """Could not connect to a TCP socket."""
+
+ addr: str
+ port: int
+ err: Exception
+
+ def __str__(self) -> str:
+ """Provide a human-readable error message."""
+ return f"Could not connect a TCP socket to {self.addr}:{self.port}: {self.err}"
+
+
+@dataclasses.dataclass(frozen=True)
+class Config(ucspi_test.Config):
+ """Runtime configuration for the TCP test runner."""
+
+ listen_addr: str
+ listen_addr_len: set[int]
+ listen_family: socket.AddressFamily
+
+
+class TcpRunner(ucspi_test.Runner):
+ """Run ucspi-tcp tests."""
+
+ def find_listening_address(self) -> list[str]:
+ """Find a local address/port combination."""
+ print(f"{self.proto}.find_listening_address() starting")
+ for port in range(6502, 8086):
+ if not isinstance(self.cfg, Config):
+ raise TypeError(repr(self.cfg))
+ addr = self.cfg.listen_addr
+ sock = socket.socket(self.cfg.listen_family, socket.SOCK_STREAM, socket.IPPROTO_TCP)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ try:
+ sock.bind((addr, port))
+ print(f"- got {port}")
+ sock.close()
+ return [addr, str(port)]
+ except OSError:
+ pass
+
+ raise NoAvailablePortsError(addr)
+
+ def get_listening_socket(self, addr: list[str]) -> socket.socket:
+ """Start listening on the specified address."""
+ if not isinstance(self.cfg, Config):
+ raise TypeError(repr(self.cfg))
+ if len(addr) not in self.cfg.listen_addr_len:
+ raise ucspi_test.SocketAddressLengthError(self.proto, addr)
+ laddr: Final = addr[0]
+ try:
+ lport: Final = int(addr[1])
+ except ValueError as err:
+ raise InvalidPortNumberError(self.proto, addr[1], err) from err
+
+ try:
+ sock: Final = socket.socket(
+ self.cfg.listen_family,
+ socket.SOCK_STREAM,
+ socket.IPPROTO_TCP,
+ )
+ except OSError as err:
+ raise SocketCreateError(err) from err
+ try:
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ except OSError as err:
+ raise SocketReuseAddressError(err) from err
+ try:
+ sock.bind((laddr, lport))
+ except OSError as err:
+ raise SocketBindError(laddr, lport, err) from err
+ try:
+ sock.listen(5)
+ except OSError as err:
+ raise SocketListenError(laddr, lport, err) from err
+
+ return sock
+
+ def get_connected_socket(self, addr: list[str]) -> socket.socket:
+ """Connect to the specified address."""
+ if not isinstance(self.cfg, Config):
+ raise TypeError(repr(self.cfg))
+ if len(addr) not in self.cfg.listen_addr_len:
+ raise ucspi_test.SocketAddressLengthError(self.proto, addr)
+ laddr: Final = addr[0]
+ try:
+ lport: Final = int(addr[1])
+ except ValueError as err:
+ raise InvalidPortNumberError(self.proto, addr[1], err) from err
+
+ try:
+ sock: Final = socket.socket(
+ self.cfg.listen_family,
+ socket.SOCK_STREAM,
+ socket.IPPROTO_TCP,
+ )
+ except OSError as err:
+ raise SocketCreateError(err) from err
+ try:
+ sock.connect((laddr, lport))
+ except OSError as err:
+ raise SocketConnectError(laddr, lport, err) from err
+
+ return sock
+
+ def format_local_addr(self, addr: list[str]) -> str:
+ """Format an address returned by accept(), etc."""
+ if not isinstance(self.cfg, Config):
+ raise TypeError(repr(self.cfg))
+ if len(addr) not in self.cfg.listen_addr_len:
+ raise TypeError(repr(addr))
+ return f"{addr[0]}:{addr[1]}"
+
+ def format_remote_addr(self, addr: Any) -> str: # noqa: ANN401
+ """Format an address returned by accept(), etc."""
+ if not isinstance(self.cfg, Config):
+ raise TypeError(repr(self.cfg))
+ if (
+ not isinstance(addr, tuple)
+ or len(addr) not in self.cfg.listen_addr_len
+ or not isinstance(addr[0], str)
+ or not isinstance(addr[1], int)
+ ):
+ raise TypeError(repr(addr))
+ return f"{addr[0]}:{addr[1]}"
+
+
+class IPVersion(str, enum.Enum):
+ """The IP address family for the listening socket."""
+
+ IPV4: Final = "4"
+ IPV6: Final = "6"
+
+ def __str__(self) -> str:
+ """Return the string value itself."""
+ return self.value
+
+ def addr_len(self) -> set[int]:
+ """Obtain the expected length of an address/port tuple."""
+ match self:
+ case IPVersion.IPV4:
+ return {2}
+
+ case IPVersion.IPV6:
+ return {2, 4}
+
+ def family(self) -> socket.AddressFamily:
+ """Obtain the address family corresponding to this value."""
+ match self:
+ case IPVersion.IPV4:
+ return socket.AF_INET
+
+ case IPVersion.IPV6:
+ return socket.AF_INET6
+
+
+def get_listen_address(ip_version: IPVersion) -> tuple[str, set[int], socket.AddressFamily] | None:
+ """Get a loopback address for the specified address family, if any are configured."""
+ ifaces: Final = netifaces.interfaces()
+ if "lo" not in ifaces:
+ print("No 'lo' interface at all?!")
+ return None
+
+ family: Final = ip_version.family()
+ addrs: Final = netifaces.ifaddresses("lo")
+ candidates: Final = addrs.get(family)
+ if not candidates:
+ print("No addresses for the specified family on the 'lo' interface")
+ return None
+
+ return candidates[0]["addr"], ip_version.addr_len(), family
+
+
+def parse_args() -> Config | None:
+ """Parse the command-line arguments."""
+ parser: Final = argparse.ArgumentParser(prog="uctest")
+
+ parser.add_argument(
+ "-d",
+ "--bindir",
+ type=pathlib.Path,
+ required=True,
+ help="the path to the UCSPI utilities",
+ )
+ parser.add_argument(
+ "-i",
+ "--ip-version",
+ type=IPVersion,
+ default=IPVersion.IPV4,
+ help="the address family to listen on ('4' for IPv4, '6' for IPv6)",
+ choices=["4", "6"],
+ )
+ parser.add_argument(
+ "-p",
+ "--proto",
+ type=str,
+ required=True,
+ help="the UCSPI protocol ('tcp', 'unix', etc)",
+ )
+ args: Final = parser.parse_args()
+
+ listen_data: Final = get_listen_address(args.ip_version)
+ if listen_data is None:
+ return None
+
+ return Config(
+ bindir=args.bindir.absolute(),
+ listen_addr=listen_data[0],
+ listen_addr_len=listen_data[1],
+ listen_family=listen_data[2],
+ proto=args.proto,
+ utf8_env=utf8_locale.UTF8Detect().detect().env,
+ )
+
+
+def main() -> None:
+ """Parse command-line arguments, run the tests."""
+ cfg: Final = parse_args()
+ if cfg is None:
+ print("No loopback interface addresses for the requested family")
+ return
+
+ ucspi_test.add_handler("tcp", TcpRunner)
+ ucspi_test.run_test_handler(cfg)
+
+
+if __name__ == "__main__":
+ main()