diff options
Diffstat (limited to 'src/service_identity/_common.py')
-rw-r--r-- | src/service_identity/_common.py | 174 |
1 files changed, 100 insertions, 74 deletions
diff --git a/src/service_identity/_common.py b/src/service_identity/_common.py index fa8a359..9b4e773 100644 --- a/src/service_identity/_common.py +++ b/src/service_identity/_common.py @@ -4,6 +4,7 @@ Common verification code. from __future__ import absolute_import, division, print_function +import ipaddress import re import attr @@ -12,22 +13,25 @@ from ._compat import maketrans, text_type from .exceptions import ( CertificateError, DNSMismatch, + IPAddressMismatch, SRVMismatch, URIMismatch, VerificationError, ) + try: import idna except ImportError: # pragma: nocover idna = None -@attr.s +@attr.s(slots=True) class ServiceMatch(object): """ A match of a service id and a certificate pattern. """ + service_id = attr.ib() cert_pattern = attr.ib() @@ -41,8 +45,9 @@ def verify_service_identity(cert_patterns, obligatory_ids, optional_ids): if a pattern of the respective type is present. """ errors = [] - matches = (_find_matches(cert_patterns, obligatory_ids) + - _find_matches(cert_patterns, optional_ids)) + matches = _find_matches(cert_patterns, obligatory_ids) + _find_matches( + cert_patterns, optional_ids + ) matched_ids = [match.service_id for match in matches] for i in obligatory_ids: @@ -54,9 +59,8 @@ def verify_service_identity(cert_patterns, obligatory_ids, optional_ids): # is a pattern of the same type , it is an error and the verification # fails. Example: the user passes a SRV-ID for "_mail.domain.com" but # the certificate contains an SRV-Pattern for "_xmpp.domain.com". - if ( - i not in matched_ids and - _contains_instance_of(cert_patterns, i.pattern_class) + if i not in matched_ids and _contains_instance_of( + cert_patterns, i.pattern_class ): errors.append(i.error_on_mismatch(mismatched_id=i)) @@ -82,9 +86,7 @@ def _find_matches(cert_patterns, service_ids): for sid in service_ids: for cid in cert_patterns: if sid.verify(cid): - matches.append( - ServiceMatch(cert_pattern=cid, service_id=sid) - ) + matches.append(ServiceMatch(cert_pattern=cid, service_id=sid)) return matches @@ -101,43 +103,42 @@ def _contains_instance_of(seq, cl): return False -_RE_IPv4 = re.compile(br"^([0-9*]{1,3}\.){3}[0-9*]{1,3}$") -_RE_IPv6 = re.compile(br"^([a-f0-9*]{0,4}:)+[a-f0-9*]{1,4}$") -_RE_NUMBER = re.compile(br"^[0-9]+$") - - def _is_ip_address(pattern): """ Check whether *pattern* could be/match an IP address. - Does *not* guarantee that pattern is in fact a valid IP address; especially - the checks for IPv6 are rather coarse. This function is for security - checks, not for validating IP addresses. - :param pattern: A pattern for a host name. :type pattern: `bytes` or `unicode` :return: `True` if *pattern* could be an IP address, else `False`. - :rtype: `bool` + :rtype: bool """ - if isinstance(pattern, text_type): + if isinstance(pattern, bytes): try: - pattern = pattern.encode('ascii') + pattern = pattern.decode("ascii") except UnicodeError: return False - return ( - _RE_IPv4.match(pattern) is not None or - _RE_IPv6.match(pattern) is not None or - _RE_NUMBER.match(pattern) is not None - ) + try: + int(pattern) + return True + except ValueError: + pass + + try: + ipaddress.ip_address(pattern.replace("*", "1")) + except ValueError: + return False + + return True -@attr.s(init=False) +@attr.s(init=False, slots=True) class DNSPattern(object): """ A DNS pattern as extracted from certificates. """ + pattern = attr.ib() _RE_LEGAL_CHARS = re.compile(br"^[a-z0-9\-_.]+$") @@ -157,15 +158,34 @@ class DNSPattern(object): ) self.pattern = pattern.translate(_TRANS_TO_LOWER) - if b'*' in self.pattern: + if b"*" in self.pattern: _validate_pattern(self.pattern) -@attr.s(init=False) +@attr.s(slots=True) +class IPAddressPattern(object): + """ + An IP address pattern as extracted from certificates. + """ + + pattern = attr.ib() + + @classmethod + def from_bytes(cls, bs): + try: + return cls(pattern=ipaddress.ip_address(bs)) + except ValueError: + raise CertificateError( + "Invalid IP address pattern {!r}.".format(bs) + ) + + +@attr.s(init=False, slots=True) class URIPattern(object): """ An URI pattern as extracted from certificates. """ + protocol_pattern = attr.ib() dns_pattern = attr.ib() @@ -178,11 +198,7 @@ class URIPattern(object): pattern = pattern.strip().translate(_TRANS_TO_LOWER) - if ( - b":" not in pattern or - b"*" in pattern or - _is_ip_address(pattern) - ): + if b":" not in pattern or b"*" in pattern or _is_ip_address(pattern): raise CertificateError( "Invalid URI pattern {0!r}.".format(pattern) ) @@ -190,11 +206,12 @@ class URIPattern(object): self.dns_pattern = DNSPattern(hostname) -@attr.s(init=False) +@attr.s(init=False, slots=True) class SRVPattern(object): """ An SRV pattern as extracted from certificates. """ + name_pattern = attr.ib() dns_pattern = attr.ib() @@ -208,10 +225,10 @@ class SRVPattern(object): pattern = pattern.strip().translate(_TRANS_TO_LOWER) if ( - pattern[0] != b"_"[0] or - b"." not in pattern or - b"*" in pattern or - _is_ip_address(pattern) + pattern[0] != b"_"[0] + or b"." not in pattern + or b"*" in pattern + or _is_ip_address(pattern) ): raise CertificateError( "Invalid SRV pattern {0!r}.".format(pattern) @@ -221,11 +238,12 @@ class SRVPattern(object): self.dns_pattern = DNSPattern(hostname) -@attr.s(init=False) +@attr.s(init=False, slots=True) class DNS_ID(object): """ A DNS service ID, aka hostname. """ + hostname = attr.ib() # characters that are legal in a normalized hostname @@ -260,7 +278,7 @@ class DNS_ID(object): def verify(self, pattern): """ - http://tools.ietf.org/search/rfc6125#section-6.4 + https://tools.ietf.org/search/rfc6125#section-6.4 """ if isinstance(pattern, self.pattern_class): return _hostname_matches(pattern.pattern, self.hostname) @@ -268,11 +286,30 @@ class DNS_ID(object): return False -@attr.s(init=False) +@attr.s(slots=True) +class IPAddress_ID(object): + """ + An IP address service ID. + """ + + ip = attr.ib(converter=ipaddress.ip_address) + + pattern_class = IPAddressPattern + error_on_mismatch = IPAddressMismatch + + def verify(self, pattern): + """ + https://tools.ietf.org/search/rfc2818#section-3.1 + """ + return self.ip == pattern.pattern + + +@attr.s(init=False, slots=True) class URI_ID(object): """ An URI service ID. """ + protocol = attr.ib() dns_id = attr.ib() @@ -297,22 +334,23 @@ class URI_ID(object): def verify(self, pattern): """ - http://tools.ietf.org/search/rfc6125#section-6.5.2 + https://tools.ietf.org/search/rfc6125#section-6.5.2 """ if isinstance(pattern, self.pattern_class): return ( - pattern.protocol_pattern == self.protocol and - self.dns_id.verify(pattern.dns_pattern) + pattern.protocol_pattern == self.protocol + and self.dns_id.verify(pattern.dns_pattern) ) else: return False -@attr.s(init=False) +@attr.s(init=False, slots=True) class SRV_ID(object): """ An SRV service ID. """ + name = attr.ib() dns_id = attr.ib() @@ -337,12 +375,11 @@ class SRV_ID(object): def verify(self, pattern): """ - http://tools.ietf.org/search/rfc6125#section-6.5.1 + https://tools.ietf.org/search/rfc6125#section-6.5.1 """ if isinstance(pattern, self.pattern_class): - return ( - self.name == pattern.name_pattern and - self.dns_id.verify(pattern.dns_pattern) + return self.name == pattern.name_pattern and self.dns_id.verify( + pattern.dns_pattern ) else: return False @@ -356,7 +393,7 @@ def _hostname_matches(cert_pattern, actual_hostname): :return: `True` if *cert_pattern* matches *actual_hostname*, else `False`. :rtype: `bool` """ - if b'*' in cert_pattern: + if b"*" in cert_pattern: cert_head, cert_tail = cert_pattern.split(b".", 1) actual_head, actual_tail = actual_hostname.split(b".", 1) if cert_tail != actual_tail: @@ -365,20 +402,7 @@ def _hostname_matches(cert_pattern, actual_hostname): if actual_head.startswith(b"xn--"): return False - if cert_head == b"*": - return True - - start, end = cert_head.split(b"*") - if start == b"": - # *oo - return actual_head.endswith(end) - elif end == b"": - # f* - return actual_head.startswith(start) - else: - # f*o - return actual_head.startswith(start) and actual_head.endswith(end) - + return cert_head == b"*" or cert_head == actual_head else: return cert_pattern == actual_hostname @@ -395,15 +419,15 @@ def _validate_pattern(cert_pattern): cnt = cert_pattern.count(b"*") if cnt > 1: raise CertificateError( - "Certificate's DNS-ID {0!r} contains too many wildcards." - .format(cert_pattern) + "Certificate's DNS-ID {0!r} contains too many wildcards.".format( + cert_pattern + ) ) parts = cert_pattern.split(b".") if len(parts) < 3: raise CertificateError( - "Certificate's DNS-ID {0!r} hast too few host components for " - "wildcard usage." - .format(cert_pattern) + "Certificate's DNS-ID {0!r} has too few host components for " + "wildcard usage.".format(cert_pattern) ) # We assume there will always be only one wildcard allowed. if b"*" not in parts[0]: @@ -413,11 +437,13 @@ def _validate_pattern(cert_pattern): ) if any(not len(p) for p in parts): raise CertificateError( - "Certificate's DNS-ID {0!r} contains empty parts." - .format(cert_pattern) + "Certificate's DNS-ID {0!r} contains empty parts.".format( + cert_pattern + ) ) # Ensure no locale magic interferes. -_TRANS_TO_LOWER = maketrans(b"ABCDEFGHIJKLMNOPQRSTUVWXYZ", - b"abcdefghijklmnopqrstuvwxyz") +_TRANS_TO_LOWER = maketrans( + b"ABCDEFGHIJKLMNOPQRSTUVWXYZ", b"abcdefghijklmnopqrstuvwxyz" +) |