summaryrefslogtreecommitdiff
path: root/synapse/crypto/keyring.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/crypto/keyring.py')
-rw-r--r--synapse/crypto/keyring.py686
1 files changed, 686 insertions, 0 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
new file mode 100644
index 00000000..8b6a5986
--- /dev/null
+++ b/synapse/crypto/keyring.py
@@ -0,0 +1,686 @@
+# -*- coding: utf-8 -*-
+# Copyright 2014, 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.crypto.keyclient import fetch_server_key
+from synapse.api.errors import SynapseError, Codes
+from synapse.util.retryutils import get_retry_limiter
+from synapse.util import unwrapFirstError
+from synapse.util.async import ObservableDeferred
+
+from twisted.internet import defer
+
+from signedjson.sign import (
+ verify_signed_json, signature_ids, sign_json, encode_canonical_json
+)
+from signedjson.key import (
+ is_signing_algorithm_supported, decode_verify_key_bytes
+)
+from unpaddedbase64 import decode_base64, encode_base64
+
+from OpenSSL import crypto
+
+from collections import namedtuple
+import urllib
+import hashlib
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
+
+
+class Keyring(object):
+ def __init__(self, hs):
+ self.store = hs.get_datastore()
+ self.clock = hs.get_clock()
+ self.client = hs.get_http_client()
+ self.config = hs.get_config()
+ self.perspective_servers = self.config.perspectives
+ self.hs = hs
+
+ self.key_downloads = {}
+
+ def verify_json_for_server(self, server_name, json_object):
+ return self.verify_json_objects_for_server(
+ [(server_name, json_object)]
+ )[0]
+
+ def verify_json_objects_for_server(self, server_and_json):
+ """Bulk verfies signatures of json objects, bulk fetching keys as
+ necessary.
+
+ Args:
+ server_and_json (list): List of pairs of (server_name, json_object)
+
+ Returns:
+ list of deferreds indicating success or failure to verify each
+ json object's signature for the given server_name.
+ """
+ group_id_to_json = {}
+ group_id_to_group = {}
+ group_ids = []
+
+ next_group_id = 0
+ deferreds = {}
+
+ for server_name, json_object in server_and_json:
+ logger.debug("Verifying for %s", server_name)
+ group_id = next_group_id
+ next_group_id += 1
+ group_ids.append(group_id)
+
+ key_ids = signature_ids(json_object, server_name)
+ if not key_ids:
+ deferreds[group_id] = defer.fail(SynapseError(
+ 400,
+ "Not signed with a supported algorithm",
+ Codes.UNAUTHORIZED,
+ ))
+ else:
+ deferreds[group_id] = defer.Deferred()
+
+ group = KeyGroup(server_name, group_id, key_ids)
+
+ group_id_to_group[group_id] = group
+ group_id_to_json[group_id] = json_object
+
+ @defer.inlineCallbacks
+ def handle_key_deferred(group, deferred):
+ server_name = group.server_name
+ try:
+ _, _, key_id, verify_key = yield deferred
+ except IOError as e:
+ logger.warn(
+ "Got IOError when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e.message),
+ )
+ raise SynapseError(
+ 502,
+ "Error downloading keys for %s" % (server_name,),
+ Codes.UNAUTHORIZED,
+ )
+ except Exception as e:
+ logger.exception(
+ "Got Exception when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e.message),
+ )
+ raise SynapseError(
+ 401,
+ "No key for %s with id %s" % (server_name, key_ids),
+ Codes.UNAUTHORIZED,
+ )
+
+ json_object = group_id_to_json[group.group_id]
+
+ try:
+ verify_signed_json(json_object, server_name, verify_key)
+ except:
+ raise SynapseError(
+ 401,
+ "Invalid signature for server %s with key %s:%s" % (
+ server_name, verify_key.alg, verify_key.version
+ ),
+ Codes.UNAUTHORIZED,
+ )
+
+ server_to_deferred = {
+ server_name: defer.Deferred()
+ for server_name, _ in server_and_json
+ }
+
+ # We want to wait for any previous lookups to complete before
+ # proceeding.
+ wait_on_deferred = self.wait_for_previous_lookups(
+ [server_name for server_name, _ in server_and_json],
+ server_to_deferred,
+ )
+
+ # Actually start fetching keys.
+ wait_on_deferred.addBoth(
+ lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
+ )
+
+ # When we've finished fetching all the keys for a given server_name,
+ # resolve the deferred passed to `wait_for_previous_lookups` so that
+ # any lookups waiting will proceed.
+ server_to_gids = {}
+
+ def remove_deferreds(res, server_name, group_id):
+ server_to_gids[server_name].discard(group_id)
+ if not server_to_gids[server_name]:
+ d = server_to_deferred.pop(server_name, None)
+ if d:
+ d.callback(None)
+ return res
+
+ for g_id, deferred in deferreds.items():
+ server_name = group_id_to_group[g_id].server_name
+ server_to_gids.setdefault(server_name, set()).add(g_id)
+ deferred.addBoth(remove_deferreds, server_name, g_id)
+
+ # Pass those keys to handle_key_deferred so that the json object
+ # signatures can be verified
+ return [
+ handle_key_deferred(
+ group_id_to_group[g_id],
+ deferreds[g_id],
+ )
+ for g_id in group_ids
+ ]
+
+ @defer.inlineCallbacks
+ def wait_for_previous_lookups(self, server_names, server_to_deferred):
+ """Waits for any previous key lookups for the given servers to finish.
+
+ Args:
+ server_names (list): list of server_names we want to lookup
+ server_to_deferred (dict): server_name to deferred which gets
+ resolved once we've finished looking up keys for that server
+ """
+ while True:
+ wait_on = [
+ self.key_downloads[server_name]
+ for server_name in server_names
+ if server_name in self.key_downloads
+ ]
+ if wait_on:
+ yield defer.DeferredList(wait_on)
+ else:
+ break
+
+ for server_name, deferred in server_to_deferred.items():
+ d = ObservableDeferred(deferred)
+ self.key_downloads[server_name] = d
+
+ def rm(r, server_name):
+ self.key_downloads.pop(server_name, None)
+ return r
+
+ d.addBoth(rm, server_name)
+
+ def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
+ """Takes a dict of KeyGroups and tries to find at least one key for
+ each group.
+ """
+
+ # These are functions that produce keys given a list of key ids
+ key_fetch_fns = (
+ self.get_keys_from_store, # First try the local store
+ self.get_keys_from_perspectives, # Then try via perspectives
+ self.get_keys_from_server, # Then try directly
+ )
+
+ @defer.inlineCallbacks
+ def do_iterations():
+ merged_results = {}
+
+ missing_keys = {}
+ for group in group_id_to_group.values():
+ missing_keys.setdefault(group.server_name, set()).union(group.key_ids)
+
+ for fn in key_fetch_fns:
+ results = yield fn(missing_keys.items())
+ merged_results.update(results)
+
+ # We now need to figure out which groups we have keys for
+ # and which we don't
+ missing_groups = {}
+ for group in group_id_to_group.values():
+ for key_id in group.key_ids:
+ if key_id in merged_results[group.server_name]:
+ group_id_to_deferred[group.group_id].callback((
+ group.group_id,
+ group.server_name,
+ key_id,
+ merged_results[group.server_name][key_id],
+ ))
+ break
+ else:
+ missing_groups.setdefault(
+ group.server_name, []
+ ).append(group)
+
+ if not missing_groups:
+ break
+
+ missing_keys = {
+ server_name: set(
+ key_id for group in groups for key_id in group.key_ids
+ )
+ for server_name, groups in missing_groups.items()
+ }
+
+ for group in missing_groups.values():
+ group_id_to_deferred[group.group_id].errback(SynapseError(
+ 401,
+ "No key for %s with id %s" % (
+ group.server_name, group.key_ids,
+ ),
+ Codes.UNAUTHORIZED,
+ ))
+
+ def on_err(err):
+ for deferred in group_id_to_deferred.values():
+ if not deferred.called:
+ deferred.errback(err)
+
+ do_iterations().addErrback(on_err)
+
+ return group_id_to_deferred
+
+ @defer.inlineCallbacks
+ def get_keys_from_store(self, server_name_and_key_ids):
+ res = yield defer.gatherResults(
+ [
+ self.store.get_server_verify_keys(
+ server_name, key_ids
+ ).addCallback(lambda ks, server: (server, ks), server_name)
+ for server_name, key_ids in server_name_and_key_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+
+ defer.returnValue(dict(res))
+
+ @defer.inlineCallbacks
+ def get_keys_from_perspectives(self, server_name_and_key_ids):
+ @defer.inlineCallbacks
+ def get_key(perspective_name, perspective_keys):
+ try:
+ result = yield self.get_server_verify_key_v2_indirect(
+ server_name_and_key_ids, perspective_name, perspective_keys
+ )
+ defer.returnValue(result)
+ except Exception as e:
+ logger.exception(
+ "Unable to get key from %r: %s %s",
+ perspective_name,
+ type(e).__name__, str(e.message),
+ )
+ defer.returnValue({})
+
+ results = yield defer.gatherResults(
+ [
+ get_key(p_name, p_keys)
+ for p_name, p_keys in self.perspective_servers.items()
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+
+ union_of_keys = {}
+ for result in results:
+ for server_name, keys in result.items():
+ union_of_keys.setdefault(server_name, {}).update(keys)
+
+ defer.returnValue(union_of_keys)
+
+ @defer.inlineCallbacks
+ def get_keys_from_server(self, server_name_and_key_ids):
+ @defer.inlineCallbacks
+ def get_key(server_name, key_ids):
+ limiter = yield get_retry_limiter(
+ server_name,
+ self.clock,
+ self.store,
+ )
+ with limiter:
+ keys = None
+ try:
+ keys = yield self.get_server_verify_key_v2_direct(
+ server_name, key_ids
+ )
+ except Exception as e:
+ logger.info(
+ "Unable to getting key %r for %r directly: %s %s",
+ key_ids, server_name,
+ type(e).__name__, str(e.message),
+ )
+
+ if not keys:
+ keys = yield self.get_server_verify_key_v1_direct(
+ server_name, key_ids
+ )
+
+ keys = {server_name: keys}
+
+ defer.returnValue(keys)
+
+ results = yield defer.gatherResults(
+ [
+ get_key(server_name, key_ids)
+ for server_name, key_ids in server_name_and_key_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+
+ merged = {}
+ for result in results:
+ merged.update(result)
+
+ defer.returnValue({
+ server_name: keys
+ for server_name, keys in merged.items()
+ if keys
+ })
+
+ @defer.inlineCallbacks
+ def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
+ perspective_name,
+ perspective_keys):
+ limiter = yield get_retry_limiter(
+ perspective_name, self.clock, self.store
+ )
+
+ with limiter:
+ # TODO(mark): Set the minimum_valid_until_ts to that needed by
+ # the events being validated or the current time if validating
+ # an incoming request.
+ query_response = yield self.client.post_json(
+ destination=perspective_name,
+ path=b"/_matrix/key/v2/query",
+ data={
+ u"server_keys": {
+ server_name: {
+ key_id: {
+ u"minimum_valid_until_ts": 0
+ } for key_id in key_ids
+ }
+ for server_name, key_ids in server_names_and_key_ids
+ }
+ },
+ )
+
+ keys = {}
+
+ responses = query_response["server_keys"]
+
+ for response in responses:
+ if (u"signatures" not in response
+ or perspective_name not in response[u"signatures"]):
+ raise ValueError(
+ "Key response not signed by perspective server"
+ " %r" % (perspective_name,)
+ )
+
+ verified = False
+ for key_id in response[u"signatures"][perspective_name]:
+ if key_id in perspective_keys:
+ verify_signed_json(
+ response,
+ perspective_name,
+ perspective_keys[key_id]
+ )
+ verified = True
+
+ if not verified:
+ logging.info(
+ "Response from perspective server %r not signed with a"
+ " known key, signed with: %r, known keys: %r",
+ perspective_name,
+ list(response[u"signatures"][perspective_name]),
+ list(perspective_keys)
+ )
+ raise ValueError(
+ "Response not signed with a known key for perspective"
+ " server %r" % (perspective_name,)
+ )
+
+ processed_response = yield self.process_v2_response(
+ perspective_name, response
+ )
+
+ for server_name, response_keys in processed_response.items():
+ keys.setdefault(server_name, {}).update(response_keys)
+
+ yield defer.gatherResults(
+ [
+ self.store_keys(
+ server_name=server_name,
+ from_server=perspective_name,
+ verify_keys=response_keys,
+ )
+ for server_name, response_keys in keys.items()
+ ],
+ consumeErrors=True
+ ).addErrback(unwrapFirstError)
+
+ defer.returnValue(keys)
+
+ @defer.inlineCallbacks
+ def get_server_verify_key_v2_direct(self, server_name, key_ids):
+ keys = {}
+
+ for requested_key_id in key_ids:
+ if requested_key_id in keys:
+ continue
+
+ (response, tls_certificate) = yield fetch_server_key(
+ server_name, self.hs.tls_server_context_factory,
+ path=(b"/_matrix/key/v2/server/%s" % (
+ urllib.quote(requested_key_id),
+ )).encode("ascii"),
+ )
+
+ if (u"signatures" not in response
+ or server_name not in response[u"signatures"]):
+ raise ValueError("Key response not signed by remote server")
+
+ if "tls_fingerprints" not in response:
+ raise ValueError("Key response missing TLS fingerprints")
+
+ certificate_bytes = crypto.dump_certificate(
+ crypto.FILETYPE_ASN1, tls_certificate
+ )
+ sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
+ sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
+
+ response_sha256_fingerprints = set()
+ for fingerprint in response[u"tls_fingerprints"]:
+ if u"sha256" in fingerprint:
+ response_sha256_fingerprints.add(fingerprint[u"sha256"])
+
+ if sha256_fingerprint_b64 not in response_sha256_fingerprints:
+ raise ValueError("TLS certificate not allowed by fingerprints")
+
+ response_keys = yield self.process_v2_response(
+ from_server=server_name,
+ requested_ids=[requested_key_id],
+ response_json=response,
+ )
+
+ keys.update(response_keys)
+
+ yield defer.gatherResults(
+ [
+ self.store_keys(
+ server_name=key_server_name,
+ from_server=server_name,
+ verify_keys=verify_keys,
+ )
+ for key_server_name, verify_keys in keys.items()
+ ],
+ consumeErrors=True
+ ).addErrback(unwrapFirstError)
+
+ defer.returnValue(keys)
+
+ @defer.inlineCallbacks
+ def process_v2_response(self, from_server, response_json,
+ requested_ids=[]):
+ time_now_ms = self.clock.time_msec()
+ response_keys = {}
+ verify_keys = {}
+ for key_id, key_data in response_json["verify_keys"].items():
+ if is_signing_algorithm_supported(key_id):
+ key_base64 = key_data["key"]
+ key_bytes = decode_base64(key_base64)
+ verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_key.time_added = time_now_ms
+ verify_keys[key_id] = verify_key
+
+ old_verify_keys = {}
+ for key_id, key_data in response_json["old_verify_keys"].items():
+ if is_signing_algorithm_supported(key_id):
+ key_base64 = key_data["key"]
+ key_bytes = decode_base64(key_base64)
+ verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_key.expired = key_data["expired_ts"]
+ verify_key.time_added = time_now_ms
+ old_verify_keys[key_id] = verify_key
+
+ results = {}
+ server_name = response_json["server_name"]
+ for key_id in response_json["signatures"].get(server_name, {}):
+ if key_id not in response_json["verify_keys"]:
+ raise ValueError(
+ "Key response must include verification keys for all"
+ " signatures"
+ )
+ if key_id in verify_keys:
+ verify_signed_json(
+ response_json,
+ server_name,
+ verify_keys[key_id]
+ )
+
+ signed_key_json = sign_json(
+ response_json,
+ self.config.server_name,
+ self.config.signing_key[0],
+ )
+
+ signed_key_json_bytes = encode_canonical_json(signed_key_json)
+ ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
+
+ updated_key_ids = set(requested_ids)
+ updated_key_ids.update(verify_keys)
+ updated_key_ids.update(old_verify_keys)
+
+ response_keys.update(verify_keys)
+ response_keys.update(old_verify_keys)
+
+ yield defer.gatherResults(
+ [
+ self.store.store_server_keys_json(
+ server_name=server_name,
+ key_id=key_id,
+ from_server=server_name,
+ ts_now_ms=time_now_ms,
+ ts_expires_ms=ts_valid_until_ms,
+ key_json_bytes=signed_key_json_bytes,
+ )
+ for key_id in updated_key_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+
+ results[server_name] = response_keys
+
+ defer.returnValue(results)
+
+ @defer.inlineCallbacks
+ def get_server_verify_key_v1_direct(self, server_name, key_ids):
+ """Finds a verification key for the server with one of the key ids.
+ Args:
+ server_name (str): The name of the server to fetch a key for.
+ keys_ids (list of str): The key_ids to check for.
+ """
+
+ # Try to fetch the key from the remote server.
+
+ (response, tls_certificate) = yield fetch_server_key(
+ server_name, self.hs.tls_server_context_factory
+ )
+
+ # Check the response.
+
+ x509_certificate_bytes = crypto.dump_certificate(
+ crypto.FILETYPE_ASN1, tls_certificate
+ )
+
+ if ("signatures" not in response
+ or server_name not in response["signatures"]):
+ raise ValueError("Key response not signed by remote server")
+
+ if "tls_certificate" not in response:
+ raise ValueError("Key response missing TLS certificate")
+
+ tls_certificate_b64 = response["tls_certificate"]
+
+ if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
+ raise ValueError("TLS certificate doesn't match")
+
+ # Cache the result in the datastore.
+
+ time_now_ms = self.clock.time_msec()
+
+ verify_keys = {}
+ for key_id, key_base64 in response["verify_keys"].items():
+ if is_signing_algorithm_supported(key_id):
+ key_bytes = decode_base64(key_base64)
+ verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_key.time_added = time_now_ms
+ verify_keys[key_id] = verify_key
+
+ for key_id in response["signatures"][server_name]:
+ if key_id not in response["verify_keys"]:
+ raise ValueError(
+ "Key response must include verification keys for all"
+ " signatures"
+ )
+ if key_id in verify_keys:
+ verify_signed_json(
+ response,
+ server_name,
+ verify_keys[key_id]
+ )
+
+ yield self.store.store_server_certificate(
+ server_name,
+ server_name,
+ time_now_ms,
+ tls_certificate,
+ )
+
+ yield self.store_keys(
+ server_name=server_name,
+ from_server=server_name,
+ verify_keys=verify_keys,
+ )
+
+ defer.returnValue(verify_keys)
+
+ @defer.inlineCallbacks
+ def store_keys(self, server_name, from_server, verify_keys):
+ """Store a collection of verify keys for a given server
+ Args:
+ server_name(str): The name of the server the keys are for.
+ from_server(str): The server the keys were downloaded from.
+ verify_keys(dict): A mapping of key_id to VerifyKey.
+ Returns:
+ A deferred that completes when the keys are stored.
+ """
+ # TODO(markjh): Store whether the keys have expired.
+ yield defer.gatherResults(
+ [
+ self.store.store_server_verify_key(
+ server_name, server_name, key.time_added, key
+ )
+ for key_id, key in verify_keys.items()
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)