summaryrefslogtreecommitdiff
path: root/aiohttp_cors/urldispatcher_router_adapter.py
blob: 1a65e99878e1b9748c9a8cbc6055de85852c7aac (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
# Copyright 2015 Vladimir Rutsky <vladimir@rutsky.org>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""AbstractRouterAdapter for aiohttp.web.UrlDispatcher.
"""
import collections

from typing import Union

from aiohttp import web
from aiohttp import hdrs

from .abc import AbstractRouterAdapter
from .mixin import CorsViewMixin


# There several usage patterns of routes which should be handled
# differently.
#
# 1. Using new Resources:
#
#     resource = app.router.add_resource(path)
#     cors.add(resource, resource_defaults=...)
#     cors.add(resource.add_route(method1, handler1), config=...)
#     cors.add(resource.add_route(method2, handler2), config=...)
#     cors.add(resource.add_route(method3, handler3), config=...)
#
# Here all related Routes (i.e. routes with the same path) are in
# a single Resource.
#
# 2. Using `router.add_static()`:
#
#     route1 = app.router.add_static(
#         "/images", "/usr/share/app/images/")
#     cors.add(route1, config=...)
#
# Here old-style `web.StaticRoute` is created and wrapped with
# `web.ResourceAdapter`.
#
# 3. Using old `router.add_route()`:
#
#     cors.add(app.router.add_route(method1, path, hand1), config=...)
#     cors.add(app.router.add_route(method2, path, hand2), config=...)
#     cors.add(app.router.add_route(method3, path, hand3), config=...)
#
# This creates three Resources with single Route in each.
#
# 4. Using deprecated `register_route` with manually created
#    `web.Route`:
#
#     route1 = RouteSubclass(...)
#     app.router.register_route(route1)
#     cors.add(route1, config=...)
#
# Here old-style route is wrapped with `web.ResourceAdapter`.
#
# Preflight requests is roughly an OPTIONS request with query
# "is specific HTTP method is allowed".
# In order to properly handle preflight request we need to know which
# routes have enabled CORS on the request path and CORS configuration
# for requested HTTP method.
#
# In case of new usage pattern it's simple: we need to take a look at
# self._resource_config[resource][method] for the processing resource.
#
# In case of old usage pattern we need to iterate over routes with
# enabled CORS and check is requested path and HTTP method is accepted
# by a route.


class _ResourceConfig:
    def __init__(self, default_config):
        # Resource default config.
        self.default_config = default_config

        # HTTP method to route configuration.
        self.method_config = {}


def _is_web_view(entity, strict=True):
    webview = False
    if isinstance(entity, web.AbstractRoute):
        handler = entity.handler
        if isinstance(handler, type) and issubclass(handler, web.View):
            webview = True
            if not issubclass(handler, CorsViewMixin):
                if strict:
                    raise ValueError("web view should be derived from "
                                     "aiohttp_cors.WebViewMixig for working "
                                     "with the library")
                else:
                    return False
    return webview


class ResourcesUrlDispatcherRouterAdapter(AbstractRouterAdapter):
    """Adapter for `UrlDispatcher` for Resources-based routing only.

    Should be used with routes added in the following way:

        resource = app.router.add_resource(path)
        cors.add(resource, resource_defaults=...)
        cors.add(resource.add_route(method1, handler1), config=...)
        cors.add(resource.add_route(method2, handler2), config=...)
        cors.add(resource.add_route(method3, handler3), config=...)
    """

    def __init__(self,
                 router: web.UrlDispatcher,
                 defaults):
        """
        :param defaults:
            Default CORS configuration.
        """
        self._router = router

        # Default configuration for all routes.
        self._default_config = defaults

        # Mapping from Resource to _ResourceConfig.
        self._resource_config = {}

        self._resources_with_preflight_handlers = set()
        self._preflight_routes = set()

    def add_preflight_handler(
            self,
            routing_entity: Union[web.Resource, web.StaticResource,
                                  web.ResourceRoute],
            handler):
        """Add OPTIONS handler for all routes defined by `routing_entity`.

        Does nothing if CORS handler already handles routing entity.
        Should fail if there are conflicting user-defined OPTIONS handlers.
        """

        if isinstance(routing_entity, web.Resource):
            resource = routing_entity

            # Add preflight handler for Resource, if not yet added.

            if resource in self._resources_with_preflight_handlers:
                # Preflight handler already added for this resource.
                return
            for route_obj in resource:
                if route_obj.method == hdrs.METH_OPTIONS:
                    if route_obj.handler is handler:
                        return  # already added
                    else:
                        raise ValueError(
                            "{!r} already has OPTIONS handler {!r}"
                            .format(resource, route_obj.handler))
                elif route_obj.method == hdrs.METH_ANY:
                    if _is_web_view(route_obj):
                        self._preflight_routes.add(route_obj)
                        self._resources_with_preflight_handlers.add(resource)
                        return
                    else:
                        raise ValueError("{!r} already has a '*' handler "
                                         "for all methods".format(resource))

            preflight_route = resource.add_route(hdrs.METH_OPTIONS, handler)
            self._preflight_routes.add(preflight_route)
            self._resources_with_preflight_handlers.add(resource)

        elif isinstance(routing_entity, web.StaticResource):
            resource = routing_entity

            # Add preflight handler for Resource, if not yet added.

            if resource in self._resources_with_preflight_handlers:
                # Preflight handler already added for this resource.
                return

            resource.set_options_route(handler)
            preflight_route = resource._routes[hdrs.METH_OPTIONS]
            self._preflight_routes.add(preflight_route)
            self._resources_with_preflight_handlers.add(resource)

        elif isinstance(routing_entity, web.ResourceRoute):
            route = routing_entity

            if not self.is_cors_for_resource(route.resource):
                self.add_preflight_handler(route.resource, handler)

        else:
            raise ValueError(
                "Resource or ResourceRoute expected, got {!r}".format(
                    routing_entity))

    def is_cors_for_resource(self, resource: web.Resource) -> bool:
        """Is CORS is configured for the resource"""
        return resource in self._resources_with_preflight_handlers

    def _request_route(self, request: web.Request) -> web.ResourceRoute:
        match_info = request.match_info
        assert isinstance(match_info, web.UrlMappingMatchInfo)
        return match_info.route

    def _request_resource(self, request: web.Request) -> web.Resource:
        return self._request_route(request).resource

    def is_preflight_request(self, request: web.Request) -> bool:
        """Is `request` is a CORS preflight request."""
        route = self._request_route(request)
        if _is_web_view(route, strict=False):
            return request.method == 'OPTIONS'
        return route in self._preflight_routes

    def is_cors_enabled_on_request(self, request: web.Request) -> bool:
        """Is `request` is a request for CORS-enabled resource."""

        return self._request_resource(request) in self._resource_config

    def set_config_for_routing_entity(
            self,
            routing_entity: Union[web.Resource, web.StaticResource,
                                  web.ResourceRoute],
            config):
        """Record configuration for resource or it's route."""

        if isinstance(routing_entity, (web.Resource, web.StaticResource)):
            resource = routing_entity

            # Add resource configuration or fail if it's already added.
            if resource in self._resource_config:
                raise ValueError(
                    "CORS is already configured for {!r} resource.".format(
                        resource))

            self._resource_config[resource] = _ResourceConfig(
                default_config=config)

        elif isinstance(routing_entity, web.ResourceRoute):
            route = routing_entity

            # Add resource's route configuration or fail if it's already added.
            if route.resource not in self._resource_config:
                self.set_config_for_routing_entity(route.resource, config)

            if route.resource not in self._resource_config:
                raise ValueError(
                    "Can't setup CORS for {!r} request, "
                    "CORS must be enabled for route's resource first.".format(
                        route))

            resource_config = self._resource_config[route.resource]

            if route.method in resource_config.method_config:
                raise ValueError(
                    "Can't setup CORS for {!r} route: CORS already "
                    "configured on resource {!r} for {} method".format(
                        route, route.resource, route.method))

            resource_config.method_config[route.method] = config

        else:
            raise ValueError(
                "Resource or ResourceRoute expected, got {!r}".format(
                    routing_entity))

    async def get_preflight_request_config(
            self,
            preflight_request: web.Request,
            origin: str,
            requested_method: str):
        assert self.is_preflight_request(preflight_request)

        resource = self._request_resource(preflight_request)
        resource_config = self._resource_config[resource]
        defaulted_config = collections.ChainMap(
            resource_config.default_config,
            self._default_config)

        options = defaulted_config.get(origin, defaulted_config.get("*"))
        if options is not None and options.is_method_allowed(requested_method):
            # Requested method enabled for CORS in defaults, override it with
            # explicit route configuration (if any).
            route_config = resource_config.method_config.get(
                requested_method, {})

        else:
            # Requested method is not enabled in defaults.
            # Enable CORS for it only if explicit configuration exists.
            route_config = resource_config.method_config[requested_method]

        defaulted_config = collections.ChainMap(route_config, defaulted_config)

        return defaulted_config

    def get_non_preflight_request_config(self, request: web.Request):
        """Get stored CORS configuration for routing entity that handles
        specified request."""

        assert self.is_cors_enabled_on_request(request)

        resource = self._request_resource(request)
        resource_config = self._resource_config[resource]
        # Take Route config (if any) with defaults from Resource CORS
        # configuration and global defaults.
        route = request.match_info.route
        if _is_web_view(route, strict=False):
            method_config = request.match_info.handler.get_request_config(
                request, request.method)
        else:
            method_config = resource_config.method_config.get(request.method,
                                                              {})
        defaulted_config = collections.ChainMap(
            method_config,
            resource_config.default_config,
            self._default_config)

        return defaulted_config