summaryrefslogtreecommitdiff
path: root/synapse/crypto/keyring.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/crypto/keyring.py')
-rw-r--r--synapse/crypto/keyring.py106
1 files changed, 48 insertions, 58 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 341c8631..6c3e885e 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -238,27 +238,9 @@ class Keyring(object):
"""
try:
- # create a deferred for each server we're going to look up the keys
- # for; we'll resolve them once we have completed our lookups.
- # These will be passed into wait_for_previous_lookups to block
- # any other lookups until we have finished.
- # The deferreds are called with no logcontext.
- server_to_deferred = {
- rq.server_name: defer.Deferred() for rq in verify_requests
- }
-
- # We want to wait for any previous lookups to complete before
- # proceeding.
- yield self.wait_for_previous_lookups(server_to_deferred)
+ ctx = LoggingContext.current_context()
- # Actually start fetching keys.
- self._get_server_verify_keys(verify_requests)
-
- # When we've finished fetching all the keys for a given server_name,
- # resolve the deferred passed to `wait_for_previous_lookups` so that
- # any lookups waiting will proceed.
- #
- # map from server name to a set of request ids
+ # map from server name to a set of outstanding request ids
server_to_request_ids = {}
for verify_request in verify_requests:
@@ -266,40 +248,61 @@ class Keyring(object):
request_id = id(verify_request)
server_to_request_ids.setdefault(server_name, set()).add(request_id)
- def remove_deferreds(res, verify_request):
+ # Wait for any previous lookups to complete before proceeding.
+ yield self.wait_for_previous_lookups(server_to_request_ids.keys())
+
+ # take out a lock on each of the servers by sticking a Deferred in
+ # key_downloads
+ for server_name in server_to_request_ids.keys():
+ self.key_downloads[server_name] = defer.Deferred()
+ logger.debug("Got key lookup lock on %s", server_name)
+
+ # When we've finished fetching all the keys for a given server_name,
+ # drop the lock by resolving the deferred in key_downloads.
+ def drop_server_lock(server_name):
+ d = self.key_downloads.pop(server_name)
+ d.callback(None)
+
+ def lookup_done(res, verify_request):
server_name = verify_request.server_name
- request_id = id(verify_request)
- server_to_request_ids[server_name].discard(request_id)
- if not server_to_request_ids[server_name]:
- d = server_to_deferred.pop(server_name, None)
- if d:
- d.callback(None)
+ server_requests = server_to_request_ids[server_name]
+ server_requests.remove(id(verify_request))
+
+ # if there are no more requests for this server, we can drop the lock.
+ if not server_requests:
+ with PreserveLoggingContext(ctx):
+ logger.debug("Releasing key lookup lock on %s", server_name)
+
+ # ... but not immediately, as that can cause stack explosions if
+ # we get a long queue of lookups.
+ self.clock.call_later(0, drop_server_lock, server_name)
+
return res
for verify_request in verify_requests:
- verify_request.key_ready.addBoth(remove_deferreds, verify_request)
+ verify_request.key_ready.addBoth(lookup_done, verify_request)
+
+ # Actually start fetching keys.
+ self._get_server_verify_keys(verify_requests)
except Exception:
logger.exception("Error starting key lookups")
@defer.inlineCallbacks
- def wait_for_previous_lookups(self, server_to_deferred):
+ def wait_for_previous_lookups(self, server_names):
"""Waits for any previous key lookups for the given servers to finish.
Args:
- server_to_deferred (dict[str, Deferred]): server_name to deferred which gets
- resolved once we've finished looking up keys for that server.
- The Deferreds should be regular twisted ones which call their
- callbacks with no logcontext.
-
- Returns: a Deferred which resolves once all key lookups for the given
- servers have completed. Follows the synapse rules of logcontext
- preservation.
+ server_names (Iterable[str]): list of servers which we want to look up
+
+ Returns:
+ Deferred[None]: resolves once all key lookups for the given servers have
+ completed. Follows the synapse rules of logcontext preservation.
"""
loop_count = 1
while True:
wait_on = [
(server_name, self.key_downloads[server_name])
- for server_name in server_to_deferred.keys()
+ for server_name in server_names
if server_name in self.key_downloads
]
if not wait_on:
@@ -314,19 +317,6 @@ class Keyring(object):
loop_count += 1
- ctx = LoggingContext.current_context()
-
- def rm(r, server_name_):
- with PreserveLoggingContext(ctx):
- logger.debug("Releasing key lookup lock on %s", server_name_)
- self.key_downloads.pop(server_name_, None)
- return r
-
- for server_name, deferred in server_to_deferred.items():
- logger.debug("Got key lookup lock on %s", server_name)
- self.key_downloads[server_name] = deferred
- deferred.addBoth(rm, server_name)
-
def _get_server_verify_keys(self, verify_requests):
"""Tries to find at least one key for each verify request
@@ -472,7 +462,7 @@ class StoreKeyFetcher(KeyFetcher):
keys = {}
for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key
- defer.returnValue(keys)
+ return keys
class BaseV2KeyFetcher(object):
@@ -576,7 +566,7 @@ class BaseV2KeyFetcher(object):
).addErrback(unwrapFirstError)
)
- defer.returnValue(verify_keys)
+ return verify_keys
class PerspectivesKeyFetcher(BaseV2KeyFetcher):
@@ -598,7 +588,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
result = yield self.get_server_verify_key_v2_indirect(
keys_to_fetch, key_server
)
- defer.returnValue(result)
+ return result
except KeyLookupError as e:
logger.warning(
"Key lookup failed from %r: %s", key_server.server_name, e
@@ -611,7 +601,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
str(e),
)
- defer.returnValue({})
+ return {}
results = yield make_deferred_yieldable(
defer.gatherResults(
@@ -625,7 +615,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys)
- defer.returnValue(union_of_keys)
+ return union_of_keys
@defer.inlineCallbacks
def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
@@ -711,7 +701,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
perspective_name, time_now_ms, added_keys
)
- defer.returnValue(keys)
+ return keys
def _validate_perspectives_response(self, key_server, response):
"""Optionally check the signature on the result of a /key/query request
@@ -853,7 +843,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
)
keys.update(response_keys)
- defer.returnValue(keys)
+ return keys
@defer.inlineCallbacks