summaryrefslogtreecommitdiff
path: root/synapse/api/ratelimiting.py
blob: 3f9ad4ce8991496046418247631802927695d7c1 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import collections


class Ratelimiter(object):
    """
    Ratelimit message sending by user.
    """

    def __init__(self):
        self.message_counts = collections.OrderedDict()

    def send_message(self, user_id, time_now_s, msg_rate_hz, burst_count):
        """Can the user send a message?
        Args:
            user_id: The user sending a message.
            time_now_s: The time now.
            msg_rate_hz: The long term number of messages a user can send in a
                second.
            burst_count: How many messages the user can send before being
                limited.
        Returns:
            A pair of a bool indicating if they can send a message now and a
                time in seconds of when they can next send a message.
        """
        self.prune_message_counts(time_now_s)
        message_count, time_start, _ignored = self.message_counts.pop(
            user_id, (0., time_now_s, None),
        )
        time_delta = time_now_s - time_start
        sent_count = message_count - time_delta * msg_rate_hz
        if sent_count < 0:
            allowed = True
            time_start = time_now_s
            message_count = 1.
        elif sent_count > burst_count - 1.:
            allowed = False
        else:
            allowed = True
            message_count += 1

        self.message_counts[user_id] = (
            message_count, time_start, msg_rate_hz
        )

        if msg_rate_hz > 0:
            time_allowed = (
                time_start + (message_count - burst_count + 1) / msg_rate_hz
            )
            if time_allowed < time_now_s:
                time_allowed = time_now_s
        else:
            time_allowed = -1

        return allowed, time_allowed

    def prune_message_counts(self, time_now_s):
        for user_id in self.message_counts.keys():
            message_count, time_start, msg_rate_hz = (
                self.message_counts[user_id]
            )
            time_delta = time_now_s - time_start
            if message_count - time_delta * msg_rate_hz > 0:
                break
            else:
                del self.message_counts[user_id]