summaryrefslogtreecommitdiff
path: root/synapse/storage/databases/main/task_scheduler.py
blob: 5c5372a8259df677bfd85a5107d8620e64b3aaa4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# Copyright 2023 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Any, Dict, List, Optional

from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
    DatabasePool,
    LoggingDatabaseConnection,
    LoggingTransaction,
    make_in_list_sql_clause,
)
from synapse.types import JsonDict, JsonMapping, ScheduledTask, TaskStatus
from synapse.util import json_encoder

if TYPE_CHECKING:
    from synapse.server import HomeServer


class TaskSchedulerWorkerStore(SQLBaseStore):
    def __init__(
        self,
        database: DatabasePool,
        db_conn: LoggingDatabaseConnection,
        hs: "HomeServer",
    ):
        super().__init__(database, db_conn, hs)

    @staticmethod
    def _convert_row_to_task(row: Dict[str, Any]) -> ScheduledTask:
        row["status"] = TaskStatus(row["status"])
        if row["params"] is not None:
            row["params"] = db_to_json(row["params"])
        if row["result"] is not None:
            row["result"] = db_to_json(row["result"])
        return ScheduledTask(**row)

    async def get_scheduled_tasks(
        self,
        *,
        actions: Optional[List[str]] = None,
        resource_id: Optional[str] = None,
        statuses: Optional[List[TaskStatus]] = None,
        max_timestamp: Optional[int] = None,
        limit: Optional[int] = None,
    ) -> List[ScheduledTask]:
        """Get a list of scheduled tasks from the DB.

        Args:
            actions: Limit the returned tasks to those specific action names
            resource_id: Limit the returned tasks to the specific resource id, if specified
            statuses: Limit the returned tasks to the specific statuses
            max_timestamp: Limit the returned tasks to the ones that have
                a timestamp inferior to the specified one
            limit: Only return `limit` number of rows if set.

        Returns: a list of `ScheduledTask`, ordered by increasing timestamps
        """

        def get_scheduled_tasks_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
            clauses: List[str] = []
            args: List[Any] = []
            if resource_id:
                clauses.append("resource_id = ?")
                args.append(resource_id)
            if actions is not None:
                clause, temp_args = make_in_list_sql_clause(
                    txn.database_engine, "action", actions
                )
                clauses.append(clause)
                args.extend(temp_args)
            if statuses is not None:
                clause, temp_args = make_in_list_sql_clause(
                    txn.database_engine, "status", statuses
                )
                clauses.append(clause)
                args.extend(temp_args)
            if max_timestamp is not None:
                clauses.append("timestamp <= ?")
                args.append(max_timestamp)

            sql = "SELECT * FROM scheduled_tasks"
            if clauses:
                sql = sql + " WHERE " + " AND ".join(clauses)

            sql = sql + " ORDER BY timestamp"

            if limit is not None:
                sql += " LIMIT ?"
                args.append(limit)

            txn.execute(sql, args)
            return self.db_pool.cursor_to_dict(txn)

        rows = await self.db_pool.runInteraction(
            "get_scheduled_tasks", get_scheduled_tasks_txn
        )
        return [TaskSchedulerWorkerStore._convert_row_to_task(row) for row in rows]

    async def insert_scheduled_task(self, task: ScheduledTask) -> None:
        """Insert a specified `ScheduledTask` in the DB.

        Args:
            task: the `ScheduledTask` to insert
        """
        await self.db_pool.simple_insert(
            "scheduled_tasks",
            {
                "id": task.id,
                "action": task.action,
                "status": task.status,
                "timestamp": task.timestamp,
                "resource_id": task.resource_id,
                "params": None
                if task.params is None
                else json_encoder.encode(task.params),
                "result": None
                if task.result is None
                else json_encoder.encode(task.result),
                "error": task.error,
            },
            desc="insert_scheduled_task",
        )

    async def update_scheduled_task(
        self,
        id: str,
        timestamp: int,
        *,
        status: Optional[TaskStatus] = None,
        result: Optional[JsonMapping] = None,
        error: Optional[str] = None,
    ) -> bool:
        """Update a scheduled task in the DB with some new value(s).

        Args:
            id: id of the `ScheduledTask` to update
            timestamp: new timestamp of the task
            status: new status of the task
            result: new result of the task
            error: new error of the task

        Returns: `False` if no matching row was found, `True` otherwise
        """
        updatevalues: JsonDict = {"timestamp": timestamp}
        if status is not None:
            updatevalues["status"] = status
        if result is not None:
            updatevalues["result"] = json_encoder.encode(result)
        if error is not None:
            updatevalues["error"] = error
        nb_rows = await self.db_pool.simple_update(
            "scheduled_tasks",
            {"id": id},
            updatevalues,
            desc="update_scheduled_task",
        )
        return nb_rows > 0

    async def get_scheduled_task(self, id: str) -> Optional[ScheduledTask]:
        """Get a specific `ScheduledTask` from its id.

        Args:
            id: the id of the task to retrieve

        Returns: the task if available, `None` otherwise
        """
        row = await self.db_pool.simple_select_one(
            table="scheduled_tasks",
            keyvalues={"id": id},
            retcols=(
                "id",
                "action",
                "status",
                "timestamp",
                "resource_id",
                "params",
                "result",
                "error",
            ),
            allow_none=True,
            desc="get_scheduled_task",
        )

        return TaskSchedulerWorkerStore._convert_row_to_task(row) if row else None

    async def delete_scheduled_task(self, id: str) -> None:
        """Delete a specific task from its id.

        Args:
            id: the id of the task to delete
        """
        await self.db_pool.simple_delete(
            "scheduled_tasks",
            keyvalues={"id": id},
            desc="delete_scheduled_task",
        )