summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorErik Johnston <erikj@matrix.org>2016-10-05 14:44:19 +0100
committerErik Johnston <erikj@matrix.org>2016-10-05 14:44:19 +0100
commit1be0bf0f238f0880a84fccd3111ebfeef0057d5c (patch)
treecc27f059a354379346f9a21ff6b8b661a9407b65
parent7410e9093d9a27b05f0ee313c6017dfa0053de41 (diff)
Imported Upstream version 0.18.1
-rw-r--r--CHANGES.rst31
-rwxr-xr-xscripts/synapse_port_db9
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/api/auth.py30
-rw-r--r--synapse/app/synchrotron.py6
-rw-r--r--synapse/federation/federation_client.py2
-rw-r--r--synapse/handlers/auth.py279
-rw-r--r--synapse/handlers/federation.py15
-rw-r--r--synapse/handlers/initial_sync.py443
-rw-r--r--synapse/handlers/message.py381
-rw-r--r--synapse/handlers/room_list.py3
-rw-r--r--synapse/handlers/typing.py177
-rw-r--r--synapse/rest/client/v1/initial_sync.py5
-rw-r--r--synapse/rest/client/v1/room.py9
-rw-r--r--synapse/server.py5
-rw-r--r--synapse/storage/event_federation.py9
-rw-r--r--synapse/storage/events.py44
-rw-r--r--synapse/storage/prepare_database.py2
-rw-r--r--synapse/storage/schema/delta/36/readd_public_rooms.sql26
-rw-r--r--synapse/storage/state.py52
-rw-r--r--synapse/types.py2
-rw-r--r--tests/handlers/test_typing.py7
-rw-r--r--tests/rest/client/v1/test_typing.py5
-rw-r--r--tests/utils.py9
24 files changed, 942 insertions, 611 deletions
diff --git a/CHANGES.rst b/CHANGES.rst
index 4dcaf117..12abd6cf 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -1,3 +1,32 @@
+Changes in synapse v0.18.1 (2016-10-0)
+======================================
+
+No changes since v0.18.1-rc1
+
+
+Changes in synapse v0.18.1-rc1 (2016-09-30)
+===========================================
+
+Features:
+
+* Add total_room_count_estimate to ``/publicRooms`` (PR #1133)
+
+
+Changes:
+
+* Time out typing over federation (PR #1140)
+* Restructure LDAP authentication (PR #1153)
+
+
+Bug fixes:
+
+* Fix 3pid invites when server is already in the room (PR #1136)
+* Fix upgrading with SQLite taking lots of CPU for a few days
+ after upgrade (PR #1144)
+* Fix upgrading from very old database versions (PR #1145)
+* Fix port script to work with recently added tables (PR #1146)
+
+
Changes in synapse v0.18.0 (2016-09-19)
=======================================
@@ -6,7 +35,7 @@ significantly reduce database size. Synapse will attempt to upgrade the current
data in the background. Servers with large SQLite database may experience
degradation of performance while this upgrade is in progress, therefore you may
want to consider migrating to using Postgres before upgrading very large SQLite
-daabases
+databases
Changes:
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 66c61b01..2cb2eab6 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -39,6 +39,7 @@ BOOLEAN_COLUMNS = {
"event_edges": ["is_state"],
"presence_list": ["accepted"],
"presence_stream": ["currently_active"],
+ "public_room_list_stream": ["visibility"],
}
@@ -71,6 +72,14 @@ APPEND_ONLY_TABLES = [
"event_to_state_groups",
"rejections",
"event_search",
+ "presence_stream",
+ "push_rules_stream",
+ "current_state_resets",
+ "ex_outlier_stream",
+ "cache_invalidation_stream",
+ "public_room_list_stream",
+ "state_group_edges",
+ "stream_ordering_to_exterm",
]
diff --git a/synapse/__init__.py b/synapse/__init__.py
index 41745170..6dbe8fc7 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
-__version__ = "0.18.0"
+__version__ = "0.18.1"
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 98a50f09..e75fd518 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -72,7 +72,7 @@ class Auth(object):
auth_events = {
(e.type, e.state_key): e for e in auth_events.values()
}
- self.check(event, auth_events=auth_events, do_sig_check=False)
+ self.check(event, auth_events=auth_events, do_sig_check=do_sig_check)
def check(self, event, auth_events, do_sig_check=True):
""" Checks if this event is correctly authed.
@@ -91,11 +91,28 @@ class Auth(object):
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
- sender_domain = get_domain_from_id(event.sender)
+ if do_sig_check:
+ sender_domain = get_domain_from_id(event.sender)
+ event_id_domain = get_domain_from_id(event.event_id)
+
+ is_invite_via_3pid = (
+ event.type == EventTypes.Member
+ and event.membership == Membership.INVITE
+ and "third_party_invite" in event.content
+ )
- # Check the sender's domain has signed the event
- if do_sig_check and not event.signatures.get(sender_domain):
- raise AuthError(403, "Event not signed by sending server")
+ # Check the sender's domain has signed the event
+ if not event.signatures.get(sender_domain):
+ # We allow invites via 3pid to have a sender from a different
+ # HS, as the sender must match the sender of the original
+ # 3pid invite. This is checked further down with the
+ # other dedicated membership checks.
+ if not is_invite_via_3pid:
+ raise AuthError(403, "Event not signed by sender's server")
+
+ # Check the event_id's domain has signed the event
+ if not event.signatures.get(event_id_domain):
+ raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
@@ -491,6 +508,9 @@ class Auth(object):
if not invite_event:
return False
+ if invite_event.sender != event.sender:
+ return False
+
if event.user_id != invite_event.user_id:
return False
diff --git a/synapse/app/synchrotron.py b/synapse/app/synchrotron.py
index 46d390fd..64b209ff 100644
--- a/synapse/app/synchrotron.py
+++ b/synapse/app/synchrotron.py
@@ -27,6 +27,8 @@ from synapse.http.server import JsonResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.rest.client.v2_alpha import sync
from synapse.rest.client.v1 import events
+from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
+from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
@@ -37,6 +39,7 @@ from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
+from synapse.replication.slave.storage.room import RoomStore
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
from synapse.storage.engines import create_engine
@@ -74,6 +77,7 @@ class SynchrotronSlavedStore(
SlavedFilteringStore,
SlavedPresenceStore,
SlavedDeviceInboxStore,
+ RoomStore,
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
):
@@ -296,6 +300,8 @@ class SynchrotronServer(HomeServer):
resource = JsonResource(self, canonical_json=False)
sync.register_servlets(self, resource)
events.register_servlets(self, resource)
+ InitialSyncRestServlet(self).register(resource)
+ RoomInitialSyncRestServlet(self).register(resource)
resources.update({
"/_matrix/client/r0": resource,
"/_matrix/client/unstable": resource,
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 06d0320b..94e76b19 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -136,9 +136,7 @@ class FederationClient(FederationBase):
sent_edus_counter.inc()
- # TODO, add errback, etc.
self._transaction_queue.enqueue_edu(edu, key=key)
- return defer.succeed(None)
@log_function
def send_device_messages(self, destination):
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 6986930c..3933ce17 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -31,6 +31,7 @@ import simplejson
try:
import ldap3
+ import ldap3.core.exceptions
except ImportError:
ldap3 = None
pass
@@ -504,6 +505,144 @@ class AuthHandler(BaseHandler):
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
+ def _ldap_simple_bind(self, server, localpart, password):
+ """ Attempt a simple bind with the credentials
+ given by the user against the LDAP server.
+
+ Returns True, LDAP3Connection
+ if the bind was successful
+ Returns False, None
+ if an error occured
+ """
+
+ try:
+ # bind with the the local users ldap credentials
+ bind_dn = "{prop}={value},{base}".format(
+ prop=self.ldap_attributes['uid'],
+ value=localpart,
+ base=self.ldap_base
+ )
+ conn = ldap3.Connection(server, bind_dn, password)
+ logger.debug(
+ "Established LDAP connection in simple bind mode: %s",
+ conn
+ )
+
+ if self.ldap_start_tls:
+ conn.start_tls()
+ logger.debug(
+ "Upgraded LDAP connection in simple bind mode through StartTLS: %s",
+ conn
+ )
+
+ if conn.bind():
+ # GOOD: bind okay
+ logger.debug("LDAP Bind successful in simple bind mode.")
+ return True, conn
+
+ # BAD: bind failed
+ logger.info(
+ "Binding against LDAP failed for '%s' failed: %s",
+ localpart, conn.result['description']
+ )
+ conn.unbind()
+ return False, None
+
+ except ldap3.core.exceptions.LDAPException as e:
+ logger.warn("Error during LDAP authentication: %s", e)
+ return False, None
+
+ def _ldap_authenticated_search(self, server, localpart, password):
+ """ Attempt to login with the preconfigured bind_dn
+ and then continue searching and filtering within
+ the base_dn
+
+ Returns (True, LDAP3Connection)
+ if a single matching DN within the base was found
+ that matched the filter expression, and with which
+ a successful bind was achieved
+
+ The LDAP3Connection returned is the instance that was used to
+ verify the password not the one using the configured bind_dn.
+ Returns (False, None)
+ if an error occured
+ """
+
+ try:
+ conn = ldap3.Connection(
+ server,
+ self.ldap_bind_dn,
+ self.ldap_bind_password
+ )
+ logger.debug(
+ "Established LDAP connection in search mode: %s",
+ conn
+ )
+
+ if self.ldap_start_tls:
+ conn.start_tls()
+ logger.debug(
+ "Upgraded LDAP connection in search mode through StartTLS: %s",
+ conn
+ )
+
+ if not conn.bind():
+ logger.warn(
+ "Binding against LDAP with `bind_dn` failed: %s",
+ conn.result['description']
+ )
+ conn.unbind()
+ return False, None
+
+ # construct search_filter like (uid=localpart)
+ query = "({prop}={value})".format(
+ prop=self.ldap_attributes['uid'],
+ value=localpart
+ )
+ if self.ldap_filter:
+ # combine with the AND expression
+ query = "(&{query}{filter})".format(
+ query=query,
+ filter=self.ldap_filter
+ )
+ logger.debug(
+ "LDAP search filter: %s",
+ query
+ )
+ conn.search(
+ search_base=self.ldap_base,
+ search_filter=query
+ )
+
+ if len(conn.response) == 1:
+ # GOOD: found exactly one result
+ user_dn = conn.response[0]['dn']
+ logger.debug('LDAP search found dn: %s', user_dn)
+
+ # unbind and simple bind with user_dn to verify the password
+ # Note: do not use rebind(), for some reason it did not verify
+ # the password for me!
+ conn.unbind()
+ return self._ldap_simple_bind(server, localpart, password)
+ else:
+ # BAD: found 0 or > 1 results, abort!
+ if len(conn.response) == 0:
+ logger.info(
+ "LDAP search returned no results for '%s'",
+ localpart
+ )
+ else:
+ logger.info(
+ "LDAP search returned too many (%s) results for '%s'",
+ len(conn.response), localpart
+ )
+ conn.unbind()
+ return False, None
+
+ except ldap3.core.exceptions.LDAPException as e:
+ logger.warn("Error during LDAP authentication: %s", e)
+ return False, None
+
@defer.inlineCallbacks
def _check_ldap_password(self, user_id, password):
""" Attempt to authenticate a user against an LDAP Server
@@ -516,106 +655,62 @@ class AuthHandler(BaseHandler):
if not ldap3 or not self.ldap_enabled:
defer.returnValue(False)
- if self.ldap_mode not in LDAPMode.LIST:
- raise RuntimeError(
- 'Invalid ldap mode specified: {mode}'.format(
- mode=self.ldap_mode
- )
- )
+ localpart = UserID.from_string(user_id).localpart
try:
server = ldap3.Server(self.ldap_uri)
logger.debug(
- "Attempting ldap connection with %s",
+ "Attempting LDAP connection with %s",
self.ldap_uri
)
- localpart = UserID.from_string(user_id).localpart
if self.ldap_mode == LDAPMode.SIMPLE:
- # bind with the the local users ldap credentials
- bind_dn = "{prop}={value},{base}".format(
- prop=self.ldap_attributes['uid'],
- value=localpart,
- base=self.ldap_base
+ result, conn = self._ldap_simple_bind(
+ server=server, localpart=localpart, password=password
)
- conn = ldap3.Connection(server, bind_dn, password)
logger.debug(
- "Established ldap connection in simple mode: %s",
+ 'LDAP authentication method simple bind returned: %s (conn: %s)',
+ result,
conn
)
-
- if self.ldap_start_tls:
- conn.start_tls()
- logger.debug(
- "Upgraded ldap connection in simple mode through StartTLS: %s",
- conn
- )
-
- conn.bind()
-
+ if not result:
+ defer.returnValue(False)
elif self.ldap_mode == LDAPMode.SEARCH:
- # connect with preconfigured credentials and search for local user
- conn = ldap3.Connection(
- server,
- self.ldap_bind_dn,
- self.ldap_bind_password
+ result, conn = self._ldap_authenticated_search(
+ server=server, localpart=localpart, password=password
)
logger.debug(
- "Established ldap connection in search mode: %s",
+ 'LDAP auth method authenticated search returned: %s (conn: %s)',
+ result,
conn
)
-
- if self.ldap_start_tls:
- conn.start_tls()
- logger.debug(
- "Upgraded ldap connection in search mode through StartTLS: %s",
- conn
+ if not result:
+ defer.returnValue(False)
+ else:
+ raise RuntimeError(
+ 'Invalid LDAP mode specified: {mode}'.format(
+ mode=self.ldap_mode
)
-
- conn.bind()
-
- # find matching dn
- query = "({prop}={value})".format(
- prop=self.ldap_attributes['uid'],
- value=localpart
)
- if self.ldap_filter:
- query = "(&{query}{filter})".format(
- query=query,
- filter=self.ldap_filter
- )
- logger.debug("ldap search filter: %s", query)
- result = conn.search(self.ldap_base, query)
-
- if result and len(conn.response) == 1:
- # found exactly one result
- user_dn = conn.response[0]['dn']
- logger.debug('ldap search found dn: %s', user_dn)
-
- # unbind and reconnect, rebind with found dn
- conn.unbind()
- conn = ldap3.Connection(
- server,
- user_dn,
- password,
- auto_bind=True
- )
- else:
- # found 0 or > 1 results, abort!
- logger.warn(
- "ldap search returned unexpected (%d!=1) amount of results",
- len(conn.response)
- )
- defer.returnValue(False)
- logger.info(
- "User authenticated against ldap server: %s",
- conn
- )
+ try:
+ logger.info(
+ "User authenticated against LDAP server: %s",
+ conn
+ )
+ except NameError:
+ logger.warn("Authentication method yielded no LDAP connection, aborting!")
+ defer.returnValue(False)
+
+ # check if user with user_id exists
+ if (yield self.check_user_exists(user_id)):
+ # exists, authentication complete
+ conn.unbind()
+ defer.returnValue(True)
- # check for existing account, if none exists, create one
- if not (yield self.check_user_exists(user_id)):
- # query user metadata for account creation
+ else:
+ # does not exist, fetch metadata for account creation from
+ # existing ldap connection
query = "({prop}={value})".format(
prop=self.ldap_attributes['uid'],
value=localpart
@@ -626,9 +721,12 @@ class AuthHandler(BaseHandler):
filter=query,
user_filter=self.ldap_filter
)
- logger.debug("ldap registration filter: %s", query)
+ logger.debug(
+ "ldap registration filter: %s",
+ query
+ )
- result = conn.search(
+ conn.search(
search_base=self.ldap_base,
search_filter=query,
attributes=[
@@ -651,20 +749,27 @@ class AuthHandler(BaseHandler):
# TODO: bind email, set displayname with data from ldap directory
logger.info(
- "ldap registration successful: %d: %s (%s, %)",
+ "Registration based on LDAP data was successful: %d: %s (%s, %)",
user_id,
localpart,
name,
mail
)
+
+ defer.returnValue(True)
else:
- logger.warn(
- "ldap registration failed: unexpected (%d!=1) amount of results",
- len(conn.response)
- )
+ if len(conn.response) == 0:
+ logger.warn("LDAP registration failed, no result.")
+ else:
+ logger.warn(
+ "LDAP registration failed, too many results (%s)",
+ len(conn.response)
+ )
+
defer.returnValue(False)
- defer.returnValue(True)
+ defer.returnValue(False)
+
except ldap3.core.exceptions.LDAPException as e:
logger.warn("Error during ldap authentication: %s", e)
defer.returnValue(False)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index f7cb3c1b..2d801bad 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1922,15 +1922,18 @@ class FederationHandler(BaseHandler):
original_invite = yield self.store.get_event(
original_invite_id, allow_none=True
)
- if not original_invite:
+ if original_invite:
+ display_name = original_invite.content["display_name"]
+ event_dict["content"]["third_party_invite"]["display_name"] = display_name
+ else:
logger.info(
- "Could not find invite event for third_party_invite - "
- "discarding: %s" % (event_dict,)
+ "Could not find invite event for third_party_invite: %r",
+ event_dict
)
- return
+ # We don't discard here as this is not the appropriate place to do
+ # auth checks. If we need the invite and don't have it then the
+ # auth check code will explode appropriately.
- display_name = original_invite.content["display_name"]
- event_dict["content"]["third_party_invite"]["display_name"] = display_name
builder = self.event_builder_factory.new(event_dict)
EventValidator().validate_new(builder)
message_handler = self.hs.get_handlers().message_handler
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
new file mode 100644
index 00000000..fbfa5a02
--- /dev/null
+++ b/synapse/handlers/initial_sync.py
@@ -0,0 +1,443 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.errors import AuthError, Codes
+from synapse.events.utils import serialize_event
+from synapse.events.validator import EventValidator
+from synapse.streams.config import PaginationConfig
+from synapse.types import (
+ UserID, StreamToken,
+)
+from synapse.util import unwrapFirstError
+from synapse.util.async import concurrently_execute
+from synapse.util.caches.snapshot_cache import SnapshotCache
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.visibility import filter_events_for_client
+
+from ._base import BaseHandler
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class InitialSyncHandler(BaseHandler):
+ def __init__(self, hs):
+ super(InitialSyncHandler, self).__init__(hs)
+ self.hs = hs
+ self.state = hs.get_state_handler()
+ self.clock = hs.get_clock()
+ self.validator = EventValidator()
+ self.snapshot_cache = SnapshotCache()
+
+ def snapshot_all_rooms(self, user_id=None, pagin_config=None,
+ as_client_event=True, include_archived=False):
+ """Retrieve a snapshot of all rooms the user is invited or has joined.
+
+ This snapshot may include messages for all rooms where the user is
+ joined, depending on the pagination config.
+
+ Args:
+ user_id (str): The ID of the user making the request.
+ pagin_config (synapse.api.streams.PaginationConfig): The pagination
+ config used to determine how many messages *PER ROOM* to return.
+ as_client_event (bool): True to get events in client-server format.
+ include_archived (bool): True to get rooms that the user has left
+ Returns:
+ A list of dicts with "room_id" and "membership" keys for all rooms
+ the user is currently invited or joined in on. Rooms where the user
+ is joined on, may return a "messages" key with messages, depending
+ on the specified PaginationConfig.
+ """
+ key = (
+ user_id,
+ pagin_config.from_token,
+ pagin_config.to_token,
+ pagin_config.direction,
+ pagin_config.limit,
+ as_client_event,
+ include_archived,
+ )
+ now_ms = self.clock.time_msec()
+ result = self.snapshot_cache.get(now_ms, key)
+ if result is not None:
+ return result
+
+ return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms(
+ user_id, pagin_config, as_client_event, include_archived
+ ))
+
+ @defer.inlineCallbacks
+ def _snapshot_all_rooms(self, user_id=None, pagin_config=None,
+ as_client_event=True, include_archived=False):
+
+ memberships = [Membership.INVITE, Membership.JOIN]
+ if include_archived:
+ memberships.append(Membership.LEAVE)
+
+ room_list = yield self.store.get_rooms_for_user_where_membership_is(
+ user_id=user_id, membership_list=memberships
+ )
+
+ user = UserID.from_string(user_id)
+
+ rooms_ret = []
+
+ now_token = yield self.hs.get_event_sources().get_current_token()
+
+ presence_stream = self.hs.get_event_sources().sources["presence"]
+ pagination_config = PaginationConfig(from_token=now_token)
+ presence, _ = yield presence_stream.get_pagination_rows(
+ user, pagination_config.get_source_config("presence"), None
+ )
+
+ receipt_stream = self.hs.get_event_sources().sources["receipt"]
+ receipt, _ = yield receipt_stream.get_pagination_rows(
+ user, pagination_config.get_source_config("receipt"), None
+ )
+
+ tags_by_room = yield self.store.get_tags_for_user(user_id)
+
+ account_data, account_data_by_room = (
+ yield self.store.get_account_data_for_user(user_id)
+ )
+
+ public_room_ids = yield self.store.get_public_room_ids()
+
+ limit = pagin_config.limit
+ if limit is None:
+ limit = 10
+
+ @defer.inlineCallbacks
+ def handle_room(event):
+ d = {
+ "room_id": event.room_id,
+ "membership": event.membership,
+ "visibility": (
+ "public" if event.room_id in public_room_ids
+ else "private"
+ ),
+ }
+
+ if event.membership == Membership.INVITE:
+ time_now = self.clock.time_msec()
+ d["inviter"] = event.sender
+
+ invite_event = yield self.store.get_event(event.event_id)
+ d["invite"] = serialize_event(invite_event, time_now, as_client_event)
+
+ rooms_ret.append(d)
+
+ if event.membership not in (Membership.JOIN, Membership.LEAVE):
+ return
+
+ try:
+ if event.membership == Membership.JOIN:
+ room_end_token = now_token.room_key
+ deferred_room_state = self.state_handler.get_current_state(
+ event.room_id
+ )
+ elif event.membership == Membership.LEAVE:
+ room_end_token = "s%d" % (event.stream_ordering,)
+ deferred_room_state = self.store.get_state_for_events(
+ [event.event_id], None
+ )
+ deferred_room_state.addCallback(
+ lambda states: states[event.event_id]
+ )
+
+ (messages, token), current_state = yield preserve_context_over_deferred(
+ defer.gatherResults(
+ [
+ preserve_fn(self.store.get_recent_events_for_room)(
+ event.room_id,
+ limit=limit,
+ end_token=room_end_token,
+ ),
+ deferred_room_state,
+ ]
+ )
+ ).addErrback(unwrapFirstError)
+
+ messages = yield filter_events_for_client(
+ self.store, user_id, messages
+ )
+
+ start_token = now_token.copy_and_replace("room_key", token[0])
+ end_token = now_token.copy_and_replace("room_key", token[1])
+ time_now = self.clock.time_msec()
+
+ d["messages"] = {
+ "chunk": [
+ serialize_event(m, time_now, as_client_event)
+ for m in messages
+ ],
+ "start": start_token.to_string(),
+ "end": end_token.to_string(),
+ }
+
+ d["state"] = [
+ serialize_event(c, time_now, as_client_event)
+ for c in current_state.values()
+ ]
+
+ account_data_events = []
+ tags = tags_by_room.get(event.room_id)
+ if tags:
+ account_data_events.append({
+ "type": "m.tag",
+ "content": {"tags": tags},
+ })
+
+ account_data = account_data_by_room.get(event.room_id, {})
+ for account_data_type, content in account_data.items():
+ account_data_events.append({
+ "type": account_data_type,
+ "content": content,
+ })
+
+ d["account_data"] = account_data_events
+ except:
+ logger.exception("Failed to get snapshot")
+
+ yield concurrently_execute(handle_room, room_list, 10)
+
+ account_data_events = []
+ for account_data_type, content in account_data.items():
+ account_data_events.append({
+ "type": account_data_type,
+ "content": content,
+ })
+
+ ret = {
+ "rooms": rooms_ret,
+ "presence": presence,
+ "account_data": account_data_events,
+ "receipts": receipt,
+ "end": now_token.to_string(),
+ }
+
+ defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def room_initial_sync(self, requester, room_id, pagin_config=None):
+ """Capture the a snapshot of a room. If user is currently a member of
+ the room this will be what is currently in the room. If the user left
+ the room this will be what was in the room when they left.
+
+ Args:
+ requester(Requester): The user to get a snapshot for.
+ room_id(str): The room to get a snapshot of.
+ pagin_config(synapse.streams.config.PaginationConfig):
+ The pagination config used to determine how many messages to
+ return.
+ Raises:
+ AuthError if the user wasn't in the room.
+ Returns:
+ A JSON serialisable dict with the snapshot of the room.
+ """
+
+ user_id = requester.user.to_string()
+
+ membership, member_event_id = yield self._check_in_room_or_world_readable(
+ room_id, user_id,
+ )
+ is_peeking = member_event_id is None
+
+ if membership == Membership.JOIN:
+ result = yield self._room_initial_sync_joined(
+ user_id, room_id, pagin_config, membership, is_peeking
+ )
+ elif membership == Membership.LEAVE:
+ result = yield self._room_initial_sync_parted(
+ user_id, room_id, pagin_config, membership, member_event_id, is_peeking
+ )
+
+ account_data_events = []
+ tags = yield self.store.get_tags_for_room(user_id, room_id)
+ if tags:
+ account_data_events.append({
+ "type": "m.tag",
+ "content": {"tags": tags},
+ })
+
+ account_data = yield self.store.get_account_data_for_room(user_id, room_id)
+ for account_data_type, content in account_data.items():
+ account_data_events.append({
+ "type": account_data_type,
+ "content": content,
+ })
+
+ result["account_data"] = account_data_events
+
+ defer.returnValue(result)
+
+ @defer.inlineCallbacks
+ def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
+ membership, member_event_id, is_peeking):
+ room_state = yield self.store.get_state_for_events(
+ [member_event_id], None
+ )
+
+ room_state = room_state[member_event_id]
+
+ limit = pagin_config.limit if pagin_config else None
+ if limit is None:
+ limit = 10
+
+ stream_token = yield self.store.get_stream_token_for_event(
+ member_event_id
+ )
+
+ messages, token = yield self.store.get_recent_events_for_room(
+ room_id,
+ limit=limit,
+ end_token=stream_token
+ )
+
+ messages = yield filter_events_for_client(
+ self.store, user_id, messages, is_peeking=is_peeking
+ )
+
+ start_token = StreamToken.START.copy_and_replace("room_key", token[0])
+ end_token = StreamToken.START.copy_and_replace("room_key", token[1])
+
+ time_now = self.clock.time_msec()
+
+ defer.returnValue({
+ "membership": membership,
+ "room_id": room_id,
+ "messages": {
+ "chunk": [serialize_event(m, time_now) for m in messages],
+ "start": start_token.to_string(),
+ "end": end_token.to_string(),
+ },
+ "state": [serialize_event(s, time_now) for s in room_state.values()],
+ "presence": [],
+ "receipts": [],
+ })
+
+ @defer.inlineCallbacks
+ def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
+ membership, is_peeking):
+ current_state = yield self.state.get_current_state(
+ room_id=room_id,
+ )
+
+ # TODO: These concurrently
+ time_now = self.clock.time_msec()
+ state = [
+ serialize_event(x, time_now)
+ for x in current_state.values()
+ ]
+
+ now_token = yield self.hs.get_event_sources().get_current_token()
+
+ limit = pagin_config.limit if pagin_config else None
+ if limit is None:
+ limit = 10
+
+ room_members = [
+ m for m in current_state.values()
+ if m.type == EventTypes.Member
+ and m.content["membership"] == Membership.JOIN
+ ]
+
+ presence_handler = self.hs.get_presence_handler()
+
+ @defer.inlineCallbacks
+ def get_presence():
+ states = yield presence_handler.get_states(
+ [m.user_id for m in room_members],
+ as_event=True,
+ )
+
+ defer.returnValue(states)
+
+ @defer.inlineCallbacks
+ def get_receipts():
+ receipts_handler = self.hs.get_handlers().receipts_handler
+ receipts = yield receipts_handler.get_receipts_for_room(
+ room_id,
+ now_token.receipt_key
+ )
+ defer.returnValue(receipts)
+
+ presence, receipts, (messages, token) = yield defer.gatherResults(
+ [
+ preserve_fn(get_presence)(),
+ preserve_fn(get_receipts)(),
+ preserve_fn(self.store.get_recent_events_for_room)(
+ room_id,
+ limit=limit,
+ end_token=now_token.room_key,
+ )
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+
+ messages = yield filter_events_for_client(
+ self.store, user_id, messages, is_peeking=is_peeking,
+ )
+
+ start_token = now_token.copy_and_replace("room_key", token[0])
+ end_token = now_token.copy_and_replace("room_key", token[1])
+
+ time_now = self.clock.time_msec()
+
+ ret = {
+ "room_id": room_id,
+ "messages": {
+ "chunk": [serialize_event(m, time_now) for m in messages],
+ "start": start_token.to_string(),
+ "end": end_token.to_string(),
+ },
+ "state": state,
+ "presence": presence,
+ "receipts": receipts,
+ }
+ if not is_peeking:
+ ret["membership"] = membership
+
+ defer.returnValue(ret)
+
+ @defer.inlineCallbacks
+ def _check_in_room_or_world_readable(self, room_id, user_id):
+ try:
+ # check_user_was_in_room will return the most recent membership
+ # event for the user if:
+ # * The user is a non-guest user, and was ever in the room
+ # * The user is a guest user, and has joined the room
+ # else it will throw.
+ member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
+ defer.returnValue((member_event.membership, member_event.event_id))
+ return
+ except AuthError:
+ visibility = yield self.state_handler.get_current_state(
+ room_id, EventTypes.RoomHistoryVisibility, ""
+ )
+ if (
+ visibility and
+ visibility.content["history_visibility"] == "world_readable"
+ ):
+ defer.returnValue((Membership.JOIN, None))
+ return
+ raise AuthError(
+ 403, "Guest access not allowed", errcode=Codes.GUEST_ACCESS_FORBIDDEN
+ )
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 178209a2..30ea9630 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -21,14 +21,11 @@ from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.push.action_generator import ActionGenerator
-from synapse.streams.config import PaginationConfig
from synapse.types import (
- UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id
+ UserID, RoomAlias, RoomStreamToken, get_domain_from_id
)
-from synapse.util import unwrapFirstError
-from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock
-from synapse.util.caches.snapshot_cache import SnapshotCache
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.async import run_on_reactor, ReadWriteLock
+from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
@@ -49,7 +46,6 @@ class MessageHandler(BaseHandler):
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
- self.snapshot_cache = SnapshotCache()
self.pagination_lock = ReadWriteLock()
@@ -392,377 +388,6 @@ class MessageHandler(BaseHandler):
[serialize_event(c, now) for c in room_state.values()]
)
- def snapshot_all_rooms(self, user_id=None, pagin_config=None,
- as_client_event=True, include_archived=False):
- """Retrieve a snapshot of all rooms the user is invited or has joined.
-
- This snapshot may include messages for all rooms where the user is
- joined, depending on the pagination config.
-
- Args:
- user_id (str): The ID of the user making the request.
- pagin_config (synapse.api.streams.PaginationConfig): The pagination
- config used to determine how many messages *PER ROOM* to return.
- as_client_event (bool): True to get events in client-server format.
- include_archived (bool): True to get rooms that the user has left
- Returns:
- A list of dicts with "room_id" and "membership" keys for all rooms
- the user is currently invited or joined in on. Rooms where the user
- is joined on, may return a "messages" key with messages, depending
- on the specified PaginationConfig.
- """
- key = (
- user_id,
- pagin_config.from_token,
- pagin_config.to_token,
- pagin_config.direction,
- pagin_config.limit,
- as_client_event,
- include_archived,
- )
- now_ms = self.clock.time_msec()
- result = self.snapshot_cache.get(now_ms, key)
- if result is not None:
- return result
-
- return self.snapshot_cache.set(now_ms, key, self._snapshot_all_rooms(
- user_id, pagin_config, as_client_event, include_archived
- ))
-
- @defer.inlineCallbacks
- def _snapshot_all_rooms(self, user_id=None, pagin_config=None,
- as_client_event=True, include_archived=False):
-
- memberships = [Membership.INVITE, Membership.JOIN]
- if include_archived:
- memberships.append(Membership.LEAVE)
-
- room_list = yield self.store.get_rooms_for_user_where_membership_is(
- user_id=user_id, membership_list=memberships
- )
-
- user = UserID.from_string(user_id)
-
- rooms_ret = []
-
- now_token = yield self.hs.get_event_sources().get_current_token()
-
- presence_stream = self.hs.get_event_sources().sources["presence"]
- pagination_config = PaginationConfig(from_token=now_token)
- presence, _ = yield presence_stream.get_pagination_rows(
- user, pagination_config.get_source_config("presence"), None
- )
-
- receipt_stream = self.hs.get_event_sources().sources["receipt"]
- receipt, _ = yield receipt_stream.get_pagination_rows(
- user, pagination_config.get_source_config("receipt"), None
- )
-
- tags_by_room = yield self.store.get_tags_for_user(user_id)
-
- account_data, account_data_by_room = (
- yield self.store.get_account_data_for_user(user_id)
- )
-
- public_room_ids = yield self.store.get_public_room_ids()
-
- limit = pagin_config.limit
- if limit is None:
- limit = 10
-
- @defer.inlineCallbacks
- def handle_room(event):
- d = {
- "room_id": event.room_id,
- "membership": event.membership,
- "visibility": (
- "public" if event.room_id in public_room_ids
- else "private"
- ),
- }
-
- if event.membership == Membership.INVITE:
- time_now = self.clock.time_msec()
- d["inviter"] = event.sender
-
- invite_event = yield self.store.get_event(event.event_id)
- d["invite"] = serialize_event(invite_event, time_now, as_client_event)
-
- rooms_ret.append(d)
-
- if event.membership not in (Membership.JOIN, Membership.LEAVE):
- return
-
- try:
- if event.membership == Membership.JOIN:
- room_end_token = now_token.room_key
- deferred_room_state = self.state_handler.get_current_state(
- event.room_id
- )
- elif event.membership == Membership.LEAVE:
- room_end_token = "s%d" % (event.stream_ordering,)
- deferred_room_state = self.store.get_state_for_events(
- [event.event_id], None
- )
- deferred_room_state.addCallback(
- lambda states: states[event.event_id]
- )
-
- (messages, token), current_state = yield preserve_context_over_deferred(
- defer.gatherResults(
- [
- preserve_fn(self.store.get_recent_events_for_room)(
- event.room_id,
- limit=limit,
- end_token=room_end_token,
- ),
- deferred_room_state,
- ]
- )
- ).addErrback(unwrapFirstError)
-
- messages = yield filter_events_for_client(
- self.store, user_id, messages
- )
-
- start_token = now_token.copy_and_replace("room_key", token[0])
- end_token = now_token.copy_and_replace("room_key", token[1])
- time_now = self.clock.time_msec()
-
- d["messages"] = {
- "chunk": [
- serialize_event(m, time_now, as_client_event)
- for m in messages
- ],
- "start": start_token.to_string(),
- "end": end_token.to_string(),
- }
-
- d["state"] = [
- serialize_event(c, time_now, as_client_event)
- for c in current_state.values()
- ]
-
- account_data_events = []
- tags = tags_by_room.get(event.room_id)
- if tags:
- account_data_events.append({
- "type": "m.tag",
- "content": {"tags": tags},
- })
-
- account_data = account_data_by_room.get(event.room_id, {})
- for account_data_type, content in account_data.items():
- account_data_events.append({
- "type": account_data_type,
- "content": content,
- })
-
- d["account_data"] = account_data_events
- except:
- logger.exception("Failed to get snapshot")
-
- yield concurrently_execute(handle_room, room_list, 10)
-
- account_data_events = []
- for account_data_type, content in account_data.items():
- account_data_events.append({
- "type": account_data_type,
- "content": content,
- })
-
- ret = {
- "rooms": rooms_ret,
- "presence": presence,
- "account_data": account_data_events,
- "receipts": receipt,
- "end": now_token.to_string(),
- }
-
- defer.returnValue(ret)
-
- @defer.inlineCallbacks
- def room_initial_sync(self, requester, room_id, pagin_config=None):
- """Capture the a snapshot of a room. If user is currently a member of
- the room this will be what is currently in the room. If the user left
- the room this will be what was in the room when they left.
-
- Args:
- requester(Requester): The user to get a snapshot for.
- room_id(str): The room to get a snapshot of.
- pagin_config(synapse.streams.config.PaginationConfig):
- The pagination config used to determine how many messages to
- return.
- Raises:
- AuthError if the user wasn't in the room.
- Returns:
- A JSON serialisable dict with the snapshot of the room.
- """
-
- user_id = requester.user.to_string()
-
- membership, member_event_id = yield self._check_in_room_or_world_readable(
- room_id, user_id,
- )
- is_peeking = member_event_id is None
-
- if membership == Membership.JOIN:
- result = yield self._room_initial_sync_joined(
- user_id, room_id, pagin_config, membership, is_peeking
- )
- elif membership == Membership.LEAVE:
- result = yield self._room_initial_sync_parted(
- user_id, room_id, pagin_config, membership, member_event_id, is_peeking
- )
-
- account_data_events = []
- tags = yield self.store.get_tags_for_room(user_id, room_id)
- if tags:
- account_data_events.append({
- "type": "m.tag",
- "content": {"tags": tags},
- })
-
- account_data = yield self.store.get_account_data_for_room(user_id, room_id)
- for account_data_type, content in account_data.items():
- account_data_events.append({
- "type": account_data_type,
- "content": content,
- })
-
- result["account_data"] = account_data_events
-
- defer.returnValue(result)
-
- @defer.inlineCallbacks
- def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
- membership, member_event_id, is_peeking):
- room_state = yield self.store.get_state_for_events(
- [member_event_id], None
- )
-
- room_state = room_state[member_event_id]
-
- limit = pagin_config.limit if pagin_config else None
- if limit is None:
- limit = 10
-
- stream_token = yield self.store.get_stream_token_for_event(
- member_event_id
- )
-
- messages, token = yield self.store.get_recent_events_for_room(
- room_id,
- limit=limit,
- end_token=stream_token
- )
-
- messages = yield filter_events_for_client(
- self.store, user_id, messages, is_peeking=is_peeking
- )
-
- start_token = StreamToken.START.copy_and_replace("room_key", token[0])
- end_token = StreamToken.START.copy_and_replace("room_key", token[1])
-
- time_now = self.clock.time_msec()
-
- defer.returnValue({
- "membership": membership,
- "room_id": room_id,
- "messages": {
- "chunk": [serialize_event(m, time_now) for m in messages],
- "start": start_token.to_string(),
- "end": end_token.to_string(),
- },
- "state": [serialize_event(s, time_now) for s in room_state.values()],
- "presence": [],
- "receipts": [],
- })
-
- @defer.inlineCallbacks
- def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
- membership, is_peeking):
- current_state = yield self.state.get_current_state(
- room_id=room_id,
- )
-
- # TODO: These concurrently
- time_now = self.clock.time_msec()
- state = [
- serialize_event(x, time_now)
- for x in current_state.values()
- ]
-
- now_token = yield self.hs.get_event_sources().get_current_token()
-
- limit = pagin_config.limit if pagin_config else None
- if limit is None:
- limit = 10
-
- room_members = [
- m for m in current_state.values()
- if m.type == EventTypes.Member
- and m.content["membership"] == Membership.JOIN
- ]
-
- presence_handler = self.hs.get_presence_handler()
-
- @defer.inlineCallbacks
- def get_presence():
- states = yield presence_handler.get_states(
- [m.user_id for m in room_members],
- as_event=True,
- )
-
- defer.returnValue(states)
-
- @defer.inlineCallbacks
- def get_receipts():
- receipts_handler = self.hs.get_handlers().receipts_handler
- receipts = yield receipts_handler.get_receipts_for_room(
- room_id,
- now_token.receipt_key
- )
- defer.returnValue(receipts)
-
- presence, receipts, (messages, token) = yield defer.gatherResults(
- [
- preserve_fn(get_presence)(),
- preserve_fn(get_receipts)(),
- preserve_fn(self.store.get_recent_events_for_room)(
- room_id,
- limit=limit,
- end_token=now_token.room_key,
- )
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
-
- messages = yield filter_events_for_client(
- self.store, user_id, messages, is_peeking=is_peeking,
- )
-
- start_token = now_token.copy_and_replace("room_key", token[0])
- end_token = now_token.copy_and_replace("room_key", token[1])
-
- time_now = self.clock.time_msec()
-
- ret = {
- "room_id": room_id,
- "messages": {
- "chunk": [serialize_event(m, time_now) for m in messages],
- "start": start_token.to_string(),
- "end": end_token.to_string(),
- },
- "state": state,
- "presence": presence,
- "receipts": receipts,
- }
- if not is_peeking:
- ret["membership"] = membership
-
- defer.returnValue(ret)
-
@measure_func("_create_new_client_event")
@defer.inlineCallbacks
def _create_new_client_event(self, builder, prev_event_ids=None):
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 5a533682..b04aea01 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -125,6 +125,8 @@ class RoomListHandler(BaseHandler):
if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0
]
+ total_room_count = len(rooms_to_scan)
+
if since_token:
# Filter out rooms we've already returned previously
# `since_token.current_limit` is the index of the last room we
@@ -188,6 +190,7 @@ class RoomListHandler(BaseHandler):
results = {
"chunk": chunk,
+ "total_room_count_estimate": total_room_count,
}
if since_token:
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 0548b81c..08313417 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -16,10 +16,9 @@
from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError
-from synapse.util.logcontext import (
- PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
-)
+from synapse.util.logcontext import preserve_fn
from synapse.util.metrics import Measure
+from synapse.util.wheel_timer import WheelTimer
from synapse.types import UserID, get_domain_from_id
import logging
@@ -35,6 +34,13 @@ logger = logging.getLogger(__name__)
RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
+# How often we expect remote servers to resend us presence.
+FEDERATION_TIMEOUT = 60 * 1000
+
+# How often to resend typing across federation.
+FEDERATION_PING_INTERVAL = 40 * 1000
+
+
class TypingHandler(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -44,7 +50,10 @@ class TypingHandler(object):
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
+ self.hs = hs
+
self.clock = hs.get_clock()
+ self.wheel_timer = WheelTimer(bucket_size=5000)
self.federation = hs.get_replication_layer()
@@ -53,7 +62,7 @@ class TypingHandler(object):
hs.get_distributor().observe("user_left_room", self.user_left_room)
self._member_typing_until = {} # clock time we expect to stop
- self._member_typing_timer = {} # deferreds to manage theabove
+ self._member_last_federation_poke = {}
# map room IDs to serial numbers
self._room_serials = {}
@@ -61,12 +70,41 @@ class TypingHandler(object):
# map room IDs to sets of users currently typing
self._room_typing = {}
- def tearDown(self):
- """Cancels all the pending timers.
- Normally this shouldn't be needed, but it's required from unit tests
- to avoid a "Reactor was unclean" warning."""
- for t in self._member_typing_timer.values():
- self.clock.cancel_call_later(t)
+ self.clock.looping_call(
+ self._handle_timeouts,
+ 5000,
+ )
+
+ def _handle_timeouts(self):
+ logger.info("Checking for typing timeouts")
+
+ now = self.clock.time_msec()
+
+ members = set(self.wheel_timer.fetch(now))
+
+ for member in members:
+ if not self.is_typing(member):
+ # Nothing to do if they're no longer typing
+ continue
+
+ until = self._member_typing_until.get(member, None)
+ if not until or until < now:
+ logger.info("Timing out typing for: %s", member.user_id)
+ preserve_fn(self._stopped_typing)(member)
+ continue
+
+ # Check if we need to resend a keep alive over federation for this
+ # user.
+ if self.hs.is_mine_id(member.user_id):
+ last_fed_poke = self._member_last_federation_poke.get(member, None)
+ if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL < now:
+ preserve_fn(self._push_remote)(
+ member=member,
+ typing=True
+ )
+
+ def is_typing(self, member):
+ return member.user_id in self._room_typing.get(member.room_id, [])
@defer.inlineCallbacks
def started_typing(self, target_user, auth_user, room_id, timeout):
@@ -85,23 +123,17 @@ class TypingHandler(object):
"%s has started typing in %s", target_user_id, room_id
)
- until = self.clock.time_msec() + timeout
member = RoomMember(room_id=room_id, user_id=target_user_id)
- was_present = member in self._member_typing_until
-
- if member in self._member_typing_timer:
- self.clock.cancel_call_later(self._member_typing_timer[member])
+ was_present = member.user_id in self._room_typing.get(room_id, set())
- def _cb():
- logger.debug(
- "%s has timed out in %s", target_user.to_string(), room_id
- )
- self._stopped_typing(member)
+ now = self.clock.time_msec()
+ self._member_typing_until[member] = now + timeout
- self._member_typing_until[member] = until
- self._member_typing_timer[member] = self.clock.call_later(
- timeout / 1000.0, _cb
+ self.wheel_timer.insert(
+ now=now,
+ obj=member,
+ then=now + timeout,
)
if was_present:
@@ -109,8 +141,7 @@ class TypingHandler(object):
defer.returnValue(None)
yield self._push_update(
- room_id=room_id,
- user_id=target_user_id,
+ member=member,
typing=True,
)
@@ -133,10 +164,6 @@ class TypingHandler(object):
member = RoomMember(room_id=room_id, user_id=target_user_id)
- if member in self._member_typing_timer:
- self.clock.cancel_call_later(self._member_typing_timer[member])
- del self._member_typing_timer[member]
-
yield self._stopped_typing(member)
@defer.inlineCallbacks
@@ -148,57 +175,61 @@ class TypingHandler(object):
@defer.inlineCallbacks
def _stopped_typing(self, member):
- if member not in self._member_typing_until:
+ if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point
defer.returnValue(None)
+ self._member_typing_until.pop(member, None)
+ self._member_last_federation_poke.pop(member, None)
+
yield self._push_update(
- room_id=member.room_id,
- user_id=member.user_id,
+ member=member,
typing=False,
)
- del self._member_typing_until[member]
-
- if member in self._member_typing_timer:
- # Don't cancel it - either it already expired, or the real
- # stopped_typing() will cancel it
- del self._member_typing_timer[member]
+ @defer.inlineCallbacks
+ def _push_update(self, member, typing):
+ if self.hs.is_mine_id(member.user_id):
+ # Only send updates for changes to our own users.
+ yield self._push_remote(member, typing)
+
+ self._push_update_local(
+ member=member,
+ typing=typing
+ )
@defer.inlineCallbacks
- def _push_update(self, room_id, user_id, typing):
- users = yield self.state.get_current_user_in_room(room_id)
- domains = set(get_domain_from_id(u) for u in users)
+ def _push_remote(self, member, typing):
+ users = yield self.state.get_current_user_in_room(member.room_id)
+ self._member_last_federation_poke[member] = self.clock.time_msec()
+
+ now = self.clock.time_msec()
+ self.wheel_timer.insert(
+ now=now,
+ obj=member,
+ then=now + FEDERATION_PING_INTERVAL,
+ )
- deferreds = []
- for domain in domains:
- if domain == self.server_name:
- preserve_fn(self._push_update_local)(
- room_id=room_id,
- user_id=user_id,
- typing=typing
- )
- else:
- deferreds.append(preserve_fn(self.federation.send_edu)(
+ for domain in set(get_domain_from_id(u) for u in users):
+ if domain != self.server_name:
+ self.federation.send_edu(
destination=domain,
edu_type="m.typing",
content={
- "room_id": room_id,
- "user_id": user_id,
+ "room_id": member.room_id,
+ "user_id": member.user_id,
"typing": typing,
},
- key=(room_id, user_id),
- ))
-
- yield preserve_context_over_deferred(
- defer.DeferredList(deferreds, consumeErrors=True)
- )
+ key=member,
+ )
@defer.inlineCallbacks
def _recv_edu(self, origin, content):
room_id = content["room_id"]
user_id = content["user_id"]
+ member = RoomMember(user_id=user_id, room_id=room_id)
+
# Check that the string is a valid user id
user = UserID.from_string(user_id)
@@ -213,26 +244,32 @@ class TypingHandler(object):
domains = set(get_domain_from_id(u) for u in users)
if self.server_name in domains:
+ logger.info("Got typing update from %s: %r", user_id, content)
+ now = self.clock.time_msec()
+ self._member_typing_until[member] = now + FEDERATION_TIMEOUT
+ self.wheel_timer.insert(
+ now=now,
+ obj=member,
+ then=now + FEDERATION_TIMEOUT,
+ )
self._push_update_local(
- room_id=room_id,
- user_id=user_id,
+ member=member,
typing=content["typing"]
)
- def _push_update_local(self, room_id, user_id, typing):
- room_set = self._room_typing.setdefault(room_id, set())
+ def _push_update_local(self, member, typing):
+ room_set = self._room_typing.setdefault(member.room_id, set())
if typing:
- room_set.add(user_id)
+ room_set.add(member.user_id)
else:
- room_set.discard(user_id)
+ room_set.discard(member.user_id)
self._latest_room_serial += 1
- self._room_serials[room_id] = self._latest_room_serial
+ self._room_serials[member.room_id] = self._latest_room_serial
- with PreserveLoggingContext():
- self.notifier.on_new_event(
- "typing_key", self._latest_room_serial, rooms=[room_id]
- )
+ self.notifier.on_new_event(
+ "typing_key", self._latest_room_serial, rooms=[member.room_id]
+ )
def get_all_typing_updates(self, last_id, current_id):
# TODO: Work out a way to do this without scanning the entire state.
diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py
index 113a49e5..478e21ee 100644
--- a/synapse/rest/client/v1/initial_sync.py
+++ b/synapse/rest/client/v1/initial_sync.py
@@ -25,16 +25,15 @@ class InitialSyncRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(InitialSyncRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.initial_sync_handler = hs.get_initial_sync_handler()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
as_client_event = "raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
- handler = self.handlers.message_handler
include_archived = request.args.get("archived", None) == ["true"]
- content = yield handler.snapshot_all_rooms(
+ content = yield self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(),
pagin_config=pagination_config,
as_client_event=as_client_event,
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 45287bf0..010fbc7c 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -456,13 +456,13 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
def __init__(self, hs):
super(RoomInitialSyncRestServlet, self).__init__(hs)
- self.handlers = hs.get_handlers()
+ self.initial_sync_handler = hs.get_initial_sync_handler()
@defer.inlineCallbacks
def on_GET(self, request, room_id):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request)
- content = yield self.handlers.message_handler.room_initial_sync(
+ content = yield self.initial_sync_handler.room_initial_sync(
room_id=room_id,
requester=requester,
pagin_config=pagination_config,
@@ -705,12 +705,15 @@ class RoomTypingRestServlet(ClientV1RestServlet):
yield self.presence_handler.bump_presence_active_time(requester.user)
+ # Limit timeout to stop people from setting silly typing timeouts.
+ timeout = min(content.get("timeout", 30000), 120000)
+
if content["typing"]:
yield self.typing_handler.started_typing(
target_user=target_user,
auth_user=requester.user,
room_id=room_id,
- timeout=content.get("timeout", 30000),
+ timeout=timeout,
)
else:
yield self.typing_handler.stopped_typing(
diff --git a/synapse/server.py b/synapse/server.py
index 69860f3d..374124a1 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -43,6 +43,7 @@ from synapse.handlers.room_list import RoomListHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.handlers.events import EventHandler, EventStreamHandler
+from synapse.handlers.initial_sync import InitialSyncHandler
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.notifier import Notifier
@@ -98,6 +99,7 @@ class HomeServer(object):
'e2e_keys_handler',
'event_handler',
'event_stream_handler',
+ 'initial_sync_handler',
'application_service_api',
'application_service_scheduler',
'application_service_handler',
@@ -228,6 +230,9 @@ class HomeServer(object):
def build_event_stream_handler(self):
return EventStreamHandler(self)
+ def build_initial_sync_handler(self):
+ return InitialSyncHandler(self)
+
def build_event_sources(self):
return EventSources(self)
diff --git a/synapse/storage/event_federation.py b/synapse/storage/event_federation.py
index 3d62451d..53feaa19 100644
--- a/synapse/storage/event_federation.py
+++ b/synapse/storage/event_federation.py
@@ -398,12 +398,11 @@ class EventFederationStore(SQLBaseStore):
sql = ("""
DELETE FROM stream_ordering_to_exterm
WHERE
- (
- SELECT max(stream_ordering) AS stream_ordering
+ room_id IN (
+ SELECT room_id
FROM stream_ordering_to_exterm
- WHERE room_id = stream_ordering_to_exterm.room_id
- ) > ?
- AND stream_ordering < ?
+ WHERE stream_ordering > ?
+ ) AND stream_ordering < ?
""")
txn.execute(
sql,
diff --git a/synapse/storage/events.py b/synapse/storage/events.py
index 6dc46fa5..6cf9d117 100644
--- a/synapse/storage/events.py
+++ b/synapse/storage/events.py
@@ -1355,39 +1355,53 @@ class EventsStore(SQLBaseStore):
min_stream_id = rows[-1][0]
event_ids = [row[1] for row in rows]
- events = self._get_events_txn(txn, event_ids)
+ rows_to_update = []
- rows = []
- for event in events:
- try:
- event_id = event.event_id
- origin_server_ts = event.origin_server_ts
- except (KeyError, AttributeError):
- # If the event is missing a necessary field then
- # skip over it.
- continue
+ chunks = [
+ event_ids[i:i + 100]
+ for i in xrange(0, len(event_ids), 100)
+ ]
+ for chunk in chunks:
+ ev_rows = self._simple_select_many_txn(
+ txn,
+ table="event_json",
+ column="event_id",
+ iterable=chunk,
+ retcols=["event_id", "json"],
+ keyvalues={},
+ )
- rows.append((origin_server_ts, event_id))
+ for row in ev_rows:
+ event_id = row["event_id"]
+ event_json = json.loads(row["json"])
+ try:
+ origin_server_ts = event_json["origin_server_ts"]
+ except (KeyError, AttributeError):
+ # If the event is missing a necessary field then
+ # skip over it.
+ continue
+
+ rows_to_update.append((origin_server_ts, event_id))
sql = (
"UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
)
- for index in range(0, len(rows), INSERT_CLUMP_SIZE):
- clump = rows[index:index + INSERT_CLUMP_SIZE]
+ for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
+ clump = rows_to_update[index:index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = {
"target_min_stream_id_inclusive": target_min_stream_id,
"max_stream_id_exclusive": min_stream_id,
- "rows_inserted": rows_inserted + len(rows)
+ "rows_inserted": rows_inserted + len(rows_to_update)
}
self._background_update_progress_txn(
txn, self.EVENT_ORIGIN_SERVER_TS_NAME, progress
)
- return len(rows)
+ return len(rows_to_update)
result = yield self.runInteraction(
self.EVENT_ORIGIN_SERVER_TS_NAME, reindex_search_txn
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 7efbe51c..08de3cc4 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 35
+SCHEMA_VERSION = 36
dir_path = os.path.abspath(os.path.dirname(__file__))
diff --git a/synapse/storage/schema/delta/36/readd_public_rooms.sql b/synapse/storage/schema/delta/36/readd_public_rooms.sql
new file mode 100644
index 00000000..90d8fd18
--- /dev/null
+++ b/synapse/storage/schema/delta/36/readd_public_rooms.sql
@@ -0,0 +1,26 @@
+/* Copyright 2016 OpenMarket Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Re-add some entries to stream_ordering_to_exterm that were incorrectly deleted
+INSERT INTO stream_ordering_to_exterm (stream_ordering, room_id, event_id)
+ SELECT
+ (SELECT stream_ordering FROM events where event_id = e.event_id) AS stream_ordering,
+ room_id,
+ event_id
+ FROM event_forward_extremities AS e
+ WHERE NOT EXISTS (
+ SELECT room_id FROM stream_ordering_to_exterm AS s
+ WHERE s.room_id = e.room_id
+ );
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 7eb34267..49abf0ac 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -307,6 +307,9 @@ class StateStore(SQLBaseStore):
def _get_state_groups_from_groups_txn(self, txn, groups, types=None):
results = {group: {} for group in groups}
+ if types is not None:
+ types = list(set(types)) # deduplicate types list
+
if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is
# a temporary hack until we can add the right indices in
@@ -375,10 +378,35 @@ class StateStore(SQLBaseStore):
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
- group_tree = [group]
next_group = group
while next_group:
+ # We did this before by getting the list of group ids, and
+ # then passing that list to sqlite to get latest event for
+ # each (type, state_key). However, that was terribly slow
+ # without the right indicies (which we can't add until
+ # after we finish deduping state, which requires this func)
+ args = [next_group]
+ if types:
+ args.extend(i for typ in types for i in typ)
+
+ txn.execute(
+ "SELECT type, state_key, event_id FROM state_groups_state"
+ " WHERE state_group = ? %s" % (where_clause,),
+ args
+ )
+ rows = txn.fetchall()
+ results[group].update({
+ (typ, state_key): event_id
+ for typ, state_key, event_id in rows
+ if (typ, state_key) not in results[group]
+ })
+
+ # If the lengths match then we must have all the types,
+ # so no need to go walk further down the tree.
+ if types is not None and len(results[group]) == len(types):
+ break
+
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
@@ -386,28 +414,6 @@ class StateStore(SQLBaseStore):
retcol="prev_state_group",
allow_none=True,
)
- if next_group:
- group_tree.append(next_group)
-
- sql = ("""
- SELECT type, state_key, event_id FROM state_groups_state
- INNER JOIN (
- SELECT type, state_key, max(state_group) as state_group
- FROM state_groups_state
- WHERE state_group IN (%s) %s
- GROUP BY type, state_key
- ) USING (type, state_key, state_group);
- """) % (",".join("?" for _ in group_tree), where_clause,)
-
- args = list(group_tree)
- if types is not None:
- args.extend([i for typ in types for i in typ])
-
- txn.execute(sql, args)
- rows = self.cursor_to_dict(txn)
- for row in rows:
- key = (row["type"], row["state_key"])
- results[group][key] = row["event_id"]
return results
diff --git a/synapse/types.py b/synapse/types.py
index 9d64e8c4..1694af12 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -56,7 +56,7 @@ def get_domain_from_id(string):
try:
return string.split(":", 1)[1]
except IndexError:
- raise SynapseError(400, "Invalid ID: %r", string)
+ raise SynapseError(400, "Invalid ID: %r" % (string,))
class DomainSpecificString(
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index ea1f0f7c..c3108f51 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -267,10 +267,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
from synapse.handlers.typing import RoomMember
member = RoomMember(self.room_id, self.u_apple.to_string())
self.handler._member_typing_until[member] = 1002000
- self.handler._member_typing_timer[member] = (
- self.clock.call_later(1002, lambda: 0)
- )
- self.handler._room_typing[self.room_id] = set((self.u_apple.to_string(),))
+ self.handler._room_typing[self.room_id] = set([self.u_apple.to_string()])
self.assertEquals(self.event_source.get_current_key(), 0)
@@ -330,7 +327,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
},
}])
- self.clock.advance_time(11)
+ self.clock.advance_time(16)
self.on_new_event.assert_has_calls([
call('typing_key', 2, rooms=[self.room_id]),
diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py
index 467f253e..a269e6f5 100644
--- a/tests/rest/client/v1/test_typing.py
+++ b/tests/rest/client/v1/test_typing.py
@@ -105,9 +105,6 @@ class RoomTypingTestCase(RestTestCase):
# Need another user to make notifications actually work
yield self.join(self.room_id, user="@jim:red")
- def tearDown(self):
- self.hs.get_typing_handler().tearDown()
-
@defer.inlineCallbacks
def test_set_typing(self):
(code, _) = yield self.mock_resource.trigger(
@@ -147,7 +144,7 @@ class RoomTypingTestCase(RestTestCase):
self.assertEquals(self.event_source.get_current_key(), 1)
- self.clock.advance_time(31)
+ self.clock.advance_time(36)
self.assertEquals(self.event_source.get_current_key(), 2)
diff --git a/tests/utils.py b/tests/utils.py
index 915b934e..92d470cb 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -220,6 +220,7 @@ class MockClock(object):
# list of lists of [absolute_time, callback, expired] in no particular
# order
self.timers = []
+ self.loopers = []
def time(self):
return self.now
@@ -240,7 +241,7 @@ class MockClock(object):
return t
def looping_call(self, function, interval):
- pass
+ self.loopers.append([function, interval / 1000., self.now])
def cancel_call_later(self, timer, ignore_errs=False):
if timer[2]:
@@ -269,6 +270,12 @@ class MockClock(object):
else:
self.timers.append(t)
+ for looped in self.loopers:
+ func, interval, last = looped
+ if last + interval < self.now:
+ func()
+ looped[2] = self.now
+
def advance_time_msec(self, ms):
self.advance_time(ms / 1000.)