From 47af65d464e11e1abd2ff23d02218035680c69fa Mon Sep 17 00:00:00 2001 From: Ruben Undheim Date: Fri, 21 Dec 2018 11:52:13 +0000 Subject: New upstream version 0.7.0 --- .gitignore | 3 + .pyup.yml | 4 + .travis.yml | 15 +- CHANGES.rst | 37 +- LICENSE | 2 +- Makefile | 8 + README.rst | 45 +- aiohttp_cors/__about__.py | 6 +- aiohttp_cors/__init__.py | 3 +- aiohttp_cors/_log.py | 22 - aiohttp_cors/abc.py | 9 +- aiohttp_cors/cors_config.py | 200 +-- aiohttp_cors/mixin.py | 47 + aiohttp_cors/preflight_handler.py | 130 ++ aiohttp_cors/urldispatcher_router_adapter.py | 209 +-- appveyor.yml | 35 - pytest.ini | 3 + requirements-dev.txt | 19 +- setup.cfg | 3 + setup.py | 1 + tests/aio_test_base.py | 63 - tests/doc/test_basic_usage.py | 150 ++- tests/integration/test_main.py | 1577 ++++++++++++----------- tests/integration/test_page.html | 2 +- tests/integration/test_real_browser.py | 209 ++- tests/unit/test___about__.py | 11 +- tests/unit/test_cors_config.py | 173 ++- tests/unit/test_mixin.py | 125 ++ tests/unit/test_preflight_handler.py | 12 + tests/unit/test_resource_options.py | 61 +- tests/unit/test_urldispatcher_router_adapter.py | 153 ++- tox.ini | 1 + 32 files changed, 1712 insertions(+), 1626 deletions(-) create mode 100644 .pyup.yml create mode 100644 Makefile delete mode 100644 aiohttp_cors/_log.py create mode 100644 aiohttp_cors/mixin.py create mode 100644 aiohttp_cors/preflight_handler.py delete mode 100644 appveyor.yml create mode 100644 pytest.ini delete mode 100644 tests/aio_test_base.py create mode 100644 tests/unit/test_mixin.py create mode 100644 tests/unit/test_preflight_handler.py diff --git a/.gitignore b/.gitignore index e2859fe..76419e6 100644 --- a/.gitignore +++ b/.gitignore @@ -48,3 +48,6 @@ docs/_build/ # PyBuilder target/ + +geckodriver.log +.pytest_cache \ No newline at end of file diff --git a/.pyup.yml b/.pyup.yml new file mode 100644 index 0000000..75f9711 --- /dev/null +++ b/.pyup.yml @@ -0,0 +1,4 @@ +# Label PRs with `deps-update` label +label_prs: deps-update + +schedule: every week diff --git a/.travis.yml b/.travis.yml index b42cab6..3120baf 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,7 +1,5 @@ language: python -dist: trusty python: -- 3.4 - 3.5 - 3.6 @@ -13,9 +11,8 @@ before_install: install: - pip install --upgrade pip setuptools wheel -# aiohttp git repo has only *.pyx files, so install cython too. -- '[ -z "$MASTER_AIOHTTP" ] || pip install -U cython git+https://github.com/KeepSafe/aiohttp.git' - pip install -Ur requirements-dev.txt +- pip install codecov before_script: # Start X-server for Selenium tests. @@ -25,7 +22,10 @@ before_script: script: - '[ "$TYPE" != "test" ] || python setup.py test --addopts -v --addopts -s' -- '[ "$TYPE" != "lint" ] || python setup.py check' +- '[ "$TYPE" != "lint" ] || python setup.py check -rms' + +after_success: + codecov env: global: @@ -33,10 +33,7 @@ env: matrix: # PYTHONASYNCIODEBUG environment variable is considered as enabled if it # is any non empty string. - - TYPE=test PYTHONASYNCIODEBUG= MASTER_AIOHTTP= - - TYPE=test PYTHONASYNCIODEBUG=x MASTER_AIOHTTP= - - TYPE=test PYTHONASYNCIODEBUG= MASTER_AIOHTTP=x - - TYPE=test PYTHONASYNCIODEBUG=x MASTER_AIOHTTP=x + - TYPE=test matrix: include: diff --git a/CHANGES.rst b/CHANGES.rst index 3224364..a68c3a8 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,32 +1,47 @@ -CHANGES -======= +========= + CHANGES +========= + +0.7.0 (2018-03-05) +================== + +- Make web view check implicit and type based (#159) + +- Disable Python 3.4 support (#156) + +- Support aiohttp 3.0+ (#155) + +0.6.0 (2017-12-21) +================== + +- Support aiohttp views by ``CorsViewMixin`` (#145) 0.5.3 (2017-04-21) ------------------- +================== -- Fix `typing` being installed on Python 3.6. +- Fix ``typing`` being installed on Python 3.6. 0.5.2 (2017-03-28) ------------------- +================== - Fix tests compatibility with ``aiohttp`` 2.0. This release and release v0.5.0 should work on ``aiohttp`` 2.0. 0.5.1 (2017-03-23) ------------------- +================== - Enforce ``aiohttp`` version to be less than 2.0. Newer ``aiohttp`` releases will be supported in the next release. 0.5.0 (2016-11-18) ------------------- +================== - Fix compatibility with ``aiohttp`` 1.1 0.4.0 (2016-04-04) ------------------- +================== - Fixed support with new Resources objects introduced in ``aiohttp`` 0.21.0. Minimum supported version of ``aiohttp`` is 0.21.4 now. @@ -54,7 +69,7 @@ CHANGES agnostic. 0.3.0 (2016-02-06) ------------------- +================== - Rename ``UrlDistatcherRouterAdapter`` to ``UrlDispatcherRouterAdapter``. @@ -62,7 +77,7 @@ CHANGES details. 0.2.0 (2015-11-30) ------------------- +================== - Move ABCs from ``aiohttp_cors.router_adapter`` to ``aiohttp_cors.abc``. @@ -71,6 +86,6 @@ CHANGES - Fix bug with configuring CORS for named routes. 0.1.0 (2015-11-05) ------------------- +================== * Initial release. diff --git a/LICENSE b/LICENSE index 8f71f43..abd6dd8 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright {yyyy} {name of copyright owner} + Copyright 2015-2018 Vladimir Rutsky and aio-libs team Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..2f8e8eb --- /dev/null +++ b/Makefile @@ -0,0 +1,8 @@ +all: test + + +flake: + flake8 aiohttp_cors tests setup.py + +test: flake + pytest tests diff --git a/README.rst b/README.rst index 64a3020..ab4b84e 100644 --- a/README.rst +++ b/README.rst @@ -1,3 +1,4 @@ +======================== CORS support for aiohttp ======================== @@ -126,7 +127,7 @@ from git: $ pip install aiohttp_cors Note that ``aiohttp_cors`` requires versions of Python >= 3.4.1 and -``aiohttp`` >= 0.21.4. +``aiohttp`` >= 1.1. Usage ===== @@ -346,6 +347,46 @@ in the router: for route in list(app.router.routes()): cors.add(route) +You can also use ``CorsViewMixin`` on ``web.View``: + +.. code-block:: python + + class CorsView(web.View, CorsViewMixin): + + cors_config = { + "*": ResourceOption( + allow_credentials=True, + allow_headers="X-Request-ID", + ) + } + + @asyncio.coroutine + def get(self): + return web.Response(text="Done") + + @custom_cors({ + "*": ResourceOption( + allow_credentials=True, + allow_headers="*", + ) + }) + @asyncio.coroutine + def post(self): + return web.Response(text="Done") + + cors = aiohttp_cors.setup(app, defaults={ + "*": aiohttp_cors.ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_headers="*", + ) + }) + + cors.add( + app.router.add_route("*", "/resource", CorsView), + webview=True) + + Security ======== @@ -460,7 +501,7 @@ Post release steps: Bugs ==== -Please report bugs, issues, feature requests, etc. on +Please report bugs, issues, feature requests, etc. on `GitHub `__. diff --git a/aiohttp_cors/__about__.py b/aiohttp_cors/__about__.py index eb70c30..51c4841 100644 --- a/aiohttp_cors/__about__.py +++ b/aiohttp_cors/__about__.py @@ -19,10 +19,10 @@ This module must be stand-alone executable. """ __title__ = "aiohttp-cors" -__version__ = "0.5.3" -__author__ = "Vladimir Rutsky" +__version__ = "0.7.0" +__author__ = "Vladimir Rutsky and aio-libs team" __email__ = "vladimir@rutsky.org" __summary__ = "CORS support for aiohttp" __uri__ = "https://github.com/aio-libs/aiohttp-cors" __license__ = "Apache License, Version 2.0" -__copyright__ = "2015, 2016, 2017 {}".format(__author__) +__copyright__ = "2015-2018 {}".format(__author__) diff --git a/aiohttp_cors/__init__.py b/aiohttp_cors/__init__.py index 49474c8..cbcc5ef 100644 --- a/aiohttp_cors/__init__.py +++ b/aiohttp_cors/__init__.py @@ -25,11 +25,12 @@ from .__about__ import ( ) from .resource_options import ResourceOptions from .cors_config import CorsConfig +from .mixin import CorsViewMixin, custom_cors __all__ = ( "__title__", "__version__", "__author__", "__email__", "__summary__", "__uri__", "__license__", "__copyright__", - "setup", "CorsConfig", "ResourceOptions", + "setup", "CorsConfig", "ResourceOptions", "CorsViewMixin", "custom_cors" ) diff --git a/aiohttp_cors/_log.py b/aiohttp_cors/_log.py deleted file mode 100644 index 5b216cb..0000000 --- a/aiohttp_cors/_log.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright 2015 Vladimir Rutsky -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""aiohttp_cors logger""" - -import logging - -__all__ = ("logger",) - -# pylint: disable=invalid-name -logger = logging.getLogger("aiohttp_cors") diff --git a/aiohttp_cors/abc.py b/aiohttp_cors/abc.py index cdddd36..5204a83 100644 --- a/aiohttp_cors/abc.py +++ b/aiohttp_cors/abc.py @@ -15,7 +15,6 @@ """Abstract base classes. """ -import asyncio from abc import ABCMeta, abstractmethod from aiohttp import web @@ -51,7 +50,10 @@ class AbstractRouterAdapter(metaclass=ABCMeta): """ @abstractmethod - def add_preflight_handler(self, routing_entity, handler): + def add_preflight_handler(self, + routing_entity, + handler, + webview: bool=False): """Add OPTIONS handler for all routes defined by `routing_entity`. Does nothing if CORS handler already handles routing entity. @@ -79,9 +81,8 @@ class AbstractRouterAdapter(metaclass=ABCMeta): entity. """ - @asyncio.coroutine @abstractmethod - def get_preflight_request_config( + async def get_preflight_request_config( self, preflight_request: web.Request, origin: str, diff --git a/aiohttp_cors/cors_config.py b/aiohttp_cors/cors_config.py index d8017c9..2e1aeea 100644 --- a/aiohttp_cors/cors_config.py +++ b/aiohttp_cors/cors_config.py @@ -15,22 +15,21 @@ """CORS configuration container class definition. """ -import asyncio import collections +import warnings from typing import Mapping, Union, Any from aiohttp import hdrs, web -from .urldispatcher_router_adapter import OldRoutesUrlDispatcherRouterAdapter from .urldispatcher_router_adapter import ResourcesUrlDispatcherRouterAdapter from .abc import AbstractRouterAdapter from .resource_options import ResourceOptions +from .preflight_handler import _PreflightHandler __all__ = ( "CorsConfig", ) - # Positive response to Access-Control-Allow-Credentials _TRUE = "true" # CORS simple response headers: @@ -103,7 +102,7 @@ def _parse_config_options( _ConfigType = Mapping[str, Union[ResourceOptions, Mapping[str, Any]]] -class _CorsConfigImpl: +class _CorsConfigImpl(_PreflightHandler): def __init__(self, app: web.Application, @@ -139,10 +138,9 @@ class _CorsConfigImpl: return routing_entity - @asyncio.coroutine - def _on_response_prepare(self, - request: web.Request, - response: web.StreamResponse): + async def _on_response_prepare(self, + request: web.Request, + response: web.StreamResponse): """Non-preflight CORS request response processor. If request is done on CORS-enabled route, process request parameters @@ -196,127 +194,11 @@ class _CorsConfigImpl: # Set allowed credentials. response.headers[hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS] = _TRUE - @staticmethod - def _parse_request_method(request: web.Request): - """Parse Access-Control-Request-Method header of the preflight request - """ - method = request.headers.get(hdrs.ACCESS_CONTROL_REQUEST_METHOD) - if method is None: - raise web.HTTPForbidden( - text="CORS preflight request failed: " - "'Access-Control-Request-Method' header is not specified") - - # FIXME: validate method string (ABNF: method = token), if parsing - # fails, raise HTTPForbidden. - - return method - - @staticmethod - def _parse_request_headers(request: web.Request): - """Parse Access-Control-Request-Headers header or the preflight request - - Returns set of headers in upper case. - """ - headers = request.headers.get(hdrs.ACCESS_CONTROL_REQUEST_HEADERS) - if headers is None: - return frozenset() - - # FIXME: validate each header string, if parsing fails, raise - # HTTPForbidden. - # FIXME: check, that headers split and stripped correctly (according - # to ABNF). - headers = (h.strip(" \t").upper() for h in headers.split(",")) - # pylint: disable=bad-builtin - return frozenset(filter(None, headers)) - - @asyncio.coroutine - def _preflight_handler(self, request: web.Request): - """CORS preflight request handler""" - - # Handle according to part 6.2 of the CORS specification. - - origin = request.headers.get(hdrs.ORIGIN) - if origin is None: - # Terminate CORS according to CORS 6.2.1. - raise web.HTTPForbidden( - text="CORS preflight request failed: " - "origin header is not specified in the request") - - # CORS 6.2.3. Doing it out of order is not an error. - request_method = self._parse_request_method(request) - - # CORS 6.2.5. Doing it out of order is not an error. - - try: - config = \ - yield from self._router_adapter.get_preflight_request_config( - request, origin, request_method) - except KeyError: - raise web.HTTPForbidden( - text="CORS preflight request failed: " - "request method {!r} is not allowed " - "for {!r} origin".format(request_method, origin)) - - if not config: - # No allowed origins for the route. - # Terminate CORS according to CORS 6.2.1. - raise web.HTTPForbidden( - text="CORS preflight request failed: " - "no origins are allowed") - - options = config.get(origin, config.get("*")) - if options is None: - # No configuration for the origin - deny. - # Terminate CORS according to CORS 6.2.2. - raise web.HTTPForbidden( - text="CORS preflight request failed: " - "origin '{}' is not allowed".format(origin)) - - # CORS 6.2.4 - request_headers = self._parse_request_headers(request) - - # CORS 6.2.6 - if options.allow_headers == "*": - pass - else: - disallowed_headers = request_headers - options.allow_headers - if disallowed_headers: - raise web.HTTPForbidden( - text="CORS preflight request failed: " - "headers are not allowed: {}".format( - ", ".join(disallowed_headers))) - - # Ok, CORS actual request with specified in the preflight request - # parameters is allowed. - # Set appropriate headers and return 200 response. - - response = web.Response() - - # CORS 6.2.7 - response.headers[hdrs.ACCESS_CONTROL_ALLOW_ORIGIN] = origin - if options.allow_credentials: - # Set allowed credentials. - response.headers[hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS] = _TRUE - - # CORS 6.2.8 - if options.max_age is not None: - response.headers[hdrs.ACCESS_CONTROL_MAX_AGE] = \ - str(options.max_age) - - # CORS 6.2.9 - # TODO: more optimal for client preflight request cache would be to - # respond with ALL allowed methods. - response.headers[hdrs.ACCESS_CONTROL_ALLOW_METHODS] = request_method - - # CORS 6.2.10 - if request_headers: - # Note: case of the headers in the request is changed, but this - # shouldn't be a problem, since the headers should be compared in - # the case-insensitive way. - response.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS] = \ - ",".join(request_headers) - - return response + async def _get_config(self, request, origin, request_method): + config = \ + await self._router_adapter.get_preflight_request_config( + request, origin, request_method) + return config class CorsConfig: @@ -341,7 +223,7 @@ class CorsConfig: Router adapter. Required if application uses non-default router. """ - defaults = _parse_config_options(defaults) + self.defaults = _parse_config_options(defaults) self._cors_impl = None @@ -350,27 +232,16 @@ class CorsConfig: self._old_routes_cors_impl = None - if router_adapter is not None: - self._cors_impl = _CorsConfigImpl(app, router_adapter) - - elif isinstance(app.router, web.UrlDispatcher): - self._resources_router_adapter = \ - ResourcesUrlDispatcherRouterAdapter(app.router, defaults) - self._resources_cors_impl = _CorsConfigImpl( - app, - self._resources_router_adapter) - self._old_routes_cors_impl = _CorsConfigImpl( - app, - OldRoutesUrlDispatcherRouterAdapter(app.router, defaults)) - else: - raise RuntimeError( - "Router adapter is not specified. " - "Routers other than aiohttp.web.UrlDispatcher requires" - "custom router adapter.") + if router_adapter is None: + router_adapter = \ + ResourcesUrlDispatcherRouterAdapter(app.router, self.defaults) + + self._cors_impl = _CorsConfigImpl(app, router_adapter) def add(self, routing_entity, - config: _ConfigType = None): + config: _ConfigType = None, + webview: bool=False): """Enable CORS for specific route or resource. If route is passed CORS is enabled for route's resource. @@ -382,30 +253,11 @@ class CorsConfig: :return: `routing_entity`. """ - if self._cors_impl is not None: - # Custom router adapter. - return self._cors_impl.add(routing_entity, config) + if webview: + warnings.warn('webview argument is deprecated, ' + 'views are handled authomatically without ' + 'extra settings', + DeprecationWarning, + stacklevel=2) - else: - # UrlDispatcher. - - if isinstance(routing_entity, (web.Resource, web.StaticResource)): - # New Resource - use new router adapter. - return self._resources_cors_impl.add(routing_entity, config) - - elif isinstance(routing_entity, web.AbstractRoute): - if self._resources_router_adapter.is_cors_for_resource( - routing_entity.resource): - # Route which resource has CORS configuration in - # new-style router adapter. - return self._resources_cors_impl.add( - routing_entity, config) - else: - # Route which resource has no CORS configuration, i.e. - # old-style route. - return self._old_routes_cors_impl.add( - routing_entity, config) - - else: - raise ValueError( - "Unknown resource/route type: {!r}".format(routing_entity)) + return self._cors_impl.add(routing_entity, config) diff --git a/aiohttp_cors/mixin.py b/aiohttp_cors/mixin.py new file mode 100644 index 0000000..f5b4506 --- /dev/null +++ b/aiohttp_cors/mixin.py @@ -0,0 +1,47 @@ +import collections + +from .preflight_handler import _PreflightHandler + + +def custom_cors(config): + def wrapper(function): + name = "{}_cors_config".format(function.__name__) + setattr(function, name, config) + return function + return wrapper + + +class CorsViewMixin(_PreflightHandler): + cors_config = None + + @classmethod + def get_request_config(cls, request, request_method): + try: + from . import APP_CONFIG_KEY + cors = request.app[APP_CONFIG_KEY] + except KeyError: + raise ValueError("aiohttp-cors is not configured.") + + method = getattr(cls, request_method.lower(), None) + + if not method: + raise KeyError() + + config_property_key = "{}_cors_config".format(request_method.lower()) + + custom_config = getattr(method, config_property_key, None) + if not custom_config: + custom_config = {} + + class_config = cls.cors_config + if not class_config: + class_config = {} + + return collections.ChainMap(custom_config, class_config, cors.defaults) + + async def _get_config(self, request, origin, request_method): + return self.get_request_config(request, request_method) + + async def options(self): + response = await self._preflight_handler(self.request) + return response diff --git a/aiohttp_cors/preflight_handler.py b/aiohttp_cors/preflight_handler.py new file mode 100644 index 0000000..35e15e1 --- /dev/null +++ b/aiohttp_cors/preflight_handler.py @@ -0,0 +1,130 @@ +from aiohttp import hdrs, web + +# Positive response to Access-Control-Allow-Credentials +_TRUE = "true" + + +class _PreflightHandler: + + @staticmethod + def _parse_request_method(request: web.Request): + """Parse Access-Control-Request-Method header of the preflight request + """ + method = request.headers.get(hdrs.ACCESS_CONTROL_REQUEST_METHOD) + if method is None: + raise web.HTTPForbidden( + text="CORS preflight request failed: " + "'Access-Control-Request-Method' header is not specified") + + # FIXME: validate method string (ABNF: method = token), if parsing + # fails, raise HTTPForbidden. + + return method + + @staticmethod + def _parse_request_headers(request: web.Request): + """Parse Access-Control-Request-Headers header or the preflight request + + Returns set of headers in upper case. + """ + headers = request.headers.get(hdrs.ACCESS_CONTROL_REQUEST_HEADERS) + if headers is None: + return frozenset() + + # FIXME: validate each header string, if parsing fails, raise + # HTTPForbidden. + # FIXME: check, that headers split and stripped correctly (according + # to ABNF). + headers = (h.strip(" \t").upper() for h in headers.split(",")) + # pylint: disable=bad-builtin + return frozenset(filter(None, headers)) + + async def _get_config(self, request, origin, request_method): + raise NotImplementedError() + + async def _preflight_handler(self, request: web.Request): + """CORS preflight request handler""" + + # Handle according to part 6.2 of the CORS specification. + + origin = request.headers.get(hdrs.ORIGIN) + if origin is None: + # Terminate CORS according to CORS 6.2.1. + raise web.HTTPForbidden( + text="CORS preflight request failed: " + "origin header is not specified in the request") + + # CORS 6.2.3. Doing it out of order is not an error. + request_method = self._parse_request_method(request) + + # CORS 6.2.5. Doing it out of order is not an error. + + try: + config = \ + await self._get_config(request, origin, request_method) + except KeyError: + raise web.HTTPForbidden( + text="CORS preflight request failed: " + "request method {!r} is not allowed " + "for {!r} origin".format(request_method, origin)) + + if not config: + # No allowed origins for the route. + # Terminate CORS according to CORS 6.2.1. + raise web.HTTPForbidden( + text="CORS preflight request failed: " + "no origins are allowed") + + options = config.get(origin, config.get("*")) + if options is None: + # No configuration for the origin - deny. + # Terminate CORS according to CORS 6.2.2. + raise web.HTTPForbidden( + text="CORS preflight request failed: " + "origin '{}' is not allowed".format(origin)) + + # CORS 6.2.4 + request_headers = self._parse_request_headers(request) + + # CORS 6.2.6 + if options.allow_headers == "*": + pass + else: + disallowed_headers = request_headers - options.allow_headers + if disallowed_headers: + raise web.HTTPForbidden( + text="CORS preflight request failed: " + "headers are not allowed: {}".format( + ", ".join(disallowed_headers))) + + # Ok, CORS actual request with specified in the preflight request + # parameters is allowed. + # Set appropriate headers and return 200 response. + + response = web.Response() + + # CORS 6.2.7 + response.headers[hdrs.ACCESS_CONTROL_ALLOW_ORIGIN] = origin + if options.allow_credentials: + # Set allowed credentials. + response.headers[hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS] = _TRUE + + # CORS 6.2.8 + if options.max_age is not None: + response.headers[hdrs.ACCESS_CONTROL_MAX_AGE] = \ + str(options.max_age) + + # CORS 6.2.9 + # TODO: more optimal for client preflight request cache would be to + # respond with ALL allowed methods. + response.headers[hdrs.ACCESS_CONTROL_ALLOW_METHODS] = request_method + + # CORS 6.2.10 + if request_headers: + # Note: case of the headers in the request is changed, but this + # shouldn't be a problem, since the headers should be compared in + # the case-insensitive way. + response.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS] = \ + ",".join(request_headers) + + return response diff --git a/aiohttp_cors/urldispatcher_router_adapter.py b/aiohttp_cors/urldispatcher_router_adapter.py index 92bfbb3..1a65e99 100644 --- a/aiohttp_cors/urldispatcher_router_adapter.py +++ b/aiohttp_cors/urldispatcher_router_adapter.py @@ -14,9 +14,7 @@ """AbstractRouterAdapter for aiohttp.web.UrlDispatcher. """ -import asyncio import collections -import re from typing import Union @@ -24,6 +22,7 @@ from aiohttp import web from aiohttp import hdrs from .abc import AbstractRouterAdapter +from .mixin import CorsViewMixin # There several usage patterns of routes which should be handled @@ -89,6 +88,22 @@ class _ResourceConfig: self.method_config = {} +def _is_web_view(entity, strict=True): + webview = False + if isinstance(entity, web.AbstractRoute): + handler = entity.handler + if isinstance(handler, type) and issubclass(handler, web.View): + webview = True + if not issubclass(handler, CorsViewMixin): + if strict: + raise ValueError("web view should be derived from " + "aiohttp_cors.WebViewMixig for working " + "with the library") + else: + return False + return webview + + class ResourcesUrlDispatcherRouterAdapter(AbstractRouterAdapter): """Adapter for `UrlDispatcher` for Resources-based routing only. @@ -138,6 +153,22 @@ class ResourcesUrlDispatcherRouterAdapter(AbstractRouterAdapter): if resource in self._resources_with_preflight_handlers: # Preflight handler already added for this resource. return + for route_obj in resource: + if route_obj.method == hdrs.METH_OPTIONS: + if route_obj.handler is handler: + return # already added + else: + raise ValueError( + "{!r} already has OPTIONS handler {!r}" + .format(resource, route_obj.handler)) + elif route_obj.method == hdrs.METH_ANY: + if _is_web_view(route_obj): + self._preflight_routes.add(route_obj) + self._resources_with_preflight_handlers.add(resource) + return + else: + raise ValueError("{!r} already has a '*' handler " + "for all methods".format(resource)) preflight_route = resource.add_route(hdrs.METH_OPTIONS, handler) self._preflight_routes.add(preflight_route) @@ -160,13 +191,8 @@ class ResourcesUrlDispatcherRouterAdapter(AbstractRouterAdapter): elif isinstance(routing_entity, web.ResourceRoute): route = routing_entity - # Preflight handler for Route's Resource already must be - # configured. if not self.is_cors_for_resource(route.resource): - raise ValueError( - "Can't setup CORS for {!r} request, " - "CORS must be enabled for route's resource first.".format( - route)) + self.add_preflight_handler(route.resource, handler) else: raise ValueError( @@ -187,8 +213,10 @@ class ResourcesUrlDispatcherRouterAdapter(AbstractRouterAdapter): def is_preflight_request(self, request: web.Request) -> bool: """Is `request` is a CORS preflight request.""" - - return self._request_route(request) in self._preflight_routes + route = self._request_route(request) + if _is_web_view(route, strict=False): + return request.method == 'OPTIONS' + return route in self._preflight_routes def is_cors_enabled_on_request(self, request: web.Request) -> bool: """Is `request` is a request for CORS-enabled resource.""" @@ -218,6 +246,9 @@ class ResourcesUrlDispatcherRouterAdapter(AbstractRouterAdapter): route = routing_entity # Add resource's route configuration or fail if it's already added. + if route.resource not in self._resource_config: + self.set_config_for_routing_entity(route.resource, config) + if route.resource not in self._resource_config: raise ValueError( "Can't setup CORS for {!r} request, " @@ -239,8 +270,7 @@ class ResourcesUrlDispatcherRouterAdapter(AbstractRouterAdapter): "Resource or ResourceRoute expected, got {!r}".format( routing_entity)) - @asyncio.coroutine - def get_preflight_request_config( + async def get_preflight_request_config( self, preflight_request: web.Request, origin: str, @@ -279,155 +309,16 @@ class ResourcesUrlDispatcherRouterAdapter(AbstractRouterAdapter): resource_config = self._resource_config[resource] # Take Route config (if any) with defaults from Resource CORS # configuration and global defaults. + route = request.match_info.route + if _is_web_view(route, strict=False): + method_config = request.match_info.handler.get_request_config( + request, request.method) + else: + method_config = resource_config.method_config.get(request.method, + {}) defaulted_config = collections.ChainMap( - resource_config.method_config.get(request.method, {}), + method_config, resource_config.default_config, self._default_config) return defaulted_config - - -class OldRoutesUrlDispatcherRouterAdapter(AbstractRouterAdapter): - """Adapter for `UrlDispatcher` for old-style routing only. - - In all use cases when Resource is not explicitly used, - Resource will automatically allocated for old route. - In this case all routes will have it's own resource, and to find - related routes (routes that shares same path) we need to iterate over - all routes with enabled CORS and check is they handle specific path. - - This whole class should go away when user will migrate to proper - Resource/Route usage scheme. - """ - - def __init__(self, - router: web.UrlDispatcher, - defaults): - """ - :param defaults: - Default CORS configuration. - """ - self._router = router - - # Default configuration for all routes. - self._default_config = defaults - - # Mapping from route to config. - self._route_config = collections.OrderedDict() - - self._preflight_routes = set() - - def add_preflight_handler( - self, - route: web.AbstractRoute, - handler): - """Add OPTIONS handler for same paths that `route` handles.""" - - assert isinstance(route, web.AbstractRoute) - - if isinstance(route, web.ResourceRoute): - # New-style route (which Resource is not used explicitly, - # otherwise it would be handled by other adapter). - preflight_route = route.resource.add_route( - hdrs.METH_OPTIONS, handler) - - elif isinstance(route, web.Route): - # Old-style route. - - if isinstance(route, web.StaticRoute): - # TODO: Use custom matches that uses `str.startswith()` - # if regexp performance is not enough. - pattern = re.compile("^" + re.escape(route._prefix)) - preflight_route = web.DynamicRoute( - hdrs.METH_OPTIONS, handler, None, pattern, "") - self._router.register_route(preflight_route) - - elif isinstance(route, web.PlainRoute): - # May occur only if user manually creates PlainRoute. - preflight_route = self._router.add_route( - hdrs.METH_OPTIONS, route._path, handler) - - elif isinstance(route, web.DynamicRoute): - # May occur only if user manually creates DynamicRoute. - preflight_route = web.DynamicRoute( - hdrs.METH_OPTIONS, handler, None, - route._pattern, route._formatter) - self._router.register_route(preflight_route) - - else: - raise RuntimeError( - "Unhandled deprecated route type {!r}".format(route)) - - else: - raise RuntimeError("Unhandled route type {!r}".format(route)) - - self._preflight_routes.add(preflight_route) - - def _request_route(self, request: web.Request) -> web.ResourceRoute: - match_info = request.match_info - assert isinstance(match_info, web.UrlMappingMatchInfo) - return match_info.route - - def is_preflight_request(self, request: web.Request) -> bool: - """Is `request` is a CORS preflight request.""" - - return self._request_route(request) in self._preflight_routes - - def is_cors_enabled_on_request(self, request: web.Request) -> bool: - """Is `request` is a request for CORS-enabled resource.""" - - return self._request_route(request) in self._route_config - - def set_config_for_routing_entity( - self, - route: web.AbstractRoute, - config): - """Record CORS configuration for route.""" - - assert isinstance(route, web.AbstractRoute) - - if any(options.allow_methods is not None - for options in config.values()): - raise ValueError( - "'allow_methods' parameter is not supported on old-style " - "routes. You specified {!r} for {!r}. " - "Use Resources to configure CORS.".format( - config, route)) - - if route in self._route_config: - raise ValueError( - "CORS is already configured for {!r} route.".format( - route)) - - self._route_config[route] = config - - @asyncio.coroutine - def get_preflight_request_config( - self, - preflight_request: web.Request, - origin: str, - requested_method: str): - assert self.is_preflight_request(preflight_request) - - request = preflight_request.clone(method=requested_method) - for route, config in self._route_config.items(): - match_info, allowed_methods = yield from route.resource.resolve( - request) - if match_info is not None: - return collections.ChainMap(config, self._default_config) - else: - raise KeyError - - def get_non_preflight_request_config(self, request: web.Request): - """Get stored CORS configuration for routing entity that handles - specified request.""" - - assert self.is_cors_enabled_on_request(request) - - route = self._request_route(request) - route_config = self._route_config[route] - - defaulted_config = collections.ChainMap( - route_config, self._default_config) - - return defaulted_config diff --git a/appveyor.yml b/appveyor.yml deleted file mode 100644 index 012eecb..0000000 --- a/appveyor.yml +++ /dev/null @@ -1,35 +0,0 @@ -version: 0.0.1.dev{build} - -environment: - matrix: - - PYTHON: "C:\\Python34" - PYTHON_VERSION: "3.4.4" - PYTHON_ARCH: "32" - - PYTHON: "C:\\Python34" - PYTHON_VERSION: "3.4.4" - PYTHON_ARCH: "64" - - PYTHON: "C:\\Python35" - PYTHON_VERSION: "3.5.1" - PYTHON_ARCH: "32" - - PYTHON: "C:\\Python35" - PYTHON_VERSION: "3.5.1" - PYTHON_ARCH: "64" - - PYTHON: "C:\\Python36" - PYTHON_VERSION: "3.6.0" - PYTHON_ARCH: "32" - - PYTHON: "C:\\Python36" - PYTHON_VERSION: "3.6.0" - PYTHON_ARCH: "64" - -install: - - "powershell ./install_python_and_pip.ps1" - - "SET PATH=%PYTHON%;%PYTHON%\\Scripts;%PATH%" - - "python --version" - - "pip install -U pip setuptools wheel" - - "pip list" - - "pip install -r requirements-dev.txt" - - "python setup.py develop" - -build: false -test_script: - - "python setup.py test" diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..e62899b --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +filterwarnings= + error diff --git a/requirements-dev.txt b/requirements-dev.txt index 7a37e87..6331621 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,11 +1,14 @@ -tox==2.7.0 -pytest==3.0.7 -pytest-cov==2.4.0 -pytest-runner==2.11.1 -pytest-flakes==1.0.1 +aiohttp==3.0.5 +tox==2.9.1 +pytest==3.4.0 +pytest-aiohttp==0.3.0 +pytest-cov==2.5.1 +pytest-runner==3.0 +pytest-flakes==2.0.0 pytest-pylint==0.7.1 -flake8==3.3.0 -selenium==3.3.3 -docutils==0.13.1 +pytest-sugar==0.9.1 +flake8==3.5.0 +selenium==3.8.1 +docutils==0.14 pygments==2.2.0 -e . diff --git a/setup.cfg b/setup.cfg index 31ad82b..20d367e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,5 @@ [aliases] test = pytest + +[tool:pytest] +addopts= --cov=aiohttp_cors --cov-report=term --cov-report=html --cov-branch --no-cov-on-fail \ No newline at end of file diff --git a/setup.py b/setup.py index c70d98b..ce7d69c 100644 --- a/setup.py +++ b/setup.py @@ -79,6 +79,7 @@ setup( "Programming Language :: Python :: 3.6", "Topic :: Software Development :: Libraries", "Topic :: Internet :: WWW/HTTP", + "Framework :: AsyncIO", "Operating System :: MacOS :: MacOS X", "Operating System :: Microsoft :: Windows", "Operating System :: POSIX", diff --git a/tests/aio_test_base.py b/tests/aio_test_base.py deleted file mode 100644 index 7289b6a..0000000 --- a/tests/aio_test_base.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2015 Vladimir Rutsky -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Base classes and utility functions for testing asyncio-powered code. -""" - -import unittest -import asyncio -import socket -import functools -import concurrent.futures - - -@asyncio.coroutine -def create_server(protocol_factory, loop=None, sock=None): - """Create server listening on random port""" - - if sock is None: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.bind(("127.0.0.1", 0)) - sock.listen(10) - - if loop is None: - loop = asyncio.get_event_loop() - - return (yield from loop.create_server(protocol_factory, sock=sock)) - - -class AioTestBase(unittest.TestCase): - """Base class for tests that need temporary asyncio event loop""" - def setUp(self): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) - - self.thread_pool_executor = concurrent.futures.ThreadPoolExecutor(4) - - def tearDown(self): - self.thread_pool_executor.shutdown() - - self.loop.close() - asyncio.set_event_loop(None) - - -def asynctest(test_method): - """Decorator for coroutine tests. - - To be used with `AioTestBase`-based tests""" - @functools.wraps(test_method) - def wrapper(self): - """Synchronously run test method in the event loop""" - self.loop.run_until_complete(test_method(self)) - return wrapper diff --git a/tests/doc/test_basic_usage.py b/tests/doc/test_basic_usage.py index 899da8d..4eace99 100644 --- a/tests/doc/test_basic_usage.py +++ b/tests/doc/test_basic_usage.py @@ -14,93 +14,87 @@ """Test basic usage.""" -import unittest - - -class TestBasicUsage(unittest.TestCase): - def test_main(self): - # This tests corresponds to example from documentation. - # If you updating it, don't forget to update documentation. - - import asyncio - from aiohttp import web - import aiohttp_cors - - @asyncio.coroutine - def handler(request): - return web.Response( - text="Hello!", - headers={ - "X-Custom-Server-Header": "Custom data", - }) - - app = web.Application() - - # `aiohttp_cors.setup` returns `aiohttp_cors.CorsConfig` instance. - # The `cors` instance will store CORS configuration for the - # application. - cors = aiohttp_cors.setup(app) - - # To enable CORS processing for specific route you need to add - # that route to the CORS configuration object and specify its - # CORS options. - resource = cors.add(app.router.add_resource("/hello")) - route = cors.add( - resource.add_route("GET", handler), { - "http://client.example.org": aiohttp_cors.ResourceOptions( - allow_credentials=True, - expose_headers=("X-Custom-Server-Header",), - allow_headers=("X-Requested-With", "Content-Type"), - max_age=3600, - ) - }) - assert route is not None +async def test_main(): + # This tests corresponds to example from documentation. + # If you updating it, don't forget to update documentation. - def test_defaults(self): - # This tests corresponds to example from documentation. - # If you updating it, don't forget to update documentation. + from aiohttp import web + import aiohttp_cors - import asyncio - from aiohttp import web - import aiohttp_cors + async def handler(request): + return web.Response( + text="Hello!", + headers={ + "X-Custom-Server-Header": "Custom data", + }) - @asyncio.coroutine - def handler(request): - return web.Response( - text="Hello!", - headers={ - "X-Custom-Server-Header": "Custom data", - }) + app = web.Application() + + # `aiohttp_cors.setup` returns `aiohttp_cors.CorsConfig` instance. + # The `cors` instance will store CORS configuration for the + # application. + cors = aiohttp_cors.setup(app) + + # To enable CORS processing for specific route you need to add + # that route to the CORS configuration object and specify its + # CORS options. + resource = cors.add(app.router.add_resource("/hello")) + route = cors.add( + resource.add_route("GET", handler), { + "http://client.example.org": aiohttp_cors.ResourceOptions( + allow_credentials=True, + expose_headers=("X-Custom-Server-Header",), + allow_headers=("X-Requested-With", "Content-Type"), + max_age=3600, + ) + }) + + assert route is not None + + +async def test_defaults(): + # This tests corresponds to example from documentation. + # If you updating it, don't forget to update documentation. + + from aiohttp import web + import aiohttp_cors + + async def handler(request): + return web.Response( + text="Hello!", + headers={ + "X-Custom-Server-Header": "Custom data", + }) - handler_post = handler - handler_put = handler + handler_post = handler + handler_put = handler - app = web.Application() + app = web.Application() - # Example: + # Example: - cors = aiohttp_cors.setup(app, defaults={ - # Allow all to read all CORS-enabled resources from - # http://client.example.org. - "http://client.example.org": aiohttp_cors.ResourceOptions(), - }) + cors = aiohttp_cors.setup(app, defaults={ + # Allow all to read all CORS-enabled resources from + # http://client.example.org. + "http://client.example.org": aiohttp_cors.ResourceOptions(), + }) - # Enable CORS on routes. + # Enable CORS on routes. - # According to defaults POST and PUT will be available only to - # "http://client.example.org". - hello_resource = cors.add(app.router.add_resource("/hello")) - cors.add(hello_resource.add_route("POST", handler_post)) - cors.add(hello_resource.add_route("PUT", handler_put)) + # According to defaults POST and PUT will be available only to + # "http://client.example.org". + hello_resource = cors.add(app.router.add_resource("/hello")) + cors.add(hello_resource.add_route("POST", handler_post)) + cors.add(hello_resource.add_route("PUT", handler_put)) - # In addition to "http://client.example.org", GET request will be - # allowed from "http://other-client.example.org" origin. - cors.add(hello_resource.add_route("GET", handler), { - "http://other-client.example.org": - aiohttp_cors.ResourceOptions(), - }) + # In addition to "http://client.example.org", GET request will be + # allowed from "http://other-client.example.org" origin. + cors.add(hello_resource.add_route("GET", handler), { + "http://other-client.example.org": + aiohttp_cors.ResourceOptions(), + }) - # CORS will be enabled only on the resources added to `CorsConfig`, - # so following resource will be NOT CORS-enabled. - app.router.add_route("GET", "/private", handler) + # CORS will be enabled only on the resources added to `CorsConfig`, + # so following resource will be NOT CORS-enabled. + app.router.add_route("GET", "/private", handler) diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py index 661094e..098c3d3 100644 --- a/tests/integration/test_main.py +++ b/tests/integration/test_main.py @@ -15,18 +15,14 @@ """Test generic usage """ -import asyncio import pathlib -from yarl import URL +import pytest -from tests.aio_test_base import AioTestBase, create_server, asynctest - -import aiohttp from aiohttp import web from aiohttp import hdrs -from aiohttp_cors import setup, ResourceOptions +from aiohttp_cors import setup as _setup, ResourceOptions, CorsViewMixin TEST_BODY = "Hello, world" @@ -34,9 +30,8 @@ SERVER_CUSTOM_HEADER_NAME = "X-Server-Custom-Header" SERVER_CUSTOM_HEADER_VALUE = "some value" -@asyncio.coroutine # pylint: disable=unused-argument -def handler(request: web.Request) -> web.StreamResponse: +async def handler(request: web.Request) -> web.StreamResponse: """Dummy request handler, returning `TEST_BODY`.""" response = web.Response(text=TEST_BODY) @@ -45,800 +40,876 @@ def handler(request: web.Request) -> web.StreamResponse: return response -class AioAiohttpAppTestBase(AioTestBase): - """Base class for tests that create single aiohttp server. +class WebViewHandler(web.View, CorsViewMixin): - Class manages server creation using create_server() method and proper - server shutdown. - """ + async def get(self) -> web.StreamResponse: + """Dummy request handler, returning `TEST_BODY`.""" + response = web.Response(text=TEST_BODY) - def setUp(self): - super().setUp() + response.headers[SERVER_CUSTOM_HEADER_NAME] = \ + SERVER_CUSTOM_HEADER_VALUE - self.handler = None - self.app = None - self.url = None + return response - self.server = None - self.session = aiohttp.ClientSession(loop=self.loop) +@pytest.fixture(params=['resource', 'view', 'route']) +def make_app(request): + def inner(defaults, route_config): + app = web.Application() + cors = _setup(app, defaults=defaults) - def tearDown(self): - self.session.close() + if request.param == 'resource': + resource = cors.add(app.router.add_resource("/resource")) + cors.add(resource.add_route("GET", handler), route_config) + elif request.param == 'view': + WebViewHandler.cors_config = route_config + cors.add( + app.router.add_route("*", "/resource", WebViewHandler)) + elif request.param == 'route': + cors.add( + app.router.add_route("GET", "/resource", handler), + route_config) + else: + raise RuntimeError('unknown parameter {}'.format(request.param)) - if self.server is not None: - self.loop.run_until_complete(self.shutdown_server()) + return app - super().tearDown() + return inner - @asyncio.coroutine - def create_server(self, app: web.Application): - """Create server listening on random port.""" - assert self.app is None - self.app = app +async def test_message_roundtrip(aiohttp_client): + """Test that aiohttp server is correctly setup in the base class.""" - assert self.handler is None - self.handler = app.make_handler() + app = web.Application() + app.router.add_route("GET", "/", handler) - self.server = (yield from create_server(self.handler, self.loop)) + client = await aiohttp_client(app) - return self.server + resp = await client.get('/') + assert resp.status == 200 + data = await resp.text() - @property - def server_url(self): - """Server navigatable URL.""" - assert self.server is not None - hostaddr, port = self.server.sockets[0].getsockname() - return "http://{host}:{port}/".format(host=hostaddr, port=port) + assert data == TEST_BODY - @asyncio.coroutine - def shutdown_server(self): - """Shutdown server.""" - assert self.server is not None - self.server.close() - yield from self.handler.finish_connections() - yield from self.server.wait_closed() - yield from self.app.cleanup() +async def test_dummy_setup(aiohttp_server): + """Test a dummy configuration.""" + app = web.Application() + _setup(app) - self.server = None - self.app = None - self.handler = None + await aiohttp_server(app) -class TestMain(AioAiohttpAppTestBase): - """Tests CORS server by issuing CORS requests.""" +async def test_dummy_setup_roundtrip(aiohttp_client): + """Test a dummy configuration with a message round-trip.""" + app = web.Application() + _setup(app) - @asynctest - @asyncio.coroutine - def test_message_roundtrip(self): - """Test that aiohttp server is correctly setup in the base class.""" + app.router.add_route("GET", "/", handler) - app = web.Application() + client = await aiohttp_client(app) - app.router.add_route("GET", "/", handler) + resp = await client.get('/') + assert resp.status == 200 + data = await resp.text() - yield from self.create_server(app) + assert data == TEST_BODY - response = yield from self.session.request("GET", self.server_url) - self.assertEqual(response.status, 200) - data = yield from response.text() - self.assertEqual(data, TEST_BODY) +async def test_dummy_setup_roundtrip_resource(aiohttp_client): + """Test a dummy configuration with a message round-trip.""" + app = web.Application() + _setup(app) - @asynctest - @asyncio.coroutine - def test_dummy_setup(self): - """Test a dummy configuration.""" - app = web.Application() - setup(app) + app.router.add_resource("/").add_route("GET", handler) - yield from self.create_server(app) + client = await aiohttp_client(app) - @asynctest - @asyncio.coroutine - def test_dummy_setup_roundtrip(self): - """Test a dummy configuration with a message round-trip.""" - app = web.Application() - setup(app) + resp = await client.get('/') + assert resp.status == 200 + data = await resp.text() - app.router.add_route("GET", "/", handler) + assert data == TEST_BODY - yield from self.create_server(app) - response = yield from self.session.request("GET", self.server_url) - self.assertEqual(response.status, 200) - data = yield from response.text() +async def test_simple_no_origin(aiohttp_client, make_app): + app = make_app(None, {"http://client1.example.org": + ResourceOptions()}) - self.assertEqual(data, TEST_BODY) + client = await aiohttp_client(app) - @asynctest - @asyncio.coroutine - def test_dummy_setup_roundtrip_resource(self): - """Test a dummy configuration with a message round-trip.""" - app = web.Application() - setup(app) - - app.router.add_resource("/").add_route("GET", handler) - - yield from self.create_server(app) - - response = yield from self.session.request("GET", self.server_url) - self.assertEqual(response.status, 200) - data = yield from response.text() - - self.assertEqual(data, TEST_BODY) - - @asyncio.coroutine - def _run_simple_requests_tests(self, - tests_descriptions, - use_resources): - """Runs CORS simple requests (without a preflight request) based - on the passed tests descriptions. - """ - - @asyncio.coroutine - def run_test(test): - """Run single test""" - - response = yield from self.session.get( - self.server_url + "resource", - headers=test.get("request_headers", {})) - self.assertEqual(response.status, 200) - self.assertEqual((yield from response.text()), TEST_BODY) - - for header_name, header_value in test.get( - "in_response_headers", {}).items(): - with self.subTest(header_name=header_name): - self.assertEqual( - response.headers.get(header_name), - header_value) - - for header_name in test.get("not_in_request_headers", {}).items(): - self.assertNotIn(header_name, response.headers) - - for test_descr in tests_descriptions: - with self.subTest(group_name=test_descr["name"]): - app = web.Application() - cors = setup(app, defaults=test_descr["defaults"]) - - if use_resources: - resource = cors.add(app.router.add_resource("/resource")) - cors.add(resource.add_route("GET", handler), - test_descr["route_config"]) - - else: - cors.add( - app.router.add_route("GET", "/resource", handler), - test_descr["route_config"]) - - yield from self.create_server(app) - - try: - for test_data in test_descr["tests"]: - with self.subTest(name=test_data["name"]): - yield from run_test(test_data) - finally: - yield from self.shutdown_server() - - @asynctest - @asyncio.coroutine - def test_simple_default(self): - """Test CORS simple requests with a route with the default - configuration. - - The default configuration means that: - * no credentials are allowed, - * no headers are exposed, - * no client headers are allowed. - """ - - client1 = "http://client1.example.org" - client2 = "http://client2.example.org" - client1_80 = "http://client1.example.org:80" - client1_https = "https://client2.example.org" - - tests_descriptions = [ - { - "name": "default", - "defaults": None, - "route_config": - { - client1: ResourceOptions(), - }, - "tests": [ - { - "name": "no origin header", - "not_in_response_headers": { - hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, - hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, - hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, - } - }, - { - "name": "allowed origin", - "request_headers": { - hdrs.ORIGIN: client1, - }, - "in_response_headers": { - hdrs.ACCESS_CONTROL_ALLOW_ORIGIN: client1, - }, - "not_in_response_headers": { - hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, - hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, - } - }, - { - "name": "not allowed origin", - "request_headers": { - hdrs.ORIGIN: client2, - }, - "not_in_response_headers": { - hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, - hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, - hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, - } - }, - { - "name": "explicitly specified default port", - # CORS specification says, that origins may compared - # as strings, so "example.org:80" is not the same as - # "example.org". - "request_headers": { - hdrs.ORIGIN: client1_80, - }, - "not_in_response_headers": { - hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, - hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, - hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, - } - }, - { - "name": "different scheme", - "request_headers": { - hdrs.ORIGIN: client1_https, - }, - "not_in_response_headers": { - hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, - hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, - hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, - } - }, - ], - }, - ] - - yield from self._run_simple_requests_tests(tests_descriptions, False) - yield from self._run_simple_requests_tests(tests_descriptions, True) - - @asynctest - @asyncio.coroutine - def test_simple_with_credentials(self): - """Test CORS simple requests with a route with enabled authorization. - - Route with enabled authorization must return - Origin: - Access-Control-Allow-Credentials: true - """ - - client1 = "http://client1.example.org" - client2 = "http://client2.example.org" - - credential_tests = [ - { - "name": "no origin header", - "not_in_response_headers": { - hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, - hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, - hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, - } - }, - { - "name": "allowed origin", - "request_headers": { - hdrs.ORIGIN: client1, - }, - "in_response_headers": { - hdrs.ACCESS_CONTROL_ALLOW_ORIGIN: client1, - hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS: "true", - }, - "not_in_response_headers": { - hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, - } - }, - { - "name": "disallowed origin", - "request_headers": { - hdrs.ORIGIN: client2, - }, - "not_in_response_headers": { - hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, - hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, - hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, - } - }, - ] - - tests_descriptions = [ - { - "name": "route settings", - "defaults": None, - "route_config": - { - client1: ResourceOptions(allow_credentials=True), - }, - "tests": credential_tests, - }, - { - "name": "cors default settings", - "defaults": - { - client1: ResourceOptions(allow_credentials=True), - }, - "route_config": None, - "tests": credential_tests, - }, - ] - - yield from self._run_simple_requests_tests(tests_descriptions, False) - yield from self._run_simple_requests_tests(tests_descriptions, True) - - @asynctest - @asyncio.coroutine - def test_simple_expose_headers(self): - """Test CORS simple requests with a route that exposes header.""" - - client1 = "http://client1.example.org" - client2 = "http://client2.example.org" - - tests_descriptions = [ - { - "name": "default", - "defaults": None, - "route_config": - { - client1: ResourceOptions( - expose_headers=(SERVER_CUSTOM_HEADER_NAME,)), - }, - "tests": [ - { - "name": "no origin header", - "not_in_response_headers": { - hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, - hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, - hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, - } - }, - { - "name": "allowed origin", - "request_headers": { - hdrs.ORIGIN: client1, - }, - "in_response_headers": { - hdrs.ACCESS_CONTROL_ALLOW_ORIGIN: client1, - hdrs.ACCESS_CONTROL_EXPOSE_HEADERS: - SERVER_CUSTOM_HEADER_NAME, - }, - "not_in_response_headers": { - hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, - } - }, - { - "name": "not allowed origin", - "request_headers": { - hdrs.ORIGIN: client2, - }, - "not_in_response_headers": { - hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, - hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, - hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, - } - }, - ], - }, - ] - - yield from self._run_simple_requests_tests(tests_descriptions, False) - yield from self._run_simple_requests_tests(tests_descriptions, True) - yield from self._run_simple_requests_tests(tests_descriptions, True) - - @asyncio.coroutine - def _run_preflight_requests_tests(self, tests_descriptions, use_resources): - """Runs CORS preflight requests based on the passed tests descriptions. - """ - - @asyncio.coroutine - def run_test(test): - """Run single test""" - - response = yield from self.session.options( - self.server_url + "resource", - headers=test.get("request_headers", {})) - self.assertEqual(response.status, test.get("response_status", 200)) - response_text = yield from response.text() - in_response = test.get("in_response") - if in_response is not None: - self.assertIn(in_response, response_text) - else: - self.assertEqual(response_text, "") - - for header_name, header_value in test.get( - "in_response_headers", {}).items(): - self.assertEqual( - response.headers.get(header_name), - header_value) - - for header_name in test.get("not_in_request_headers", {}).items(): - self.assertNotIn(header_name, response.headers) - - for test_descr in tests_descriptions: - with self.subTest(group_name=test_descr["name"]): - app = web.Application() - cors = setup(app, defaults=test_descr["defaults"]) - - if use_resources: - resource = cors.add(app.router.add_resource("/resource")) - cors.add(resource.add_route("GET", handler), - test_descr["route_config"]) - - else: - cors.add( - app.router.add_route("GET", "/resource", handler), - test_descr["route_config"]) - - yield from self.create_server(app) - - try: - for test_data in test_descr["tests"]: - with self.subTest(name=test_data["name"]): - yield from run_test(test_data) - finally: - yield from self.shutdown_server() - - @asynctest - @asyncio.coroutine - def test_preflight_default(self): - """Test CORS preflight requests with a route with the default - configuration. - - The default configuration means that: - * no credentials are allowed, - * no headers are exposed, - * no client headers are allowed. - """ - - client1 = "http://client1.example.org" - client2 = "http://client2.example.org" - - tests_descriptions = [ - { - "name": "default", - "defaults": None, - "route_config": - { - client1: ResourceOptions(), - }, - "tests": [ - { - "name": "no origin", - "response_status": 403, - "in_response": "origin header is not specified", - "not_in_response_headers": { - hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, - hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, - hdrs.ACCESS_CONTROL_MAX_AGE, - hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, - hdrs.ACCESS_CONTROL_ALLOW_METHODS, - hdrs.ACCESS_CONTROL_ALLOW_HEADERS, - }, - }, - { - "name": "no method", - "request_headers": { - hdrs.ORIGIN: client1, - }, - "response_status": 403, - "in_response": "'Access-Control-Request-Method' " - "header is not specified", - "not_in_response_headers": { - hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, - hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, - hdrs.ACCESS_CONTROL_MAX_AGE, - hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, - hdrs.ACCESS_CONTROL_ALLOW_METHODS, - hdrs.ACCESS_CONTROL_ALLOW_HEADERS, - }, - }, - { - "name": "origin and method", - "request_headers": { - hdrs.ORIGIN: client1, - hdrs.ACCESS_CONTROL_REQUEST_METHOD: "GET", - }, - "in_response_headers": { - hdrs.ACCESS_CONTROL_ALLOW_ORIGIN: client1, - hdrs.ACCESS_CONTROL_ALLOW_METHODS: "GET", - }, - "not_in_response_headers": { - hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, - hdrs.ACCESS_CONTROL_MAX_AGE, - hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, - hdrs.ACCESS_CONTROL_ALLOW_HEADERS, - }, - }, - { - "name": "disallowed origin", - "request_headers": { - hdrs.ORIGIN: client2, - hdrs.ACCESS_CONTROL_REQUEST_METHOD: "GET", - }, - "response_status": 403, - "in_response": "origin '{}' is not allowed".format( - client2), - "not_in_response_headers": { - hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, - hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, - hdrs.ACCESS_CONTROL_MAX_AGE, - hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, - hdrs.ACCESS_CONTROL_ALLOW_METHODS, - hdrs.ACCESS_CONTROL_ALLOW_HEADERS, - }, - }, - { - "name": "disallowed method", - "request_headers": { - hdrs.ORIGIN: client1, - hdrs.ACCESS_CONTROL_REQUEST_METHOD: "POST", - }, - "response_status": 403, - "in_response": "request method 'POST' is not allowed", - "not_in_response_headers": { - hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, - hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, - hdrs.ACCESS_CONTROL_MAX_AGE, - hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, - hdrs.ACCESS_CONTROL_ALLOW_METHODS, - hdrs.ACCESS_CONTROL_ALLOW_HEADERS, - }, - }, - ], - }, - ] - - yield from self._run_preflight_requests_tests( - tests_descriptions, False) - yield from self._run_preflight_requests_tests( - tests_descriptions, True) - - @asynctest - @asyncio.coroutine - def test_preflight_request_multiple_routes_with_one_options(self): - """Test CORS preflight handling on resource that is available through - several routes. - """ - app = web.Application() - cors = setup(app, defaults={ - "*": ResourceOptions( - allow_credentials=True, - expose_headers="*", - allow_headers="*", - ) - }) - - cors.add(app.router.add_route("GET", "/{name}", handler)) - cors.add(app.router.add_route("PUT", "/{name}", handler)) - - yield from self.create_server(app) - - response = yield from self.session.request( - "OPTIONS", self.server_url + "user", - headers={ - hdrs.ORIGIN: "http://example.org", - hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT" - } + resp = await client.get("/resource") + assert resp.status == 200 + resp_text = await resp.text() + assert resp_text == TEST_BODY + + for header_name in { + hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, + hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, + hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, + }: + assert header_name not in resp.headers + + +async def test_simple_allowed_origin(aiohttp_client, make_app): + app = make_app(None, {"http://client1.example.org": + ResourceOptions()}) + + client = await aiohttp_client(app) + + resp = await client.get("/resource", + headers={hdrs.ORIGIN: + 'http://client1.example.org'}) + assert resp.status == 200 + resp_text = await resp.text() + assert resp_text == TEST_BODY + + for hdr, val in { + hdrs.ACCESS_CONTROL_ALLOW_ORIGIN: 'http://client1.example.org', + }.items(): + assert resp.headers.get(hdr) == val + + for header_name in { + hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, + hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, + }: + assert header_name not in resp.headers + + +async def test_simple_not_allowed_origin(aiohttp_client, make_app): + app = make_app(None, {"http://client1.example.org": + ResourceOptions()}) + + client = await aiohttp_client(app) + + resp = await client.get("/resource", + headers={hdrs.ORIGIN: + 'http://client2.example.org'}) + assert resp.status == 200 + resp_text = await resp.text() + assert resp_text == TEST_BODY + + for header_name in { + hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, + hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, + hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, + }: + assert header_name not in resp.headers + + +async def test_simple_explicit_port(aiohttp_client, make_app): + app = make_app(None, {"http://client1.example.org": + ResourceOptions()}) + + client = await aiohttp_client(app) + + resp = await client.get("/resource", + headers={hdrs.ORIGIN: + 'http://client1.example.org:80'}) + assert resp.status == 200 + resp_text = await resp.text() + assert resp_text == TEST_BODY + + for header_name in { + hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, + hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, + hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, + }: + assert header_name not in resp.headers + + +async def test_simple_different_scheme(aiohttp_client, make_app): + app = make_app(None, {"http://client1.example.org": + ResourceOptions()}) + + client = await aiohttp_client(app) + + resp = await client.get("/resource", + headers={hdrs.ORIGIN: + 'https://client1.example.org'}) + assert resp.status == 200 + resp_text = await resp.text() + assert resp_text == TEST_BODY + + for header_name in { + hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, + hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, + hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, + }: + assert header_name not in resp.headers + + +@pytest.fixture(params=[ + (None, + {"http://client1.example.org": ResourceOptions(allow_credentials=True)}), + ({"http://client1.example.org": ResourceOptions(allow_credentials=True)}, + None), +]) +def app_for_credentials(make_app, request): + return make_app(*request.param) + + +async def test_cred_no_origin(aiohttp_client, app_for_credentials): + app = app_for_credentials + + client = await aiohttp_client(app) + + resp = await client.get("/resource") + assert resp.status == 200 + resp_text = await resp.text() + assert resp_text == TEST_BODY + + for header_name in { + hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, + hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, + hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, + }: + assert header_name not in resp.headers + + +async def test_cred_allowed_origin(aiohttp_client, app_for_credentials): + app = app_for_credentials + + client = await aiohttp_client(app) + + resp = await client.get("/resource", + headers={hdrs.ORIGIN: + 'http://client1.example.org'}) + assert resp.status == 200 + resp_text = await resp.text() + assert resp_text == TEST_BODY + + for hdr, val in { + hdrs.ACCESS_CONTROL_ALLOW_ORIGIN: 'http://client1.example.org', + hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS: "true"}.items(): + assert resp.headers.get(hdr) == val + + for header_name in { + hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, + }: + assert header_name not in resp.headers + + +async def test_cred_disallowed_origin(aiohttp_client, app_for_credentials): + app = app_for_credentials + + client = await aiohttp_client(app) + + resp = await client.get("/resource", + headers={hdrs.ORIGIN: + 'http://client2.example.org'}) + assert resp.status == 200 + resp_text = await resp.text() + assert resp_text == TEST_BODY + + for header_name in { + hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, + hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, + hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, + }: + assert header_name not in resp.headers + + +async def test_simple_expose_headers_no_origin(aiohttp_client, make_app): + app = make_app(None, {"http://client1.example.org": + ResourceOptions( + expose_headers=(SERVER_CUSTOM_HEADER_NAME,))}) + + client = await aiohttp_client(app) + + resp = await client.get("/resource") + assert resp.status == 200 + resp_text = await resp.text() + assert resp_text == TEST_BODY + + for header_name in { + hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, + hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, + hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, + }: + assert header_name not in resp.headers + + +async def test_simple_expose_headers_allowed_origin(aiohttp_client, make_app): + app = make_app(None, {"http://client1.example.org": + ResourceOptions( + expose_headers=(SERVER_CUSTOM_HEADER_NAME,))}) + + client = await aiohttp_client(app) + + resp = await client.get("/resource", + headers={hdrs.ORIGIN: + 'http://client1.example.org'}) + assert resp.status == 200 + resp_text = await resp.text() + assert resp_text == TEST_BODY + + for hdr, val in { + hdrs.ACCESS_CONTROL_ALLOW_ORIGIN: 'http://client1.example.org', + hdrs.ACCESS_CONTROL_EXPOSE_HEADERS: + SERVER_CUSTOM_HEADER_NAME}.items(): + assert resp.headers.get(hdr) == val + + for header_name in { + hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, + }: + assert header_name not in resp.headers + + +async def test_simple_expose_headers_not_allowed_origin(aiohttp_client, + make_app): + app = make_app(None, {"http://client1.example.org": + ResourceOptions( + expose_headers=(SERVER_CUSTOM_HEADER_NAME,))}) + + client = await aiohttp_client(app) + + resp = await client.get("/resource", + headers={hdrs.ORIGIN: + 'http://client2.example.org'}) + assert resp.status == 200 + resp_text = await resp.text() + assert resp_text == TEST_BODY + + for header_name in { + hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, + hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, + hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, + }: + assert header_name not in resp.headers + + +async def test_preflight_default_no_origin(aiohttp_client, make_app): + app = make_app(None, {"http://client1.example.org": + ResourceOptions()}) + + client = await aiohttp_client(app) + + resp = await client.options("/resource") + assert resp.status == 403 + resp_text = await resp.text() + assert "origin header is not specified" in resp_text + + for header_name in { + hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, + hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, + hdrs.ACCESS_CONTROL_MAX_AGE, + hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, + hdrs.ACCESS_CONTROL_ALLOW_METHODS, + hdrs.ACCESS_CONTROL_ALLOW_HEADERS, + }: + assert header_name not in resp.headers + + +async def test_preflight_default_no_method(aiohttp_client, make_app): + + app = make_app(None, {"http://client1.example.org": + ResourceOptions()}) + + client = await aiohttp_client(app) + + resp = await client.options("/resource", headers={ + hdrs.ORIGIN: "http://client1.example.org", + }) + assert resp.status == 403 + resp_text = await resp.text() + assert "'Access-Control-Request-Method' header is not specified"\ + in resp_text + + for header_name in { + hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, + hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, + hdrs.ACCESS_CONTROL_MAX_AGE, + hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, + hdrs.ACCESS_CONTROL_ALLOW_METHODS, + hdrs.ACCESS_CONTROL_ALLOW_HEADERS, + }: + assert header_name not in resp.headers + + +async def test_preflight_default_origin_and_method(aiohttp_client, make_app): + + app = make_app(None, {"http://client1.example.org": + ResourceOptions()}) + + client = await aiohttp_client(app) + + resp = await client.options("/resource", headers={ + hdrs.ORIGIN: "http://client1.example.org", + hdrs.ACCESS_CONTROL_REQUEST_METHOD: "GET", + }) + assert resp.status == 200 + resp_text = await resp.text() + assert '' == resp_text + + for hdr, val in { + hdrs.ACCESS_CONTROL_ALLOW_ORIGIN: "http://client1.example.org", + hdrs.ACCESS_CONTROL_ALLOW_METHODS: "GET"}.items(): + assert resp.headers.get(hdr) == val + + for header_name in { + hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, + hdrs.ACCESS_CONTROL_MAX_AGE, + hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, + hdrs.ACCESS_CONTROL_ALLOW_HEADERS, + }: + assert header_name not in resp.headers + + +async def test_preflight_default_disallowed_origin(aiohttp_client, make_app): + + app = make_app(None, {"http://client1.example.org": + ResourceOptions()}) + + client = await aiohttp_client(app) + + resp = await client.options("/resource", headers={ + hdrs.ORIGIN: "http://client2.example.org", + hdrs.ACCESS_CONTROL_REQUEST_METHOD: "GET", + }) + assert resp.status == 403 + resp_text = await resp.text() + assert "origin 'http://client2.example.org' is not allowed" in resp_text + + for header_name in { + hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, + hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, + hdrs.ACCESS_CONTROL_MAX_AGE, + hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, + hdrs.ACCESS_CONTROL_ALLOW_METHODS, + hdrs.ACCESS_CONTROL_ALLOW_HEADERS, + }: + assert header_name not in resp.headers + + +async def test_preflight_default_disallowed_method(aiohttp_client, make_app): + + app = make_app(None, {"http://client1.example.org": + ResourceOptions()}) + + client = await aiohttp_client(app) + + resp = await client.options("/resource", headers={ + hdrs.ORIGIN: "http://client1.example.org", + hdrs.ACCESS_CONTROL_REQUEST_METHOD: "POST", + }) + assert resp.status == 403 + resp_text = await resp.text() + assert ("request method 'POST' is not allowed for " + "'http://client1.example.org' origin" in resp_text) + + for header_name in { + hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, + hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, + hdrs.ACCESS_CONTROL_MAX_AGE, + hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, + hdrs.ACCESS_CONTROL_ALLOW_METHODS, + hdrs.ACCESS_CONTROL_ALLOW_HEADERS, + }: + assert header_name not in resp.headers + + +async def test_preflight_req_multiple_routes_with_one_options(aiohttp_client): + """Test CORS preflight handling on resource that is available through + several routes. + """ + app = web.Application() + cors = _setup(app, defaults={ + "*": ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_headers="*", ) - self.assertEqual(response.status, 200) + }) - data = yield from response.text() - self.assertEqual(data, "") + cors.add(app.router.add_route("GET", "/{name}", handler)) + cors.add(app.router.add_route("PUT", "/{name}", handler)) - @asynctest - @asyncio.coroutine - def test_preflight_request_multiple_routes_with_one_options_resource(self): - """Test CORS preflight handling on resource that is available through - several routes. - """ - app = web.Application() - cors = setup(app, defaults={ - "*": ResourceOptions( - allow_credentials=True, - expose_headers="*", - allow_headers="*", - ) - }) - - resource = cors.add(app.router.add_resource("/{name}")) - cors.add(resource.add_route("GET", handler)) - cors.add(resource.add_route("PUT", handler)) - - yield from self.create_server(app) - - response = yield from self.session.request( - "OPTIONS", self.server_url + "user", - headers={ - hdrs.ORIGIN: "http://example.org", - hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT" - } + client = await aiohttp_client(app) + + resp = await client.options( + "/user", + headers={ + hdrs.ORIGIN: "http://example.org", + hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT" + } + ) + assert resp.status == 200 + + data = await resp.text() + assert data == "" + + +async def test_preflight_request_mult_routes_with_one_options_resource( + aiohttp_client): + """Test CORS preflight handling on resource that is available through + several routes. + """ + app = web.Application() + cors = _setup(app, defaults={ + "*": ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_headers="*", ) - self.assertEqual(response.status, 200) + }) - data = yield from response.text() - self.assertEqual(data, "") + resource = cors.add(app.router.add_resource("/{name}")) + cors.add(resource.add_route("GET", handler)) + cors.add(resource.add_route("PUT", handler)) - @asynctest - @asyncio.coroutine - def test_preflight_request_headers_resource(self): - """Test CORS preflight request handlers handling.""" - app = web.Application() - cors = setup(app, defaults={ - "*": ResourceOptions( - allow_credentials=True, - expose_headers="*", - allow_headers=("Content-Type", "X-Header"), - ) - }) - - cors.add(app.router.add_route("PUT", "/", handler)) - - yield from self.create_server(app) - - response = yield from self.session.request( - "OPTIONS", self.server_url, - headers={ - hdrs.ORIGIN: "http://example.org", - hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT", - hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type", - } + client = await aiohttp_client(app) + + resp = await client.options( + "/user", + headers={ + hdrs.ORIGIN: "http://example.org", + hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT" + } + ) + assert resp.status == 200 + + data = await resp.text() + assert data == "" + + +async def test_preflight_request_max_age_resource(aiohttp_client): + """Test CORS preflight handling on resource that is available through + several routes. + """ + app = web.Application() + cors = _setup(app, defaults={ + "*": ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_headers="*", + max_age=1200 ) - self.assertEqual((yield from response.text()), "") - self.assertEqual(response.status, 200) - # Access-Control-Allow-Headers must be compared in case-insensitive - # way. - self.assertEqual( - response.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS].upper(), - "content-type".upper()) + }) + + resource = cors.add(app.router.add_resource("/{name}")) + cors.add(resource.add_route("GET", handler)) + + client = await aiohttp_client(app) + + resp = await client.options( + "/user", + headers={ + hdrs.ORIGIN: "http://example.org", + hdrs.ACCESS_CONTROL_REQUEST_METHOD: "GET" + } + ) + assert resp.status == 200 + assert resp.headers[hdrs.ACCESS_CONTROL_MAX_AGE].upper() == "1200" + + data = await resp.text() + assert data == "" - response = yield from self.session.request( - "OPTIONS", self.server_url, - headers={ - hdrs.ORIGIN: "http://example.org", - hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT", - hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "X-Header,content-type", - } + +async def test_preflight_request_max_age_webview(aiohttp_client): + """Test CORS preflight handling on resource that is available through + several routes. + """ + app = web.Application() + cors = _setup(app, defaults={ + "*": ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_headers="*", + max_age=1200 ) - self.assertEqual(response.status, 200) - # Access-Control-Allow-Headers must be compared in case-insensitive - # way. - self.assertEqual( - frozenset(response.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS] - .upper().split(",")), - {"X-Header".upper(), "content-type".upper()}) - self.assertEqual((yield from response.text()), "") - - response = yield from self.session.request( - "OPTIONS", self.server_url, - headers={ - hdrs.ORIGIN: "http://example.org", - hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT", - hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type,Test", - } + }) + + class TestView(web.View, CorsViewMixin): + async def get(self): + resp = web.Response(text=TEST_BODY) + + resp.headers[SERVER_CUSTOM_HEADER_NAME] = \ + SERVER_CUSTOM_HEADER_VALUE + + return resp + + cors.add(app.router.add_route("*", "/{name}", TestView)) + + client = await aiohttp_client(app) + + resp = await client.options( + "/user", + headers={ + hdrs.ORIGIN: "http://example.org", + hdrs.ACCESS_CONTROL_REQUEST_METHOD: "GET" + } + ) + assert resp.status == 200 + assert resp.headers[hdrs.ACCESS_CONTROL_MAX_AGE].upper() == "1200" + + data = await resp.text() + assert data == "" + + +async def test_preflight_request_mult_routes_with_one_options_webview( + aiohttp_client): + """Test CORS preflight handling on resource that is available through + several routes. + """ + app = web.Application() + cors = _setup(app, defaults={ + "*": ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_headers="*", ) - self.assertEqual(response.status, 403) - self.assertNotIn( - hdrs.ACCESS_CONTROL_ALLOW_HEADERS, - response.headers) - self.assertIn( - "headers are not allowed: TEST", - (yield from response.text())) - - @asynctest - @asyncio.coroutine - def test_preflight_request_headers(self): - """Test CORS preflight request handlers handling.""" - app = web.Application() - cors = setup(app, defaults={ - "*": ResourceOptions( - allow_credentials=True, - expose_headers="*", - allow_headers=("Content-Type", "X-Header"), - ) - }) - - resource = cors.add(app.router.add_resource("/")) - cors.add(resource.add_route("PUT", handler)) - - yield from self.create_server(app) - - response = yield from self.session.request( - "OPTIONS", self.server_url, - headers={ - hdrs.ORIGIN: "http://example.org", - hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT", - hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type", - } + }) + + class TestView(web.View, CorsViewMixin): + async def get(self): + resp = web.Response(text=TEST_BODY) + + resp.headers[SERVER_CUSTOM_HEADER_NAME] = \ + SERVER_CUSTOM_HEADER_VALUE + + return resp + + put = get + + cors.add(app.router.add_route("*", "/{name}", TestView)) + + client = await aiohttp_client(app) + + resp = await client.options( + "/user", + headers={ + hdrs.ORIGIN: "http://example.org", + hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT" + } + ) + assert resp.status == 200 + + data = await resp.text() + assert data == "" + + +async def test_preflight_request_headers_webview(aiohttp_client): + """Test CORS preflight request handlers handling.""" + app = web.Application() + cors = _setup(app, defaults={ + "*": ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_headers=("Content-Type", "X-Header"), ) - self.assertEqual((yield from response.text()), "") - self.assertEqual(response.status, 200) - # Access-Control-Allow-Headers must be compared in case-insensitive - # way. - self.assertEqual( - response.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS].upper(), + }) + + class TestView(web.View, CorsViewMixin): + async def put(self): + response = web.Response(text=TEST_BODY) + + response.headers[SERVER_CUSTOM_HEADER_NAME] = \ + SERVER_CUSTOM_HEADER_VALUE + + return response + + cors.add(app.router.add_route("*", "/", TestView)) + + client = await aiohttp_client(app) + + resp = await client.options( + '/', + headers={ + hdrs.ORIGIN: "http://example.org", + hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT", + hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type", + } + ) + assert (await resp.text()) == "" + assert resp.status == 200 + # Access-Control-Allow-Headers must be compared in case-insensitive + # way. + assert (resp.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS].upper() == "content-type".upper()) - response = yield from self.session.request( - "OPTIONS", self.server_url, - headers={ - hdrs.ORIGIN: "http://example.org", - hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT", - hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "X-Header,content-type", - } + resp = await client.options( + '/', + headers={ + hdrs.ORIGIN: "http://example.org", + hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT", + hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "X-Header,content-type", + } + ) + assert resp.status == 200 + # Access-Control-Allow-Headers must be compared in case-insensitive + # way. + assert ( + frozenset(resp.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS] + .upper().split(",")) == + {"X-Header".upper(), "content-type".upper()}) + assert (await resp.text()) == "" + + resp = await client.options( + '/', + headers={ + hdrs.ORIGIN: "http://example.org", + hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT", + hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type,Test", + } + ) + assert resp.status == 403 + assert hdrs.ACCESS_CONTROL_ALLOW_HEADERS not in resp.headers + assert "headers are not allowed: TEST" in (await resp.text()) + + +async def test_preflight_request_headers_resource(aiohttp_client): + """Test CORS preflight request handlers handling.""" + app = web.Application() + cors = _setup(app, defaults={ + "*": ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_headers=("Content-Type", "X-Header"), ) - self.assertEqual(response.status, 200) - # Access-Control-Allow-Headers must be compared in case-insensitive - # way. - self.assertEqual( - frozenset(response.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS] - .upper().split(",")), - {"X-Header".upper(), "content-type".upper()}) - self.assertEqual((yield from response.text()), "") - - response = yield from self.session.request( - "OPTIONS", self.server_url, - headers={ - hdrs.ORIGIN: "http://example.org", - hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT", - hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type,Test", - } + }) + + cors.add(app.router.add_route("PUT", "/", handler)) + + client = await aiohttp_client(app) + + resp = await client.options( + '/', + headers={ + hdrs.ORIGIN: "http://example.org", + hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT", + hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type", + } + ) + assert (await resp.text()) == "" + assert resp.status == 200 + # Access-Control-Allow-Headers must be compared in case-insensitive + # way. + assert ( + resp.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS].upper() == + "content-type".upper()) + + resp = await client.options( + '/', + headers={ + hdrs.ORIGIN: "http://example.org", + hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT", + hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "X-Header,content-type", + } + ) + assert resp.status == 200 + # Access-Control-Allow-Headers must be compared in case-insensitive + # way. + assert ( + frozenset(resp.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS] + .upper().split(",")) == + {"X-Header".upper(), "content-type".upper()}) + assert (await resp.text()) == "" + + resp = await client.options( + '/', + headers={ + hdrs.ORIGIN: "http://example.org", + hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT", + hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type,Test", + } + ) + assert resp.status == 403 + assert hdrs.ACCESS_CONTROL_ALLOW_HEADERS not in resp.headers + assert "headers are not allowed: TEST" in (await resp.text()) + + +async def test_preflight_request_headers(aiohttp_client): + """Test CORS preflight request handlers handling.""" + app = web.Application() + cors = _setup(app, defaults={ + "*": ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_headers=("Content-Type", "X-Header"), ) - self.assertEqual(response.status, 403) - self.assertNotIn( - hdrs.ACCESS_CONTROL_ALLOW_HEADERS, - response.headers) - self.assertIn( - "headers are not allowed: TEST", - (yield from response.text())) - - @asynctest - @asyncio.coroutine - def test_static_route(self): - """Test a static route with CORS.""" - app = web.Application() - cors = setup(app, defaults={ - "*": ResourceOptions( - allow_credentials=True, - expose_headers="*", - allow_methods="*", - allow_headers=("Content-Type", "X-Header"), - ) - }) - - test_static_path = pathlib.Path(__file__).parent - cors.add(app.router.add_static("/static", test_static_path, name='static')) - - yield from self.create_server(app) - - response = yield from self.session.request( - "OPTIONS", URL(self.server_url) / "static/test_page.html", - headers={ - hdrs.ORIGIN: "http://example.org", - hdrs.ACCESS_CONTROL_REQUEST_METHOD: "OPTIONS", - hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type", - } + }) + + resource = cors.add(app.router.add_resource("/")) + cors.add(resource.add_route("PUT", handler)) + + client = await aiohttp_client(app) + + resp = await client.options( + '/', + headers={ + hdrs.ORIGIN: "http://example.org", + hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT", + hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type", + } + ) + assert (await resp.text()) == "" + assert resp.status == 200 + # Access-Control-Allow-Headers must be compared in case-insensitive + # way. + assert ( + resp.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS].upper() == + "content-type".upper()) + + resp = await client.options( + '/', + headers={ + hdrs.ORIGIN: "http://example.org", + hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT", + hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "X-Header,content-type", + } + ) + assert resp.status == 200 + # Access-Control-Allow-Headers must be compared in case-insensitive + # way. + assert ( + frozenset(resp.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS] + .upper().split(",")) == + {"X-Header".upper(), "content-type".upper()}) + assert (await resp.text()) == "" + + resp = await client.options( + '/', + headers={ + hdrs.ORIGIN: "http://example.org", + hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT", + hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type,Test", + } + ) + assert resp.status == 403 + assert hdrs.ACCESS_CONTROL_ALLOW_HEADERS not in resp.headers + assert "headers are not allowed: TEST" in (await resp.text()) + + +async def test_static_route(aiohttp_client): + """Test a static route with CORS.""" + app = web.Application() + cors = _setup(app, defaults={ + "*": ResourceOptions( + allow_credentials=True, + expose_headers="*", + allow_methods="*", + allow_headers=("Content-Type", "X-Header"), ) - data = yield from response.text() - self.assertEqual(response.status, 200) - self.assertEqual(data, '') + }) + + test_static_path = pathlib.Path(__file__).parent + cors.add(app.router.add_static("/static", test_static_path, + name='static')) + + client = await aiohttp_client(app) + + resp = await client.options( + "/static/test_page.html", + headers={ + hdrs.ORIGIN: "http://example.org", + hdrs.ACCESS_CONTROL_REQUEST_METHOD: "OPTIONS", + hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type", + } + ) + data = await resp.text() + assert resp.status == 200 + assert data == '' # TODO: test requesting resources with not configured CORS. diff --git a/tests/integration/test_page.html b/tests/integration/test_page.html index c2f4f7a..240742b 100644 --- a/tests/integration/test_page.html +++ b/tests/integration/test_page.html @@ -406,7 +406,7 @@ xhr.responseText); } else { - log('Received server addressess:', xhr.response); + log('Received server addresses:', xhr.response); setServersUrls(JSON.parse(xhr.responseText)); } diff --git a/tests/integration/test_real_browser.py b/tests/integration/test_real_browser.py index 06a828b..a5c9030 100644 --- a/tests/integration/test_real_browser.py +++ b/tests/integration/test_real_browser.py @@ -19,12 +19,12 @@ import os import json import asyncio import socket -import unittest import pathlib import logging import webbrowser from aiohttp import web, hdrs +import pytest import selenium.common.exceptions from selenium import webdriver @@ -33,9 +33,7 @@ from selenium.webdriver.common.keys import Keys from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as EC -from aiohttp_cors import setup, ResourceOptions - -from ..aio_test_base import create_server, AioTestBase, asynctest +from aiohttp_cors import setup as _setup, ResourceOptions, CorsViewMixin class _ServerDescr: @@ -52,7 +50,7 @@ class _ServerDescr: class IntegrationServers: """Integration servers starting/stopping manager""" - def __init__(self, use_resources, *, loop=None): + def __init__(self, use_resources, use_webview, *, loop=None): self.servers = {} self.loop = loop @@ -60,6 +58,7 @@ class IntegrationServers: self.loop = asyncio.get_event_loop() self.use_resources = use_resources + self.use_webview = use_webview self._logger = logging.getLogger("IntegrationServers") @@ -67,37 +66,37 @@ class IntegrationServers: def origin_server_url(self): return self.servers["origin"].url - @asyncio.coroutine - def start_servers(self): + async def start_servers(self): test_page_path = pathlib.Path(__file__).with_name("test_page.html") - @asyncio.coroutine - def handle_test_page(request: web.Request) -> web.StreamResponse: + async def handle_test_page(request: web.Request) -> web.StreamResponse: with test_page_path.open("r", encoding="utf-8") as f: return web.Response( text=f.read(), headers={hdrs.CONTENT_TYPE: "text/html"}) - @asyncio.coroutine - def handle_no_cors(request: web.Request) -> web.StreamResponse: + async def handle_no_cors(request: web.Request) -> web.StreamResponse: return web.Response( text="""{"type": "no_cors.json"}""", headers={hdrs.CONTENT_TYPE: "application/json"}) - @asyncio.coroutine - def handle_resource(request: web.Request) -> web.StreamResponse: + async def handle_resource(request: web.Request) -> web.StreamResponse: return web.Response( text="""{"type": "resource"}""", headers={hdrs.CONTENT_TYPE: "application/json"}) - @asyncio.coroutine - def handle_servers_addresses( + async def handle_servers_addresses( request: web.Request) -> web.StreamResponse: servers_addresses = \ {name: descr.url for name, descr in self.servers.items()} return web.Response( text=json.dumps(servers_addresses)) + class ResourceView(web.View, CorsViewMixin): + + async def get(self) -> web.StreamResponse: + return await handle_resource(self.request) + # For most resources: # "origin" server has no CORS configuration. # "allowing" server explicitly allows CORS requests to "origin" server. @@ -137,8 +136,12 @@ class IntegrationServers: for server_name in server_names: app = self.servers[server_name].app app.router.add_route("GET", "/no_cors.json", handle_no_cors) - app.router.add_route("GET", "/cors_resource", handle_resource, - name="cors_resource") + if self.use_webview: + app.router.add_route("*", "/cors_resource", ResourceView, + name="cors_resource") + else: + app.router.add_route("GET", "/cors_resource", handle_resource, + name="cors_resource") cors_default_configs = { "allowing": { @@ -167,7 +170,7 @@ class IntegrationServers: default_config = cors_default_configs.get(server_name) if default_config is None: continue - server_descr.cors = setup( + server_descr.cors = _setup( server_descr.app, defaults=default_config) # Add CORS routes. @@ -182,27 +185,30 @@ class IntegrationServers: server_descr.cors.add(resource) server_descr.cors.add(route) + elif self.use_webview: + server_descr.cors.add(route) + else: server_descr.cors.add(route) # Start servers. for server_name, server_descr in self.servers.items(): handler = server_descr.app.make_handler() - server = yield from create_server(handler, self.loop, - sock=server_sockets[server_name]) + server = await self.loop.create_server( + handler, + sock=server_sockets[server_name]) server_descr.handler = handler server_descr.server = server self._logger.info("Started server '%s' at '%s'", server_name, server_descr.url) - @asyncio.coroutine - def stop_servers(self): + async def stop_servers(self): for server_descr in self.servers.values(): server_descr.server.close() - yield from server_descr.handler.finish_connections() - yield from server_descr.server.wait_closed() - yield from server_descr.app.cleanup() + await server_descr.handler.shutdown() + await server_descr.server.wait_closed() + await server_descr.app.cleanup() self.servers = {} @@ -218,99 +224,64 @@ def _get_chrome_driver(): return driver -class TestInBrowser(AioTestBase): - @asyncio.coroutine - def _test_in_webdriver(self, driver, use_resources): - # TODO: Use pytest's fixtures to test use resources/not use resources. - servers = IntegrationServers(use_resources) - yield from servers.start_servers() - - def selenium_thread(): - driver.get(servers.origin_server_url) - assert "aiohttp_cors" in driver.title - - wait = WebDriverWait(driver, 10) - - run_button = wait.until(EC.element_to_be_clickable( - (By.ID, "runTestsButton"))) - - # Start tests. - run_button.send_keys(Keys.RETURN) - - # Wait while test will finish (until clear button is not - # activated). - wait.until(EC.element_to_be_clickable( - (By.ID, "clearResultsButton"))) - - # Get results json - results_area = driver.find_element_by_id("results") - - return json.loads(results_area.get_attribute("value")) - - try: - results = yield from self.loop.run_in_executor( - self.thread_pool_executor, selenium_thread) - - self.assertEqual(results["status"], "success") - for test_name, test_data in results["data"].items(): - with self.subTest(group_name=test_name): - self.assertEqual(test_data["status"], "success", - msg=(test_name, test_data)) - - finally: - yield from servers.stop_servers() - - @asynctest - @asyncio.coroutine - def test_firefox(self): - try: - driver = webdriver.Firefox() - except selenium.common.exceptions.WebDriverException: - raise unittest.SkipTest - - try: - yield from self._test_in_webdriver(driver, False) - finally: - driver.close() - - @asynctest - @asyncio.coroutine - def test_chromium(self): - try: - driver = _get_chrome_driver() - except selenium.common.exceptions.WebDriverException: - raise unittest.SkipTest - - try: - yield from self._test_in_webdriver(driver, False) - finally: - driver.close() - - @asynctest - @asyncio.coroutine - def test_firefox_resource(self): - try: - driver = webdriver.Firefox() - except selenium.common.exceptions.WebDriverException: - raise unittest.SkipTest - - try: - yield from self._test_in_webdriver(driver, True) - finally: - driver.close() - - @asynctest - @asyncio.coroutine - def test_chromium_resource(self): - try: - driver = _get_chrome_driver() - except selenium.common.exceptions.WebDriverException: - raise unittest.SkipTest - - try: - yield from self._test_in_webdriver(driver, True) - finally: - driver.close() +@pytest.fixture(params=[(False, False), + (True, False), + (False, True)]) +def server(request, loop): + async def inner(): + # to grab implicit loop + return IntegrationServers(*request.param) + return loop.run_until_complete(inner()) + + +@pytest.fixture(params=[webdriver.Firefox, + _get_chrome_driver]) +def driver(request): + try: + driver = request.param() + except selenium.common.exceptions.WebDriverException: + pytest.skip("Driver is not supported") + + yield driver + driver.close() + + +async def test_in_webdriver(driver, server): + loop = asyncio.get_event_loop() + await server.start_servers() + + def selenium_thread(): + driver.get(server.origin_server_url) + assert "aiohttp_cors" in driver.title + + wait = WebDriverWait(driver, 10) + + run_button = wait.until(EC.element_to_be_clickable( + (By.ID, "runTestsButton"))) + + # Start tests. + run_button.send_keys(Keys.RETURN) + + # Wait while test will finish (until clear button is not + # activated). + wait.until(EC.element_to_be_clickable( + (By.ID, "clearResultsButton"))) + + # Get results json + results_area = driver.find_element_by_id("results") + + return json.loads(results_area.get_attribute("value")) + + try: + results = await loop.run_in_executor( + None, selenium_thread) + + assert results["status"] == "success" + for test_name, test_data in results["data"].items(): + assert test_data["status"] == "success" + + finally: + await server.stop_servers() def _run_integration_server(): @@ -322,7 +293,7 @@ def _run_integration_server(): loop = asyncio.get_event_loop() - servers = IntegrationServers() + servers = IntegrationServers(False, True) logger.info("Starting integration servers...") loop.run_until_complete(servers.start_servers()) diff --git a/tests/unit/test___about__.py b/tests/unit/test___about__.py index 99c5672..4b16655 100644 --- a/tests/unit/test___about__.py +++ b/tests/unit/test___about__.py @@ -15,15 +15,12 @@ """Test aiohttp_cors package metainformation. """ -import unittest from pkg_resources import parse_version import aiohttp_cors -class TestMetaInformation(unittest.TestCase): - """Test package metainformation""" - # pylint: disable=no-self-use - def test_version(self): - """Test package version string""" - parse_version(aiohttp_cors.__version__) +def test_version(): + """Test package version string""" + # not raised + parse_version(aiohttp_cors.__version__) diff --git a/tests/unit/test_cors_config.py b/tests/unit/test_cors_config.py index 96e837d..5b8d8f3 100644 --- a/tests/unit/test_cors_config.py +++ b/tests/unit/test_cors_config.py @@ -16,75 +16,120 @@ """ import asyncio -import unittest +import pytest from aiohttp import web -from aiohttp_cors import CorsConfig, ResourceOptions +from aiohttp_cors import CorsConfig, ResourceOptions, CorsViewMixin -def _handler(request): +async def _handler(request): return web.Response(text="Done") -class TestCorsConfig(unittest.TestCase): - """Unit tests for CorsConfig""" - - def setUp(self): - self.loop = asyncio.new_event_loop() - self.app = web.Application(loop=self.loop) - self.cors = CorsConfig(self.app, defaults={ - "*": ResourceOptions() - }) - self.get_route = self.app.router.add_route( - "GET", "/get_path", _handler) - self.options_route = self.app.router.add_route( - "OPTIONS", "/options_path", _handler) - - def tearDown(self): - self.loop.close() - - def test_add_options_route(self): - """Test configuring OPTIONS route""" - - with self.assertRaises(RuntimeError): - self.cors.add(self.options_route.resource) - - def test_plain_named_route(self): - """Test adding plain named route.""" - # Adding CORS routes should not introduce new named routes. - self.assertEqual(len(self.app.router.keys()), 0) - route = self.app.router.add_route( - "GET", "/{name}", _handler, name="dynamic_named_route") - self.assertEqual(len(self.app.router.keys()), 1) - self.cors.add(route) - self.assertEqual(len(self.app.router.keys()), 1) - - def test_dynamic_named_route(self): - """Test adding dynamic named route.""" - self.assertEqual(len(self.app.router.keys()), 0) - route = self.app.router.add_route( - "GET", "/{name}", _handler, name="dynamic_named_route") - self.assertEqual(len(self.app.router.keys()), 1) - self.cors.add(route) - self.assertEqual(len(self.app.router.keys()), 1) - - def test_static_named_route(self): - """Test adding dynamic named route.""" - self.assertEqual(len(self.app.router.keys()), 0) - route = self.app.router.add_static( - "/file", "/", name="dynamic_named_route") - self.assertEqual(len(self.app.router.keys()), 1) - self.cors.add(route) - self.assertEqual(len(self.app.router.keys()), 1) - - def test_static_resource(self): - """Test adding static resource.""" - self.assertEqual(len(self.app.router.keys()), 0) - self.app.router.add_static( - "/file", "/", name="dynamic_named_route") - self.assertEqual(len(self.app.router.keys()), 1) - for resource in list(self.app.router.resources()): - if issubclass(resource, web.StaticResource): - self.cors.add(resource) - self.assertEqual(len(self.app.router.keys()), 1) +class _View(web.View, CorsViewMixin): + + @asyncio.coroutine + def get(self): + return web.Response(text="Done") + + +@pytest.fixture +def app(): + return web.Application() + + +@pytest.fixture +def cors(app): + return CorsConfig(app, defaults={ + "*": ResourceOptions() + }) + + +@pytest.fixture +def get_route(app): + return app.router.add_route( + "GET", "/get_path", _handler) + + +@pytest.fixture +def options_route(app): + return app.router.add_route( + "OPTIONS", "/options_path", _handler) + + +def test_add_options_route(cors, options_route): + """Test configuring OPTIONS route""" + + with pytest.raises(ValueError, + match="/options_path already has OPTIONS handler"): + cors.add(options_route.resource) + + +def test_plain_named_route(app, cors): + """Test adding plain named route.""" + # Adding CORS routes should not introduce new named routes. + assert len(app.router.keys()) == 0 + route = app.router.add_route( + "GET", "/{name}", _handler, name="dynamic_named_route") + assert len(app.router.keys()) == 1 + cors.add(route) + assert len(app.router.keys()) == 1 + + +def test_dynamic_named_route(app, cors): + """Test adding dynamic named route.""" + assert len(app.router.keys()) == 0 + route = app.router.add_route( + "GET", "/{name}", _handler, name="dynamic_named_route") + assert len(app.router.keys()) == 1 + cors.add(route) + assert len(app.router.keys()) == 1 + + +def test_static_named_route(app, cors): + """Test adding dynamic named route.""" + assert len(app.router.keys()) == 0 + route = app.router.add_static( + "/file", "/", name="dynamic_named_route") + assert len(app.router.keys()) == 1 + cors.add(route) + assert len(app.router.keys()) == 1 + + +def test_static_resource(app, cors): + """Test adding static resource.""" + assert len(app.router.keys()) == 0 + app.router.add_static( + "/file", "/", name="dynamic_named_route") + assert len(app.router.keys()) == 1 + for resource in list(app.router.resources()): + if issubclass(resource, web.StaticResource): + cors.add(resource) + assert len(app.router.keys()) == 1 + + +def test_web_view_resource(app, cors): + """Test adding resource with web.View as handler""" + assert len(app.router.keys()) == 0 + route = app.router.add_route( + "GET", "/{name}", _View, name="dynamic_named_route") + assert len(app.router.keys()) == 1 + cors.add(route) + assert len(app.router.keys()) == 1 + + +def test_web_view_warning(app, cors): + """Test adding resource with web.View as handler""" + route = app.router.add_route("*", "/", _View) + with pytest.warns(DeprecationWarning): + cors.add(route, webview=True) + + +def test_disable_bare_view(app, cors): + class View(web.View): + pass + + route = app.router.add_route("*", "/", View) + with pytest.raises(ValueError): + cors.add(route) diff --git a/tests/unit/test_mixin.py b/tests/unit/test_mixin.py new file mode 100644 index 0000000..fb33b2e --- /dev/null +++ b/tests/unit/test_mixin.py @@ -0,0 +1,125 @@ +import asyncio + +from unittest import mock + +import pytest +from aiohttp import web + +from aiohttp_cors import CorsConfig, APP_CONFIG_KEY +from aiohttp_cors import ResourceOptions, CorsViewMixin, custom_cors + + +DEFAULT_CONFIG = { + '*': ResourceOptions() +} + +CLASS_CONFIG = { + '*': ResourceOptions() +} + +CUSTOM_CONFIG = { + 'www.client1.com': ResourceOptions(allow_headers=['X-Host']) +} + + +class SimpleView(web.View, CorsViewMixin): + async def get(self): + return web.Response(text="Done") + + +class SimpleViewWithConfig(web.View, CorsViewMixin): + + cors_config = CLASS_CONFIG + + async def get(self): + return web.Response(text="Done") + + +class CustomMethodView(web.View, CorsViewMixin): + + cors_config = CLASS_CONFIG + + async def get(self): + return web.Response(text="Done") + + @custom_cors(CUSTOM_CONFIG) + async def post(self): + return web.Response(text="Done") + + +@pytest.fixture +def _app(): + return web.Application() + + +@pytest.fixture +def cors(_app): + ret = CorsConfig(_app, defaults=DEFAULT_CONFIG) + _app[APP_CONFIG_KEY] = ret + return ret + + +@pytest.fixture +def app(_app, cors): + # a trick to install a cors into app + return _app + + +def test_raise_exception_when_cors_not_configure(): + request = mock.Mock() + request.app = {} + view = CustomMethodView(request) + + with pytest.raises(ValueError): + view.get_request_config(request, 'post') + + +async def test_raises_forbidden_when_config_not_found(app): + app[APP_CONFIG_KEY].defaults = {} + request = mock.Mock() + request.app = app + request.headers = { + 'Origin': '*', + 'Access-Control-Request-Method': 'GET' + } + view = SimpleView(request) + + with pytest.raises(web.HTTPForbidden): + await view.options() + + +def test_method_with_custom_cors(app): + """Test adding resource with web.View as handler""" + request = mock.Mock() + request.app = app + view = CustomMethodView(request) + + assert hasattr(view.post, 'post_cors_config') + assert asyncio.iscoroutinefunction(view.post) + config = view.get_request_config(request, 'post') + + assert config.get('www.client1.com') == CUSTOM_CONFIG['www.client1.com'] + + +def test_method_with_class_config(app): + """Test adding resource with web.View as handler""" + request = mock.Mock() + request.app = app + view = SimpleViewWithConfig(request) + + assert not hasattr(view.get, 'get_cors_config') + config = view.get_request_config(request, 'get') + + assert config.get('*') == CLASS_CONFIG['*'] + + +def test_method_with_default_config(app): + """Test adding resource with web.View as handler""" + request = mock.Mock() + request.app = app + view = SimpleView(request) + + assert not hasattr(view.get, 'get_cors_config') + config = view.get_request_config(request, 'get') + + assert config.get('*') == DEFAULT_CONFIG['*'] diff --git a/tests/unit/test_preflight_handler.py b/tests/unit/test_preflight_handler.py new file mode 100644 index 0000000..84fc8bc --- /dev/null +++ b/tests/unit/test_preflight_handler.py @@ -0,0 +1,12 @@ +from unittest import mock + +import pytest + +from aiohttp_cors.preflight_handler import _PreflightHandler + + +async def test_raises_when_handler_not_extend(): + request = mock.Mock() + handler = _PreflightHandler() + with pytest.raises(NotImplementedError): + await handler._get_config(request, 'origin', 'GET') diff --git a/tests/unit/test_resource_options.py b/tests/unit/test_resource_options.py index 8c6631e..2ba88c2 100644 --- a/tests/unit/test_resource_options.py +++ b/tests/unit/test_resource_options.py @@ -15,46 +15,37 @@ """aiohttp_cors.resource_options unit tests. """ -import unittest +import pytest from aiohttp_cors.resource_options import ResourceOptions -class TestResourceOptions(unittest.TestCase): - """Unit tests for ResourceOptions class""" - - def test_init_no_args(self): - """Test construction without arguments""" - opts = ResourceOptions() - - self.assertFalse(opts.allow_credentials) - self.assertFalse(opts.expose_headers) - self.assertFalse(opts.allow_headers) - self.assertIsNone(opts.max_age) - - def test_comparison(self): - self.assertTrue(ResourceOptions() == ResourceOptions()) - self.assertFalse(ResourceOptions() != ResourceOptions()) - self.assertFalse( - ResourceOptions(allow_credentials=True) == ResourceOptions()) - self.assertTrue( - ResourceOptions(allow_credentials=True) != ResourceOptions()) - - def test_allow_methods(self): - self.assertIsNone(ResourceOptions().allow_methods) - self.assertEqual( - ResourceOptions(allow_methods='*').allow_methods, - '*') - self.assertEqual( - ResourceOptions(allow_methods=[]).allow_methods, - frozenset()) - self.assertEqual( - ResourceOptions(allow_methods=['get']).allow_methods, +def test_init_no_args(): + """Test construction without arguments""" + opts = ResourceOptions() + + assert not opts.allow_credentials + assert not opts.expose_headers + assert not opts.allow_headers + assert opts.max_age is None + + +def test_comparison(): + assert ResourceOptions() == ResourceOptions() + assert not (ResourceOptions() != ResourceOptions()) + assert not (ResourceOptions(allow_credentials=True) == ResourceOptions()) + assert ResourceOptions(allow_credentials=True) != ResourceOptions() + + +def test_allow_methods(): + assert ResourceOptions().allow_methods is None + assert ResourceOptions(allow_methods='*').allow_methods == '*' + assert ResourceOptions(allow_methods=[]).allow_methods == frozenset() + assert (ResourceOptions(allow_methods=['get']).allow_methods == frozenset(['GET'])) - self.assertEqual( - ResourceOptions(allow_methods=['get', 'Post']).allow_methods, + assert (ResourceOptions(allow_methods=['get', 'Post']).allow_methods == {'GET', 'POST'}) - with self.assertRaises(ValueError): - ResourceOptions(allow_methods='GET') + with pytest.raises(ValueError): + ResourceOptions(allow_methods='GET') # TODO: test arguments parsing diff --git a/tests/unit/test_urldispatcher_router_adapter.py b/tests/unit/test_urldispatcher_router_adapter.py index ad5d9c1..d81d2a3 100644 --- a/tests/unit/test_urldispatcher_router_adapter.py +++ b/tests/unit/test_urldispatcher_router_adapter.py @@ -15,10 +15,9 @@ """aiohttp_cors.urldispatcher_router_adapter unit tests. """ -import asyncio -import unittest from unittest import mock +import pytest from aiohttp import web from aiohttp_cors.urldispatcher_router_adapter import \ @@ -26,90 +25,90 @@ from aiohttp_cors.urldispatcher_router_adapter import \ from aiohttp_cors import ResourceOptions -def _handler(request): +async def _handler(request): return web.Response(text="Done") -class TestResourcesUrlDispatcherRouterAdapter(unittest.TestCase): - """Unit tests for CorsConfig""" - - def setUp(self): - self.loop = asyncio.new_event_loop() - self.app = web.Application(loop=self.loop) - - self.adapter = ResourcesUrlDispatcherRouterAdapter( - self.app.router, defaults={ - "*": ResourceOptions() - }) - self.get_route = self.app.router.add_route( - "GET", "/get_path", _handler) - self.options_route = self.app.router.add_route( - "OPTIONS", "/options_path", _handler) - - def tearDown(self): - self.loop.close() - - def test_add_get_route(self): - """Test configuring GET route""" - result = self.adapter.add_preflight_handler( - self.get_route.resource, _handler) - self.assertIsNone(result) - - self.assertEqual(len(self.adapter._resource_config), 0) - self.assertEqual( - len(self.adapter._resources_with_preflight_handlers), 1) - self.assertEqual(len(self.adapter._preflight_routes), 1) - - def test_add_options_route(self): - """Test configuring OPTIONS route""" - - with self.assertRaisesRegex( - ValueError, - "CORS must be enabled for route's resource first"): - self.adapter.add_preflight_handler(self.options_route, _handler) - - self.assertFalse(self.adapter._resources_with_preflight_handlers) - self.assertFalse(self.adapter._preflight_routes) - - def test_get_non_preflight_request_config(self): - self.adapter.add_preflight_handler( - self.get_route.resource, _handler) - self.adapter.set_config_for_routing_entity( - self.get_route.resource, { - 'http://example.org': ResourceOptions(), - }) - - self.adapter.add_preflight_handler( - self.get_route, _handler) - self.adapter.set_config_for_routing_entity( - self.get_route, { - 'http://test.example.org': ResourceOptions(), - }) - - request = mock.Mock() - - with mock.patch('aiohttp_cors.urldispatcher_router_adapter.' - 'ResourcesUrlDispatcherRouterAdapter.' - 'is_cors_enabled_on_request' - ) as is_cors_enabled_on_request, \ - mock.patch('aiohttp_cors.urldispatcher_router_adapter.' - 'ResourcesUrlDispatcherRouterAdapter.' - '_request_resource' - ) as _request_resource: - is_cors_enabled_on_request.return_value = True - _request_resource.return_value = self.get_route.resource - - self.assertEqual( - self.adapter.get_non_preflight_request_config(request), +@pytest.fixture +def app(): + return web.Application() + + +@pytest.fixture +def adapter(app): + return ResourcesUrlDispatcherRouterAdapter( + app.router, defaults={ + "*": ResourceOptions() + }) + + +@pytest.fixture +def get_route(app): + return app.router.add_route( + "GET", "/get_path", _handler) + + +@pytest.fixture +def options_route(app): + return app.router.add_route( + "OPTIONS", "/options_path", _handler) + + +def test_add_get_route(adapter, get_route): + """Test configuring GET route""" + result = adapter.add_preflight_handler( + get_route.resource, _handler) + assert result is None + + assert len(adapter._resource_config) == 0 + assert len(adapter._resources_with_preflight_handlers) == 1 + assert len(adapter._preflight_routes) == 1 + + +def test_add_options_route(adapter, options_route): + """Test configuring OPTIONS route""" + + adapter.add_preflight_handler(options_route, _handler) + + assert not adapter._resources_with_preflight_handlers + assert not adapter._preflight_routes + + +def test_get_non_preflight_request_config(adapter, get_route): + adapter.add_preflight_handler(get_route.resource, _handler) + adapter.set_config_for_routing_entity( + get_route.resource, { + 'http://example.org': ResourceOptions(), + }) + + adapter.add_preflight_handler(get_route, _handler) + adapter.set_config_for_routing_entity( + get_route, { + 'http://test.example.org': ResourceOptions(), + }) + + request = mock.Mock() + + with mock.patch('aiohttp_cors.urldispatcher_router_adapter.' + 'ResourcesUrlDispatcherRouterAdapter.' + 'is_cors_enabled_on_request' + ) as is_cors_enabled_on_request, \ + mock.patch('aiohttp_cors.urldispatcher_router_adapter.' + 'ResourcesUrlDispatcherRouterAdapter.' + '_request_resource' + ) as _request_resource: + is_cors_enabled_on_request.return_value = True + _request_resource.return_value = get_route.resource + + assert (adapter.get_non_preflight_request_config(request) == { '*': ResourceOptions(), 'http://example.org': ResourceOptions(), }) - request.method = 'GET' + request.method = 'GET' - self.assertEqual( - self.adapter.get_non_preflight_request_config(request), + assert (adapter.get_non_preflight_request_config(request) == { '*': ResourceOptions(), 'http://example.org': ResourceOptions(), diff --git a/tox.ini b/tox.ini index 5a9786a..9668d37 100644 --- a/tox.ini +++ b/tox.ini @@ -18,5 +18,6 @@ commands = [pytest] testpaths = aiohttp_cors tests +addopts= --cov=aiohttp_cors --cov-report=term --cov-report=html --cov-branch --no-cov-on-fail ;addopts = --cov aiohttp_cors ; --pylint-rcfile=.pylintrc --pylint -- cgit v1.2.3