summaryrefslogtreecommitdiff
path: root/synapse/crypto/keyring.py
diff options
context:
space:
mode:
authorErik Johnston <erikj@matrix.org>2016-08-24 15:05:56 +0100
committerErik Johnston <erikj@matrix.org>2016-08-24 15:05:56 +0100
commitc1c15ad12f8bda0d65778bd03543ad1f14a1cfc2 (patch)
tree1c843a49d3d5168ff998a54f50d30cdc3814f104 /synapse/crypto/keyring.py
parentcfb5c3f91265d2b9423b47cec2b555b39c46bc4b (diff)
Imported Upstream version 0.17.1
Diffstat (limited to 'synapse/crypto/keyring.py')
-rw-r--r--synapse/crypto/keyring.py166
1 files changed, 86 insertions, 80 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 5012c10e..d7211ee9 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -22,6 +22,7 @@ from synapse.util.logcontext import (
preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
preserve_fn
)
+from synapse.util.metrics import Measure
from twisted.internet import defer
@@ -61,6 +62,10 @@ Attributes:
"""
+class KeyLookupError(ValueError):
+ pass
+
+
class Keyring(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -239,59 +244,60 @@ class Keyring(object):
@defer.inlineCallbacks
def do_iterations():
- merged_results = {}
+ with Measure(self.clock, "get_server_verify_keys"):
+ merged_results = {}
- missing_keys = {}
- for verify_request in verify_requests:
- missing_keys.setdefault(verify_request.server_name, set()).update(
- verify_request.key_ids
- )
-
- for fn in key_fetch_fns:
- results = yield fn(missing_keys.items())
- merged_results.update(results)
-
- # We now need to figure out which verify requests we have keys
- # for and which we don't
missing_keys = {}
- requests_missing_keys = []
for verify_request in verify_requests:
- server_name = verify_request.server_name
- result_keys = merged_results[server_name]
-
- if verify_request.deferred.called:
- # We've already called this deferred, which probably
- # means that we've already found a key for it.
- continue
-
- for key_id in verify_request.key_ids:
- if key_id in result_keys:
- with PreserveLoggingContext():
- verify_request.deferred.callback((
- server_name,
- key_id,
- result_keys[key_id],
- ))
- break
- else:
- # The else block is only reached if the loop above
- # doesn't break.
- missing_keys.setdefault(server_name, set()).update(
- verify_request.key_ids
- )
- requests_missing_keys.append(verify_request)
-
- if not missing_keys:
- break
-
- for verify_request in requests_missing_keys.values():
- verify_request.deferred.errback(SynapseError(
- 401,
- "No key for %s with id %s" % (
- verify_request.server_name, verify_request.key_ids,
- ),
- Codes.UNAUTHORIZED,
- ))
+ missing_keys.setdefault(verify_request.server_name, set()).update(
+ verify_request.key_ids
+ )
+
+ for fn in key_fetch_fns:
+ results = yield fn(missing_keys.items())
+ merged_results.update(results)
+
+ # We now need to figure out which verify requests we have keys
+ # for and which we don't
+ missing_keys = {}
+ requests_missing_keys = []
+ for verify_request in verify_requests:
+ server_name = verify_request.server_name
+ result_keys = merged_results[server_name]
+
+ if verify_request.deferred.called:
+ # We've already called this deferred, which probably
+ # means that we've already found a key for it.
+ continue
+
+ for key_id in verify_request.key_ids:
+ if key_id in result_keys:
+ with PreserveLoggingContext():
+ verify_request.deferred.callback((
+ server_name,
+ key_id,
+ result_keys[key_id],
+ ))
+ break
+ else:
+ # The else block is only reached if the loop above
+ # doesn't break.
+ missing_keys.setdefault(server_name, set()).update(
+ verify_request.key_ids
+ )
+ requests_missing_keys.append(verify_request)
+
+ if not missing_keys:
+ break
+
+ for verify_request in requests_missing_keys.values():
+ verify_request.deferred.errback(SynapseError(
+ 401,
+ "No key for %s with id %s" % (
+ verify_request.server_name, verify_request.key_ids,
+ ),
+ Codes.UNAUTHORIZED,
+ ))
def on_err(err):
for verify_request in verify_requests:
@@ -302,15 +308,15 @@ class Keyring(object):
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
- res = yield defer.gatherResults(
+ res = yield preserve_context_over_deferred(defer.gatherResults(
[
- self.store.get_server_verify_keys(
+ preserve_fn(self.store.get_server_verify_keys)(
server_name, key_ids
).addCallback(lambda ks, server: (server, ks), server_name)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)
defer.returnValue(dict(res))
@@ -331,13 +337,13 @@ class Keyring(object):
)
defer.returnValue({})
- results = yield defer.gatherResults(
+ results = yield preserve_context_over_deferred(defer.gatherResults(
[
- get_key(p_name, p_keys)
+ preserve_fn(get_key)(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items()
],
consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)
union_of_keys = {}
for result in results:
@@ -363,7 +369,7 @@ class Keyring(object):
)
except Exception as e:
logger.info(
- "Unable to getting key %r for %r directly: %s %s",
+ "Unable to get key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
)
@@ -377,13 +383,13 @@ class Keyring(object):
defer.returnValue(keys)
- results = yield defer.gatherResults(
+ results = yield preserve_context_over_deferred(defer.gatherResults(
[
- get_key(server_name, key_ids)
+ preserve_fn(get_key)(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)
merged = {}
for result in results:
@@ -425,7 +431,7 @@ class Keyring(object):
for response in responses:
if (u"signatures" not in response
or perspective_name not in response[u"signatures"]):
- raise ValueError(
+ raise KeyLookupError(
"Key response not signed by perspective server"
" %r" % (perspective_name,)
)
@@ -448,7 +454,7 @@ class Keyring(object):
list(response[u"signatures"][perspective_name]),
list(perspective_keys)
)
- raise ValueError(
+ raise KeyLookupError(
"Response not signed with a known key for perspective"
" server %r" % (perspective_name,)
)
@@ -460,9 +466,9 @@ class Keyring(object):
for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys)
- yield defer.gatherResults(
+ yield preserve_context_over_deferred(defer.gatherResults(
[
- self.store_keys(
+ preserve_fn(self.store_keys)(
server_name=server_name,
from_server=perspective_name,
verify_keys=response_keys,
@@ -470,7 +476,7 @@ class Keyring(object):
for server_name, response_keys in keys.items()
],
consumeErrors=True
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)
defer.returnValue(keys)
@@ -491,10 +497,10 @@ class Keyring(object):
if (u"signatures" not in response
or server_name not in response[u"signatures"]):
- raise ValueError("Key response not signed by remote server")
+ raise KeyLookupError("Key response not signed by remote server")
if "tls_fingerprints" not in response:
- raise ValueError("Key response missing TLS fingerprints")
+ raise KeyLookupError("Key response missing TLS fingerprints")
certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate
@@ -508,7 +514,7 @@ class Keyring(object):
response_sha256_fingerprints.add(fingerprint[u"sha256"])
if sha256_fingerprint_b64 not in response_sha256_fingerprints:
- raise ValueError("TLS certificate not allowed by fingerprints")
+ raise KeyLookupError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response(
from_server=server_name,
@@ -518,7 +524,7 @@ class Keyring(object):
keys.update(response_keys)
- yield defer.gatherResults(
+ yield preserve_context_over_deferred(defer.gatherResults(
[
preserve_fn(self.store_keys)(
server_name=key_server_name,
@@ -528,7 +534,7 @@ class Keyring(object):
for key_server_name, verify_keys in keys.items()
],
consumeErrors=True
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)
defer.returnValue(keys)
@@ -560,14 +566,14 @@ class Keyring(object):
server_name = response_json["server_name"]
if only_from_server:
if server_name != from_server:
- raise ValueError(
+ raise KeyLookupError(
"Expected a response for server %r not %r" % (
from_server, server_name
)
)
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
- raise ValueError(
+ raise KeyLookupError(
"Key response must include verification keys for all"
" signatures"
)
@@ -594,7 +600,7 @@ class Keyring(object):
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
- yield defer.gatherResults(
+ yield preserve_context_over_deferred(defer.gatherResults(
[
preserve_fn(self.store.store_server_keys_json)(
server_name=server_name,
@@ -607,7 +613,7 @@ class Keyring(object):
for key_id in updated_key_ids
],
consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)
results[server_name] = response_keys
@@ -635,15 +641,15 @@ class Keyring(object):
if ("signatures" not in response
or server_name not in response["signatures"]):
- raise ValueError("Key response not signed by remote server")
+ raise KeyLookupError("Key response not signed by remote server")
if "tls_certificate" not in response:
- raise ValueError("Key response missing TLS certificate")
+ raise KeyLookupError("Key response missing TLS certificate")
tls_certificate_b64 = response["tls_certificate"]
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
- raise ValueError("TLS certificate doesn't match")
+ raise KeyLookupError("TLS certificate doesn't match")
# Cache the result in the datastore.
@@ -659,7 +665,7 @@ class Keyring(object):
for key_id in response["signatures"][server_name]:
if key_id not in response["verify_keys"]:
- raise ValueError(
+ raise KeyLookupError(
"Key response must include verification keys for all"
" signatures"
)
@@ -696,7 +702,7 @@ class Keyring(object):
A deferred that completes when the keys are stored.
"""
# TODO(markjh): Store whether the keys have expired.
- yield defer.gatherResults(
+ yield preserve_context_over_deferred(defer.gatherResults(
[
preserve_fn(self.store.store_server_verify_key)(
server_name, server_name, key.time_added, key
@@ -704,4 +710,4 @@ class Keyring(object):
for key_id, key in verify_keys.items()
],
consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)