summaryrefslogtreecommitdiff
path: root/src/service_identity/_common.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/service_identity/_common.py')
-rw-r--r--src/service_identity/_common.py423
1 files changed, 423 insertions, 0 deletions
diff --git a/src/service_identity/_common.py b/src/service_identity/_common.py
new file mode 100644
index 0000000..fa8a359
--- /dev/null
+++ b/src/service_identity/_common.py
@@ -0,0 +1,423 @@
+"""
+Common verification code.
+"""
+
+from __future__ import absolute_import, division, print_function
+
+import re
+
+import attr
+
+from ._compat import maketrans, text_type
+from .exceptions import (
+ CertificateError,
+ DNSMismatch,
+ SRVMismatch,
+ URIMismatch,
+ VerificationError,
+)
+
+try:
+ import idna
+except ImportError: # pragma: nocover
+ idna = None
+
+
+@attr.s
+class ServiceMatch(object):
+ """
+ A match of a service id and a certificate pattern.
+ """
+ service_id = attr.ib()
+ cert_pattern = attr.ib()
+
+
+def verify_service_identity(cert_patterns, obligatory_ids, optional_ids):
+ """
+ Verify whether *cert_patterns* are valid for *obligatory_ids* and
+ *optional_ids*.
+
+ *obligatory_ids* must be both present and match. *optional_ids* must match
+ if a pattern of the respective type is present.
+ """
+ errors = []
+ matches = (_find_matches(cert_patterns, obligatory_ids) +
+ _find_matches(cert_patterns, optional_ids))
+
+ matched_ids = [match.service_id for match in matches]
+ for i in obligatory_ids:
+ if i not in matched_ids:
+ errors.append(i.error_on_mismatch(mismatched_id=i))
+
+ for i in optional_ids:
+ # If an optional ID is not matched by a certificate pattern *but* there
+ # is a pattern of the same type , it is an error and the verification
+ # fails. Example: the user passes a SRV-ID for "_mail.domain.com" but
+ # the certificate contains an SRV-Pattern for "_xmpp.domain.com".
+ if (
+ i not in matched_ids and
+ _contains_instance_of(cert_patterns, i.pattern_class)
+ ):
+ errors.append(i.error_on_mismatch(mismatched_id=i))
+
+ if errors:
+ raise VerificationError(errors=errors)
+
+ return matches
+
+
+def _find_matches(cert_patterns, service_ids):
+ """
+ Search for matching certificate patterns and service_ids.
+
+ :param cert_ids: List certificate IDs like DNSPattern.
+ :type cert_ids: `list`
+
+ :param service_ids: List of service IDs like DNS_ID.
+ :type service_ids: `list`
+
+ :rtype: `list` of `ServiceMatch`
+ """
+ matches = []
+ for sid in service_ids:
+ for cid in cert_patterns:
+ if sid.verify(cid):
+ matches.append(
+ ServiceMatch(cert_pattern=cid, service_id=sid)
+ )
+ return matches
+
+
+def _contains_instance_of(seq, cl):
+ """
+ :type seq: iterable
+ :type cl: type
+
+ :rtype: bool
+ """
+ for e in seq:
+ if isinstance(e, cl):
+ return True
+ return False
+
+
+_RE_IPv4 = re.compile(br"^([0-9*]{1,3}\.){3}[0-9*]{1,3}$")
+_RE_IPv6 = re.compile(br"^([a-f0-9*]{0,4}:)+[a-f0-9*]{1,4}$")
+_RE_NUMBER = re.compile(br"^[0-9]+$")
+
+
+def _is_ip_address(pattern):
+ """
+ Check whether *pattern* could be/match an IP address.
+
+ Does *not* guarantee that pattern is in fact a valid IP address; especially
+ the checks for IPv6 are rather coarse. This function is for security
+ checks, not for validating IP addresses.
+
+ :param pattern: A pattern for a host name.
+ :type pattern: `bytes` or `unicode`
+
+ :return: `True` if *pattern* could be an IP address, else `False`.
+ :rtype: `bool`
+ """
+ if isinstance(pattern, text_type):
+ try:
+ pattern = pattern.encode('ascii')
+ except UnicodeError:
+ return False
+
+ return (
+ _RE_IPv4.match(pattern) is not None or
+ _RE_IPv6.match(pattern) is not None or
+ _RE_NUMBER.match(pattern) is not None
+ )
+
+
+@attr.s(init=False)
+class DNSPattern(object):
+ """
+ A DNS pattern as extracted from certificates.
+ """
+ pattern = attr.ib()
+
+ _RE_LEGAL_CHARS = re.compile(br"^[a-z0-9\-_.]+$")
+
+ def __init__(self, pattern):
+ """
+ :type pattern: `bytes`
+ """
+ if not isinstance(pattern, bytes):
+ raise TypeError("The DNS pattern must be a bytes string.")
+
+ pattern = pattern.strip()
+
+ if pattern == b"" or _is_ip_address(pattern) or b"\0" in pattern:
+ raise CertificateError(
+ "Invalid DNS pattern {0!r}.".format(pattern)
+ )
+
+ self.pattern = pattern.translate(_TRANS_TO_LOWER)
+ if b'*' in self.pattern:
+ _validate_pattern(self.pattern)
+
+
+@attr.s(init=False)
+class URIPattern(object):
+ """
+ An URI pattern as extracted from certificates.
+ """
+ protocol_pattern = attr.ib()
+ dns_pattern = attr.ib()
+
+ def __init__(self, pattern):
+ """
+ :type pattern: `bytes`
+ """
+ if not isinstance(pattern, bytes):
+ raise TypeError("The URI pattern must be a bytes string.")
+
+ pattern = pattern.strip().translate(_TRANS_TO_LOWER)
+
+ if (
+ b":" not in pattern or
+ b"*" in pattern or
+ _is_ip_address(pattern)
+ ):
+ raise CertificateError(
+ "Invalid URI pattern {0!r}.".format(pattern)
+ )
+ self.protocol_pattern, hostname = pattern.split(b":")
+ self.dns_pattern = DNSPattern(hostname)
+
+
+@attr.s(init=False)
+class SRVPattern(object):
+ """
+ An SRV pattern as extracted from certificates.
+ """
+ name_pattern = attr.ib()
+ dns_pattern = attr.ib()
+
+ def __init__(self, pattern):
+ """
+ :type pattern: `bytes`
+ """
+ if not isinstance(pattern, bytes):
+ raise TypeError("The SRV pattern must be a bytes string.")
+
+ pattern = pattern.strip().translate(_TRANS_TO_LOWER)
+
+ if (
+ pattern[0] != b"_"[0] or
+ b"." not in pattern or
+ b"*" in pattern or
+ _is_ip_address(pattern)
+ ):
+ raise CertificateError(
+ "Invalid SRV pattern {0!r}.".format(pattern)
+ )
+ name, hostname = pattern.split(b".", 1)
+ self.name_pattern = name[1:]
+ self.dns_pattern = DNSPattern(hostname)
+
+
+@attr.s(init=False)
+class DNS_ID(object):
+ """
+ A DNS service ID, aka hostname.
+ """
+ hostname = attr.ib()
+
+ # characters that are legal in a normalized hostname
+ _RE_LEGAL_CHARS = re.compile(br"^[a-z0-9\-_.]+$")
+ pattern_class = DNSPattern
+ error_on_mismatch = DNSMismatch
+
+ def __init__(self, hostname):
+ """
+ :type hostname: `unicode`
+ """
+ if not isinstance(hostname, text_type):
+ raise TypeError("DNS-ID must be a unicode string.")
+
+ hostname = hostname.strip()
+ if hostname == u"" or _is_ip_address(hostname):
+ raise ValueError("Invalid DNS-ID.")
+
+ if any(ord(c) > 127 for c in hostname):
+ if idna:
+ ascii_id = idna.encode(hostname)
+ else:
+ raise ImportError(
+ "idna library is required for non-ASCII IDs."
+ )
+ else:
+ ascii_id = hostname.encode("ascii")
+
+ self.hostname = ascii_id.translate(_TRANS_TO_LOWER)
+ if self._RE_LEGAL_CHARS.match(self.hostname) is None:
+ raise ValueError("Invalid DNS-ID.")
+
+ def verify(self, pattern):
+ """
+ http://tools.ietf.org/search/rfc6125#section-6.4
+ """
+ if isinstance(pattern, self.pattern_class):
+ return _hostname_matches(pattern.pattern, self.hostname)
+ else:
+ return False
+
+
+@attr.s(init=False)
+class URI_ID(object):
+ """
+ An URI service ID.
+ """
+ protocol = attr.ib()
+ dns_id = attr.ib()
+
+ pattern_class = URIPattern
+ error_on_mismatch = URIMismatch
+
+ def __init__(self, uri):
+ """
+ :type uri: `unicode`
+ """
+ if not isinstance(uri, text_type):
+ raise TypeError("URI-ID must be a unicode string.")
+
+ uri = uri.strip()
+ if u":" not in uri or _is_ip_address(uri):
+ raise ValueError("Invalid URI-ID.")
+
+ prot, hostname = uri.split(u":")
+
+ self.protocol = prot.encode("ascii").translate(_TRANS_TO_LOWER)
+ self.dns_id = DNS_ID(hostname.strip(u"/"))
+
+ def verify(self, pattern):
+ """
+ http://tools.ietf.org/search/rfc6125#section-6.5.2
+ """
+ if isinstance(pattern, self.pattern_class):
+ return (
+ pattern.protocol_pattern == self.protocol and
+ self.dns_id.verify(pattern.dns_pattern)
+ )
+ else:
+ return False
+
+
+@attr.s(init=False)
+class SRV_ID(object):
+ """
+ An SRV service ID.
+ """
+ name = attr.ib()
+ dns_id = attr.ib()
+
+ pattern_class = SRVPattern
+ error_on_mismatch = SRVMismatch
+
+ def __init__(self, srv):
+ """
+ :type srv: `unicode`
+ """
+ if not isinstance(srv, text_type):
+ raise TypeError("SRV-ID must be a unicode string.")
+
+ srv = srv.strip()
+ if u"." not in srv or _is_ip_address(srv) or srv[0] != u"_":
+ raise ValueError("Invalid SRV-ID.")
+
+ name, hostname = srv.split(u".", 1)
+
+ self.name = name[1:].encode("ascii").translate(_TRANS_TO_LOWER)
+ self.dns_id = DNS_ID(hostname)
+
+ def verify(self, pattern):
+ """
+ http://tools.ietf.org/search/rfc6125#section-6.5.1
+ """
+ if isinstance(pattern, self.pattern_class):
+ return (
+ self.name == pattern.name_pattern and
+ self.dns_id.verify(pattern.dns_pattern)
+ )
+ else:
+ return False
+
+
+def _hostname_matches(cert_pattern, actual_hostname):
+ """
+ :type cert_pattern: `bytes`
+ :type actual_hostname: `bytes`
+
+ :return: `True` if *cert_pattern* matches *actual_hostname*, else `False`.
+ :rtype: `bool`
+ """
+ if b'*' in cert_pattern:
+ cert_head, cert_tail = cert_pattern.split(b".", 1)
+ actual_head, actual_tail = actual_hostname.split(b".", 1)
+ if cert_tail != actual_tail:
+ return False
+ # No patterns for IDNA
+ if actual_head.startswith(b"xn--"):
+ return False
+
+ if cert_head == b"*":
+ return True
+
+ start, end = cert_head.split(b"*")
+ if start == b"":
+ # *oo
+ return actual_head.endswith(end)
+ elif end == b"":
+ # f*
+ return actual_head.startswith(start)
+ else:
+ # f*o
+ return actual_head.startswith(start) and actual_head.endswith(end)
+
+ else:
+ return cert_pattern == actual_hostname
+
+
+def _validate_pattern(cert_pattern):
+ """
+ Check whether the usage of wildcards within *cert_pattern* conforms with
+ our expectations.
+
+ :type hostname: `bytes`
+
+ :return: None
+ """
+ cnt = cert_pattern.count(b"*")
+ if cnt > 1:
+ raise CertificateError(
+ "Certificate's DNS-ID {0!r} contains too many wildcards."
+ .format(cert_pattern)
+ )
+ parts = cert_pattern.split(b".")
+ if len(parts) < 3:
+ raise CertificateError(
+ "Certificate's DNS-ID {0!r} hast too few host components for "
+ "wildcard usage."
+ .format(cert_pattern)
+ )
+ # We assume there will always be only one wildcard allowed.
+ if b"*" not in parts[0]:
+ raise CertificateError(
+ "Certificate's DNS-ID {0!r} has a wildcard outside the left-most "
+ "part.".format(cert_pattern)
+ )
+ if any(not len(p) for p in parts):
+ raise CertificateError(
+ "Certificate's DNS-ID {0!r} contains empty parts."
+ .format(cert_pattern)
+ )
+
+
+# Ensure no locale magic interferes.
+_TRANS_TO_LOWER = maketrans(b"ABCDEFGHIJKLMNOPQRSTUVWXYZ",
+ b"abcdefghijklmnopqrstuvwxyz")