summaryrefslogtreecommitdiff
path: root/src/service_identity/_common.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/service_identity/_common.py')
-rw-r--r--src/service_identity/_common.py174
1 files changed, 100 insertions, 74 deletions
diff --git a/src/service_identity/_common.py b/src/service_identity/_common.py
index fa8a359..9b4e773 100644
--- a/src/service_identity/_common.py
+++ b/src/service_identity/_common.py
@@ -4,6 +4,7 @@ Common verification code.
from __future__ import absolute_import, division, print_function
+import ipaddress
import re
import attr
@@ -12,22 +13,25 @@ from ._compat import maketrans, text_type
from .exceptions import (
CertificateError,
DNSMismatch,
+ IPAddressMismatch,
SRVMismatch,
URIMismatch,
VerificationError,
)
+
try:
import idna
except ImportError: # pragma: nocover
idna = None
-@attr.s
+@attr.s(slots=True)
class ServiceMatch(object):
"""
A match of a service id and a certificate pattern.
"""
+
service_id = attr.ib()
cert_pattern = attr.ib()
@@ -41,8 +45,9 @@ def verify_service_identity(cert_patterns, obligatory_ids, optional_ids):
if a pattern of the respective type is present.
"""
errors = []
- matches = (_find_matches(cert_patterns, obligatory_ids) +
- _find_matches(cert_patterns, optional_ids))
+ matches = _find_matches(cert_patterns, obligatory_ids) + _find_matches(
+ cert_patterns, optional_ids
+ )
matched_ids = [match.service_id for match in matches]
for i in obligatory_ids:
@@ -54,9 +59,8 @@ def verify_service_identity(cert_patterns, obligatory_ids, optional_ids):
# is a pattern of the same type , it is an error and the verification
# fails. Example: the user passes a SRV-ID for "_mail.domain.com" but
# the certificate contains an SRV-Pattern for "_xmpp.domain.com".
- if (
- i not in matched_ids and
- _contains_instance_of(cert_patterns, i.pattern_class)
+ if i not in matched_ids and _contains_instance_of(
+ cert_patterns, i.pattern_class
):
errors.append(i.error_on_mismatch(mismatched_id=i))
@@ -82,9 +86,7 @@ def _find_matches(cert_patterns, service_ids):
for sid in service_ids:
for cid in cert_patterns:
if sid.verify(cid):
- matches.append(
- ServiceMatch(cert_pattern=cid, service_id=sid)
- )
+ matches.append(ServiceMatch(cert_pattern=cid, service_id=sid))
return matches
@@ -101,43 +103,42 @@ def _contains_instance_of(seq, cl):
return False
-_RE_IPv4 = re.compile(br"^([0-9*]{1,3}\.){3}[0-9*]{1,3}$")
-_RE_IPv6 = re.compile(br"^([a-f0-9*]{0,4}:)+[a-f0-9*]{1,4}$")
-_RE_NUMBER = re.compile(br"^[0-9]+$")
-
-
def _is_ip_address(pattern):
"""
Check whether *pattern* could be/match an IP address.
- Does *not* guarantee that pattern is in fact a valid IP address; especially
- the checks for IPv6 are rather coarse. This function is for security
- checks, not for validating IP addresses.
-
:param pattern: A pattern for a host name.
:type pattern: `bytes` or `unicode`
:return: `True` if *pattern* could be an IP address, else `False`.
- :rtype: `bool`
+ :rtype: bool
"""
- if isinstance(pattern, text_type):
+ if isinstance(pattern, bytes):
try:
- pattern = pattern.encode('ascii')
+ pattern = pattern.decode("ascii")
except UnicodeError:
return False
- return (
- _RE_IPv4.match(pattern) is not None or
- _RE_IPv6.match(pattern) is not None or
- _RE_NUMBER.match(pattern) is not None
- )
+ try:
+ int(pattern)
+ return True
+ except ValueError:
+ pass
+
+ try:
+ ipaddress.ip_address(pattern.replace("*", "1"))
+ except ValueError:
+ return False
+
+ return True
-@attr.s(init=False)
+@attr.s(init=False, slots=True)
class DNSPattern(object):
"""
A DNS pattern as extracted from certificates.
"""
+
pattern = attr.ib()
_RE_LEGAL_CHARS = re.compile(br"^[a-z0-9\-_.]+$")
@@ -157,15 +158,34 @@ class DNSPattern(object):
)
self.pattern = pattern.translate(_TRANS_TO_LOWER)
- if b'*' in self.pattern:
+ if b"*" in self.pattern:
_validate_pattern(self.pattern)
-@attr.s(init=False)
+@attr.s(slots=True)
+class IPAddressPattern(object):
+ """
+ An IP address pattern as extracted from certificates.
+ """
+
+ pattern = attr.ib()
+
+ @classmethod
+ def from_bytes(cls, bs):
+ try:
+ return cls(pattern=ipaddress.ip_address(bs))
+ except ValueError:
+ raise CertificateError(
+ "Invalid IP address pattern {!r}.".format(bs)
+ )
+
+
+@attr.s(init=False, slots=True)
class URIPattern(object):
"""
An URI pattern as extracted from certificates.
"""
+
protocol_pattern = attr.ib()
dns_pattern = attr.ib()
@@ -178,11 +198,7 @@ class URIPattern(object):
pattern = pattern.strip().translate(_TRANS_TO_LOWER)
- if (
- b":" not in pattern or
- b"*" in pattern or
- _is_ip_address(pattern)
- ):
+ if b":" not in pattern or b"*" in pattern or _is_ip_address(pattern):
raise CertificateError(
"Invalid URI pattern {0!r}.".format(pattern)
)
@@ -190,11 +206,12 @@ class URIPattern(object):
self.dns_pattern = DNSPattern(hostname)
-@attr.s(init=False)
+@attr.s(init=False, slots=True)
class SRVPattern(object):
"""
An SRV pattern as extracted from certificates.
"""
+
name_pattern = attr.ib()
dns_pattern = attr.ib()
@@ -208,10 +225,10 @@ class SRVPattern(object):
pattern = pattern.strip().translate(_TRANS_TO_LOWER)
if (
- pattern[0] != b"_"[0] or
- b"." not in pattern or
- b"*" in pattern or
- _is_ip_address(pattern)
+ pattern[0] != b"_"[0]
+ or b"." not in pattern
+ or b"*" in pattern
+ or _is_ip_address(pattern)
):
raise CertificateError(
"Invalid SRV pattern {0!r}.".format(pattern)
@@ -221,11 +238,12 @@ class SRVPattern(object):
self.dns_pattern = DNSPattern(hostname)
-@attr.s(init=False)
+@attr.s(init=False, slots=True)
class DNS_ID(object):
"""
A DNS service ID, aka hostname.
"""
+
hostname = attr.ib()
# characters that are legal in a normalized hostname
@@ -260,7 +278,7 @@ class DNS_ID(object):
def verify(self, pattern):
"""
- http://tools.ietf.org/search/rfc6125#section-6.4
+ https://tools.ietf.org/search/rfc6125#section-6.4
"""
if isinstance(pattern, self.pattern_class):
return _hostname_matches(pattern.pattern, self.hostname)
@@ -268,11 +286,30 @@ class DNS_ID(object):
return False
-@attr.s(init=False)
+@attr.s(slots=True)
+class IPAddress_ID(object):
+ """
+ An IP address service ID.
+ """
+
+ ip = attr.ib(converter=ipaddress.ip_address)
+
+ pattern_class = IPAddressPattern
+ error_on_mismatch = IPAddressMismatch
+
+ def verify(self, pattern):
+ """
+ https://tools.ietf.org/search/rfc2818#section-3.1
+ """
+ return self.ip == pattern.pattern
+
+
+@attr.s(init=False, slots=True)
class URI_ID(object):
"""
An URI service ID.
"""
+
protocol = attr.ib()
dns_id = attr.ib()
@@ -297,22 +334,23 @@ class URI_ID(object):
def verify(self, pattern):
"""
- http://tools.ietf.org/search/rfc6125#section-6.5.2
+ https://tools.ietf.org/search/rfc6125#section-6.5.2
"""
if isinstance(pattern, self.pattern_class):
return (
- pattern.protocol_pattern == self.protocol and
- self.dns_id.verify(pattern.dns_pattern)
+ pattern.protocol_pattern == self.protocol
+ and self.dns_id.verify(pattern.dns_pattern)
)
else:
return False
-@attr.s(init=False)
+@attr.s(init=False, slots=True)
class SRV_ID(object):
"""
An SRV service ID.
"""
+
name = attr.ib()
dns_id = attr.ib()
@@ -337,12 +375,11 @@ class SRV_ID(object):
def verify(self, pattern):
"""
- http://tools.ietf.org/search/rfc6125#section-6.5.1
+ https://tools.ietf.org/search/rfc6125#section-6.5.1
"""
if isinstance(pattern, self.pattern_class):
- return (
- self.name == pattern.name_pattern and
- self.dns_id.verify(pattern.dns_pattern)
+ return self.name == pattern.name_pattern and self.dns_id.verify(
+ pattern.dns_pattern
)
else:
return False
@@ -356,7 +393,7 @@ def _hostname_matches(cert_pattern, actual_hostname):
:return: `True` if *cert_pattern* matches *actual_hostname*, else `False`.
:rtype: `bool`
"""
- if b'*' in cert_pattern:
+ if b"*" in cert_pattern:
cert_head, cert_tail = cert_pattern.split(b".", 1)
actual_head, actual_tail = actual_hostname.split(b".", 1)
if cert_tail != actual_tail:
@@ -365,20 +402,7 @@ def _hostname_matches(cert_pattern, actual_hostname):
if actual_head.startswith(b"xn--"):
return False
- if cert_head == b"*":
- return True
-
- start, end = cert_head.split(b"*")
- if start == b"":
- # *oo
- return actual_head.endswith(end)
- elif end == b"":
- # f*
- return actual_head.startswith(start)
- else:
- # f*o
- return actual_head.startswith(start) and actual_head.endswith(end)
-
+ return cert_head == b"*" or cert_head == actual_head
else:
return cert_pattern == actual_hostname
@@ -395,15 +419,15 @@ def _validate_pattern(cert_pattern):
cnt = cert_pattern.count(b"*")
if cnt > 1:
raise CertificateError(
- "Certificate's DNS-ID {0!r} contains too many wildcards."
- .format(cert_pattern)
+ "Certificate's DNS-ID {0!r} contains too many wildcards.".format(
+ cert_pattern
+ )
)
parts = cert_pattern.split(b".")
if len(parts) < 3:
raise CertificateError(
- "Certificate's DNS-ID {0!r} hast too few host components for "
- "wildcard usage."
- .format(cert_pattern)
+ "Certificate's DNS-ID {0!r} has too few host components for "
+ "wildcard usage.".format(cert_pattern)
)
# We assume there will always be only one wildcard allowed.
if b"*" not in parts[0]:
@@ -413,11 +437,13 @@ def _validate_pattern(cert_pattern):
)
if any(not len(p) for p in parts):
raise CertificateError(
- "Certificate's DNS-ID {0!r} contains empty parts."
- .format(cert_pattern)
+ "Certificate's DNS-ID {0!r} contains empty parts.".format(
+ cert_pattern
+ )
)
# Ensure no locale magic interferes.
-_TRANS_TO_LOWER = maketrans(b"ABCDEFGHIJKLMNOPQRSTUVWXYZ",
- b"abcdefghijklmnopqrstuvwxyz")
+_TRANS_TO_LOWER = maketrans(
+ b"ABCDEFGHIJKLMNOPQRSTUVWXYZ", b"abcdefghijklmnopqrstuvwxyz"
+)