summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorRuben Undheim <ruben.undheim@gmail.com>2018-12-21 11:52:13 +0000
committerRuben Undheim <ruben.undheim@gmail.com>2018-12-21 11:52:13 +0000
commit47af65d464e11e1abd2ff23d02218035680c69fa (patch)
tree1ff7e7da314f6901e4340978f0ef47c07fed9a6e
parent094f4e8aeb46fc6bb07936901e228cc4f38dbec6 (diff)
New upstream version 0.7.0
-rw-r--r--.gitignore3
-rw-r--r--.pyup.yml4
-rw-r--r--.travis.yml15
-rw-r--r--CHANGES.rst37
-rw-r--r--LICENSE2
-rw-r--r--Makefile8
-rw-r--r--README.rst45
-rw-r--r--aiohttp_cors/__about__.py6
-rw-r--r--aiohttp_cors/__init__.py3
-rw-r--r--aiohttp_cors/_log.py22
-rw-r--r--aiohttp_cors/abc.py9
-rw-r--r--aiohttp_cors/cors_config.py200
-rw-r--r--aiohttp_cors/mixin.py47
-rw-r--r--aiohttp_cors/preflight_handler.py130
-rw-r--r--aiohttp_cors/urldispatcher_router_adapter.py209
-rw-r--r--appveyor.yml35
-rw-r--r--pytest.ini3
-rw-r--r--requirements-dev.txt19
-rw-r--r--setup.cfg3
-rw-r--r--setup.py1
-rw-r--r--tests/aio_test_base.py63
-rw-r--r--tests/doc/test_basic_usage.py150
-rw-r--r--tests/integration/test_main.py1577
-rw-r--r--tests/integration/test_page.html2
-rw-r--r--tests/integration/test_real_browser.py209
-rw-r--r--tests/unit/test___about__.py11
-rw-r--r--tests/unit/test_cors_config.py173
-rw-r--r--tests/unit/test_mixin.py125
-rw-r--r--tests/unit/test_preflight_handler.py12
-rw-r--r--tests/unit/test_resource_options.py61
-rw-r--r--tests/unit/test_urldispatcher_router_adapter.py153
-rw-r--r--tox.ini1
32 files changed, 1712 insertions, 1626 deletions
diff --git a/.gitignore b/.gitignore
index e2859fe..76419e6 100644
--- a/.gitignore
+++ b/.gitignore
@@ -48,3 +48,6 @@ docs/_build/
# PyBuilder
target/
+
+geckodriver.log
+.pytest_cache \ No newline at end of file
diff --git a/.pyup.yml b/.pyup.yml
new file mode 100644
index 0000000..75f9711
--- /dev/null
+++ b/.pyup.yml
@@ -0,0 +1,4 @@
+# Label PRs with `deps-update` label
+label_prs: deps-update
+
+schedule: every week
diff --git a/.travis.yml b/.travis.yml
index b42cab6..3120baf 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,7 +1,5 @@
language: python
-dist: trusty
python:
-- 3.4
- 3.5
- 3.6
@@ -13,9 +11,8 @@ before_install:
install:
- pip install --upgrade pip setuptools wheel
-# aiohttp git repo has only *.pyx files, so install cython too.
-- '[ -z "$MASTER_AIOHTTP" ] || pip install -U cython git+https://github.com/KeepSafe/aiohttp.git'
- pip install -Ur requirements-dev.txt
+- pip install codecov
before_script:
# Start X-server for Selenium tests.
@@ -25,7 +22,10 @@ before_script:
script:
- '[ "$TYPE" != "test" ] || python setup.py test --addopts -v --addopts -s'
-- '[ "$TYPE" != "lint" ] || python setup.py check'
+- '[ "$TYPE" != "lint" ] || python setup.py check -rms'
+
+after_success:
+ codecov
env:
global:
@@ -33,10 +33,7 @@ env:
matrix:
# PYTHONASYNCIODEBUG environment variable is considered as enabled if it
# is any non empty string.
- - TYPE=test PYTHONASYNCIODEBUG= MASTER_AIOHTTP=
- - TYPE=test PYTHONASYNCIODEBUG=x MASTER_AIOHTTP=
- - TYPE=test PYTHONASYNCIODEBUG= MASTER_AIOHTTP=x
- - TYPE=test PYTHONASYNCIODEBUG=x MASTER_AIOHTTP=x
+ - TYPE=test
matrix:
include:
diff --git a/CHANGES.rst b/CHANGES.rst
index 3224364..a68c3a8 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -1,32 +1,47 @@
-CHANGES
-=======
+=========
+ CHANGES
+=========
+
+0.7.0 (2018-03-05)
+==================
+
+- Make web view check implicit and type based (#159)
+
+- Disable Python 3.4 support (#156)
+
+- Support aiohttp 3.0+ (#155)
+
+0.6.0 (2017-12-21)
+==================
+
+- Support aiohttp views by ``CorsViewMixin`` (#145)
0.5.3 (2017-04-21)
-------------------
+==================
-- Fix `typing` being installed on Python 3.6.
+- Fix ``typing`` being installed on Python 3.6.
0.5.2 (2017-03-28)
-------------------
+==================
- Fix tests compatibility with ``aiohttp`` 2.0.
This release and release v0.5.0 should work on ``aiohttp`` 2.0.
0.5.1 (2017-03-23)
-------------------
+==================
- Enforce ``aiohttp`` version to be less than 2.0.
Newer ``aiohttp`` releases will be supported in the next release.
0.5.0 (2016-11-18)
-------------------
+==================
- Fix compatibility with ``aiohttp`` 1.1
0.4.0 (2016-04-04)
-------------------
+==================
- Fixed support with new Resources objects introduced in ``aiohttp`` 0.21.0.
Minimum supported version of ``aiohttp`` is 0.21.4 now.
@@ -54,7 +69,7 @@ CHANGES
agnostic.
0.3.0 (2016-02-06)
-------------------
+==================
- Rename ``UrlDistatcherRouterAdapter`` to ``UrlDispatcherRouterAdapter``.
@@ -62,7 +77,7 @@ CHANGES
details.
0.2.0 (2015-11-30)
-------------------
+==================
- Move ABCs from ``aiohttp_cors.router_adapter`` to ``aiohttp_cors.abc``.
@@ -71,6 +86,6 @@ CHANGES
- Fix bug with configuring CORS for named routes.
0.1.0 (2015-11-05)
-------------------
+==================
* Initial release.
diff --git a/LICENSE b/LICENSE
index 8f71f43..abd6dd8 100644
--- a/LICENSE
+++ b/LICENSE
@@ -186,7 +186,7 @@
same "printed page" as the copyright notice for easier
identification within third-party archives.
- Copyright {yyyy} {name of copyright owner}
+ Copyright 2015-2018 Vladimir Rutsky and aio-libs team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..2f8e8eb
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,8 @@
+all: test
+
+
+flake:
+ flake8 aiohttp_cors tests setup.py
+
+test: flake
+ pytest tests
diff --git a/README.rst b/README.rst
index 64a3020..ab4b84e 100644
--- a/README.rst
+++ b/README.rst
@@ -1,3 +1,4 @@
+========================
CORS support for aiohttp
========================
@@ -126,7 +127,7 @@ from git:
$ pip install aiohttp_cors
Note that ``aiohttp_cors`` requires versions of Python >= 3.4.1 and
-``aiohttp`` >= 0.21.4.
+``aiohttp`` >= 1.1.
Usage
=====
@@ -346,6 +347,46 @@ in the router:
for route in list(app.router.routes()):
cors.add(route)
+You can also use ``CorsViewMixin`` on ``web.View``:
+
+.. code-block:: python
+
+ class CorsView(web.View, CorsViewMixin):
+
+ cors_config = {
+ "*": ResourceOption(
+ allow_credentials=True,
+ allow_headers="X-Request-ID",
+ )
+ }
+
+ @asyncio.coroutine
+ def get(self):
+ return web.Response(text="Done")
+
+ @custom_cors({
+ "*": ResourceOption(
+ allow_credentials=True,
+ allow_headers="*",
+ )
+ })
+ @asyncio.coroutine
+ def post(self):
+ return web.Response(text="Done")
+
+ cors = aiohttp_cors.setup(app, defaults={
+ "*": aiohttp_cors.ResourceOptions(
+ allow_credentials=True,
+ expose_headers="*",
+ allow_headers="*",
+ )
+ })
+
+ cors.add(
+ app.router.add_route("*", "/resource", CorsView),
+ webview=True)
+
+
Security
========
@@ -460,7 +501,7 @@ Post release steps:
Bugs
====
-Please report bugs, issues, feature requests, etc. on
+Please report bugs, issues, feature requests, etc. on
`GitHub <https://github.com/aio-libs/aiohttp_cors/issues>`__.
diff --git a/aiohttp_cors/__about__.py b/aiohttp_cors/__about__.py
index eb70c30..51c4841 100644
--- a/aiohttp_cors/__about__.py
+++ b/aiohttp_cors/__about__.py
@@ -19,10 +19,10 @@ This module must be stand-alone executable.
"""
__title__ = "aiohttp-cors"
-__version__ = "0.5.3"
-__author__ = "Vladimir Rutsky"
+__version__ = "0.7.0"
+__author__ = "Vladimir Rutsky and aio-libs team"
__email__ = "vladimir@rutsky.org"
__summary__ = "CORS support for aiohttp"
__uri__ = "https://github.com/aio-libs/aiohttp-cors"
__license__ = "Apache License, Version 2.0"
-__copyright__ = "2015, 2016, 2017 {}".format(__author__)
+__copyright__ = "2015-2018 {}".format(__author__)
diff --git a/aiohttp_cors/__init__.py b/aiohttp_cors/__init__.py
index 49474c8..cbcc5ef 100644
--- a/aiohttp_cors/__init__.py
+++ b/aiohttp_cors/__init__.py
@@ -25,11 +25,12 @@ from .__about__ import (
)
from .resource_options import ResourceOptions
from .cors_config import CorsConfig
+from .mixin import CorsViewMixin, custom_cors
__all__ = (
"__title__", "__version__", "__author__", "__email__", "__summary__",
"__uri__", "__license__", "__copyright__",
- "setup", "CorsConfig", "ResourceOptions",
+ "setup", "CorsConfig", "ResourceOptions", "CorsViewMixin", "custom_cors"
)
diff --git a/aiohttp_cors/_log.py b/aiohttp_cors/_log.py
deleted file mode 100644
index 5b216cb..0000000
--- a/aiohttp_cors/_log.py
+++ /dev/null
@@ -1,22 +0,0 @@
-# Copyright 2015 Vladimir Rutsky <vladimir@rutsky.org>
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""aiohttp_cors logger"""
-
-import logging
-
-__all__ = ("logger",)
-
-# pylint: disable=invalid-name
-logger = logging.getLogger("aiohttp_cors")
diff --git a/aiohttp_cors/abc.py b/aiohttp_cors/abc.py
index cdddd36..5204a83 100644
--- a/aiohttp_cors/abc.py
+++ b/aiohttp_cors/abc.py
@@ -15,7 +15,6 @@
"""Abstract base classes.
"""
-import asyncio
from abc import ABCMeta, abstractmethod
from aiohttp import web
@@ -51,7 +50,10 @@ class AbstractRouterAdapter(metaclass=ABCMeta):
"""
@abstractmethod
- def add_preflight_handler(self, routing_entity, handler):
+ def add_preflight_handler(self,
+ routing_entity,
+ handler,
+ webview: bool=False):
"""Add OPTIONS handler for all routes defined by `routing_entity`.
Does nothing if CORS handler already handles routing entity.
@@ -79,9 +81,8 @@ class AbstractRouterAdapter(metaclass=ABCMeta):
entity.
"""
- @asyncio.coroutine
@abstractmethod
- def get_preflight_request_config(
+ async def get_preflight_request_config(
self,
preflight_request: web.Request,
origin: str,
diff --git a/aiohttp_cors/cors_config.py b/aiohttp_cors/cors_config.py
index d8017c9..2e1aeea 100644
--- a/aiohttp_cors/cors_config.py
+++ b/aiohttp_cors/cors_config.py
@@ -15,22 +15,21 @@
"""CORS configuration container class definition.
"""
-import asyncio
import collections
+import warnings
from typing import Mapping, Union, Any
from aiohttp import hdrs, web
-from .urldispatcher_router_adapter import OldRoutesUrlDispatcherRouterAdapter
from .urldispatcher_router_adapter import ResourcesUrlDispatcherRouterAdapter
from .abc import AbstractRouterAdapter
from .resource_options import ResourceOptions
+from .preflight_handler import _PreflightHandler
__all__ = (
"CorsConfig",
)
-
# Positive response to Access-Control-Allow-Credentials
_TRUE = "true"
# CORS simple response headers:
@@ -103,7 +102,7 @@ def _parse_config_options(
_ConfigType = Mapping[str, Union[ResourceOptions, Mapping[str, Any]]]
-class _CorsConfigImpl:
+class _CorsConfigImpl(_PreflightHandler):
def __init__(self,
app: web.Application,
@@ -139,10 +138,9 @@ class _CorsConfigImpl:
return routing_entity
- @asyncio.coroutine
- def _on_response_prepare(self,
- request: web.Request,
- response: web.StreamResponse):
+ async def _on_response_prepare(self,
+ request: web.Request,
+ response: web.StreamResponse):
"""Non-preflight CORS request response processor.
If request is done on CORS-enabled route, process request parameters
@@ -196,127 +194,11 @@ class _CorsConfigImpl:
# Set allowed credentials.
response.headers[hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS] = _TRUE
- @staticmethod
- def _parse_request_method(request: web.Request):
- """Parse Access-Control-Request-Method header of the preflight request
- """
- method = request.headers.get(hdrs.ACCESS_CONTROL_REQUEST_METHOD)
- if method is None:
- raise web.HTTPForbidden(
- text="CORS preflight request failed: "
- "'Access-Control-Request-Method' header is not specified")
-
- # FIXME: validate method string (ABNF: method = token), if parsing
- # fails, raise HTTPForbidden.
-
- return method
-
- @staticmethod
- def _parse_request_headers(request: web.Request):
- """Parse Access-Control-Request-Headers header or the preflight request
-
- Returns set of headers in upper case.
- """
- headers = request.headers.get(hdrs.ACCESS_CONTROL_REQUEST_HEADERS)
- if headers is None:
- return frozenset()
-
- # FIXME: validate each header string, if parsing fails, raise
- # HTTPForbidden.
- # FIXME: check, that headers split and stripped correctly (according
- # to ABNF).
- headers = (h.strip(" \t").upper() for h in headers.split(","))
- # pylint: disable=bad-builtin
- return frozenset(filter(None, headers))
-
- @asyncio.coroutine
- def _preflight_handler(self, request: web.Request):
- """CORS preflight request handler"""
-
- # Handle according to part 6.2 of the CORS specification.
-
- origin = request.headers.get(hdrs.ORIGIN)
- if origin is None:
- # Terminate CORS according to CORS 6.2.1.
- raise web.HTTPForbidden(
- text="CORS preflight request failed: "
- "origin header is not specified in the request")
-
- # CORS 6.2.3. Doing it out of order is not an error.
- request_method = self._parse_request_method(request)
-
- # CORS 6.2.5. Doing it out of order is not an error.
-
- try:
- config = \
- yield from self._router_adapter.get_preflight_request_config(
- request, origin, request_method)
- except KeyError:
- raise web.HTTPForbidden(
- text="CORS preflight request failed: "
- "request method {!r} is not allowed "
- "for {!r} origin".format(request_method, origin))
-
- if not config:
- # No allowed origins for the route.
- # Terminate CORS according to CORS 6.2.1.
- raise web.HTTPForbidden(
- text="CORS preflight request failed: "
- "no origins are allowed")
-
- options = config.get(origin, config.get("*"))
- if options is None:
- # No configuration for the origin - deny.
- # Terminate CORS according to CORS 6.2.2.
- raise web.HTTPForbidden(
- text="CORS preflight request failed: "
- "origin '{}' is not allowed".format(origin))
-
- # CORS 6.2.4
- request_headers = self._parse_request_headers(request)
-
- # CORS 6.2.6
- if options.allow_headers == "*":
- pass
- else:
- disallowed_headers = request_headers - options.allow_headers
- if disallowed_headers:
- raise web.HTTPForbidden(
- text="CORS preflight request failed: "
- "headers are not allowed: {}".format(
- ", ".join(disallowed_headers)))
-
- # Ok, CORS actual request with specified in the preflight request
- # parameters is allowed.
- # Set appropriate headers and return 200 response.
-
- response = web.Response()
-
- # CORS 6.2.7
- response.headers[hdrs.ACCESS_CONTROL_ALLOW_ORIGIN] = origin
- if options.allow_credentials:
- # Set allowed credentials.
- response.headers[hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS] = _TRUE
-
- # CORS 6.2.8
- if options.max_age is not None:
- response.headers[hdrs.ACCESS_CONTROL_MAX_AGE] = \
- str(options.max_age)
-
- # CORS 6.2.9
- # TODO: more optimal for client preflight request cache would be to
- # respond with ALL allowed methods.
- response.headers[hdrs.ACCESS_CONTROL_ALLOW_METHODS] = request_method
-
- # CORS 6.2.10
- if request_headers:
- # Note: case of the headers in the request is changed, but this
- # shouldn't be a problem, since the headers should be compared in
- # the case-insensitive way.
- response.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS] = \
- ",".join(request_headers)
-
- return response
+ async def _get_config(self, request, origin, request_method):
+ config = \
+ await self._router_adapter.get_preflight_request_config(
+ request, origin, request_method)
+ return config
class CorsConfig:
@@ -341,7 +223,7 @@ class CorsConfig:
Router adapter. Required if application uses non-default router.
"""
- defaults = _parse_config_options(defaults)
+ self.defaults = _parse_config_options(defaults)
self._cors_impl = None
@@ -350,27 +232,16 @@ class CorsConfig:
self._old_routes_cors_impl = None
- if router_adapter is not None:
- self._cors_impl = _CorsConfigImpl(app, router_adapter)
-
- elif isinstance(app.router, web.UrlDispatcher):
- self._resources_router_adapter = \
- ResourcesUrlDispatcherRouterAdapter(app.router, defaults)
- self._resources_cors_impl = _CorsConfigImpl(
- app,
- self._resources_router_adapter)
- self._old_routes_cors_impl = _CorsConfigImpl(
- app,
- OldRoutesUrlDispatcherRouterAdapter(app.router, defaults))
- else:
- raise RuntimeError(
- "Router adapter is not specified. "
- "Routers other than aiohttp.web.UrlDispatcher requires"
- "custom router adapter.")
+ if router_adapter is None:
+ router_adapter = \
+ ResourcesUrlDispatcherRouterAdapter(app.router, self.defaults)
+
+ self._cors_impl = _CorsConfigImpl(app, router_adapter)
def add(self,
routing_entity,
- config: _ConfigType = None):
+ config: _ConfigType = None,
+ webview: bool=False):
"""Enable CORS for specific route or resource.
If route is passed CORS is enabled for route's resource.
@@ -382,30 +253,11 @@ class CorsConfig:
:return: `routing_entity`.
"""
- if self._cors_impl is not None:
- # Custom router adapter.
- return self._cors_impl.add(routing_entity, config)
+ if webview:
+ warnings.warn('webview argument is deprecated, '
+ 'views are handled authomatically without '
+ 'extra settings',
+ DeprecationWarning,
+ stacklevel=2)
- else:
- # UrlDispatcher.
-
- if isinstance(routing_entity, (web.Resource, web.StaticResource)):
- # New Resource - use new router adapter.
- return self._resources_cors_impl.add(routing_entity, config)
-
- elif isinstance(routing_entity, web.AbstractRoute):
- if self._resources_router_adapter.is_cors_for_resource(
- routing_entity.resource):
- # Route which resource has CORS configuration in
- # new-style router adapter.
- return self._resources_cors_impl.add(
- routing_entity, config)
- else:
- # Route which resource has no CORS configuration, i.e.
- # old-style route.
- return self._old_routes_cors_impl.add(
- routing_entity, config)
-
- else:
- raise ValueError(
- "Unknown resource/route type: {!r}".format(routing_entity))
+ return self._cors_impl.add(routing_entity, config)
diff --git a/aiohttp_cors/mixin.py b/aiohttp_cors/mixin.py
new file mode 100644
index 0000000..f5b4506
--- /dev/null
+++ b/aiohttp_cors/mixin.py
@@ -0,0 +1,47 @@
+import collections
+
+from .preflight_handler import _PreflightHandler
+
+
+def custom_cors(config):
+ def wrapper(function):
+ name = "{}_cors_config".format(function.__name__)
+ setattr(function, name, config)
+ return function
+ return wrapper
+
+
+class CorsViewMixin(_PreflightHandler):
+ cors_config = None
+
+ @classmethod
+ def get_request_config(cls, request, request_method):
+ try:
+ from . import APP_CONFIG_KEY
+ cors = request.app[APP_CONFIG_KEY]
+ except KeyError:
+ raise ValueError("aiohttp-cors is not configured.")
+
+ method = getattr(cls, request_method.lower(), None)
+
+ if not method:
+ raise KeyError()
+
+ config_property_key = "{}_cors_config".format(request_method.lower())
+
+ custom_config = getattr(method, config_property_key, None)
+ if not custom_config:
+ custom_config = {}
+
+ class_config = cls.cors_config
+ if not class_config:
+ class_config = {}
+
+ return collections.ChainMap(custom_config, class_config, cors.defaults)
+
+ async def _get_config(self, request, origin, request_method):
+ return self.get_request_config(request, request_method)
+
+ async def options(self):
+ response = await self._preflight_handler(self.request)
+ return response
diff --git a/aiohttp_cors/preflight_handler.py b/aiohttp_cors/preflight_handler.py
new file mode 100644
index 0000000..35e15e1
--- /dev/null
+++ b/aiohttp_cors/preflight_handler.py
@@ -0,0 +1,130 @@
+from aiohttp import hdrs, web
+
+# Positive response to Access-Control-Allow-Credentials
+_TRUE = "true"
+
+
+class _PreflightHandler:
+
+ @staticmethod
+ def _parse_request_method(request: web.Request):
+ """Parse Access-Control-Request-Method header of the preflight request
+ """
+ method = request.headers.get(hdrs.ACCESS_CONTROL_REQUEST_METHOD)
+ if method is None:
+ raise web.HTTPForbidden(
+ text="CORS preflight request failed: "
+ "'Access-Control-Request-Method' header is not specified")
+
+ # FIXME: validate method string (ABNF: method = token), if parsing
+ # fails, raise HTTPForbidden.
+
+ return method
+
+ @staticmethod
+ def _parse_request_headers(request: web.Request):
+ """Parse Access-Control-Request-Headers header or the preflight request
+
+ Returns set of headers in upper case.
+ """
+ headers = request.headers.get(hdrs.ACCESS_CONTROL_REQUEST_HEADERS)
+ if headers is None:
+ return frozenset()
+
+ # FIXME: validate each header string, if parsing fails, raise
+ # HTTPForbidden.
+ # FIXME: check, that headers split and stripped correctly (according
+ # to ABNF).
+ headers = (h.strip(" \t").upper() for h in headers.split(","))
+ # pylint: disable=bad-builtin
+ return frozenset(filter(None, headers))
+
+ async def _get_config(self, request, origin, request_method):
+ raise NotImplementedError()
+
+ async def _preflight_handler(self, request: web.Request):
+ """CORS preflight request handler"""
+
+ # Handle according to part 6.2 of the CORS specification.
+
+ origin = request.headers.get(hdrs.ORIGIN)
+ if origin is None:
+ # Terminate CORS according to CORS 6.2.1.
+ raise web.HTTPForbidden(
+ text="CORS preflight request failed: "
+ "origin header is not specified in the request")
+
+ # CORS 6.2.3. Doing it out of order is not an error.
+ request_method = self._parse_request_method(request)
+
+ # CORS 6.2.5. Doing it out of order is not an error.
+
+ try:
+ config = \
+ await self._get_config(request, origin, request_method)
+ except KeyError:
+ raise web.HTTPForbidden(
+ text="CORS preflight request failed: "
+ "request method {!r} is not allowed "
+ "for {!r} origin".format(request_method, origin))
+
+ if not config:
+ # No allowed origins for the route.
+ # Terminate CORS according to CORS 6.2.1.
+ raise web.HTTPForbidden(
+ text="CORS preflight request failed: "
+ "no origins are allowed")
+
+ options = config.get(origin, config.get("*"))
+ if options is None:
+ # No configuration for the origin - deny.
+ # Terminate CORS according to CORS 6.2.2.
+ raise web.HTTPForbidden(
+ text="CORS preflight request failed: "
+ "origin '{}' is not allowed".format(origin))
+
+ # CORS 6.2.4
+ request_headers = self._parse_request_headers(request)
+
+ # CORS 6.2.6
+ if options.allow_headers == "*":
+ pass
+ else:
+ disallowed_headers = request_headers - options.allow_headers
+ if disallowed_headers:
+ raise web.HTTPForbidden(
+ text="CORS preflight request failed: "
+ "headers are not allowed: {}".format(
+ ", ".join(disallowed_headers)))
+
+ # Ok, CORS actual request with specified in the preflight request
+ # parameters is allowed.
+ # Set appropriate headers and return 200 response.
+
+ response = web.Response()
+
+ # CORS 6.2.7
+ response.headers[hdrs.ACCESS_CONTROL_ALLOW_ORIGIN] = origin
+ if options.allow_credentials:
+ # Set allowed credentials.
+ response.headers[hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS] = _TRUE
+
+ # CORS 6.2.8
+ if options.max_age is not None:
+ response.headers[hdrs.ACCESS_CONTROL_MAX_AGE] = \
+ str(options.max_age)
+
+ # CORS 6.2.9
+ # TODO: more optimal for client preflight request cache would be to
+ # respond with ALL allowed methods.
+ response.headers[hdrs.ACCESS_CONTROL_ALLOW_METHODS] = request_method
+
+ # CORS 6.2.10
+ if request_headers:
+ # Note: case of the headers in the request is changed, but this
+ # shouldn't be a problem, since the headers should be compared in
+ # the case-insensitive way.
+ response.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS] = \
+ ",".join(request_headers)
+
+ return response
diff --git a/aiohttp_cors/urldispatcher_router_adapter.py b/aiohttp_cors/urldispatcher_router_adapter.py
index 92bfbb3..1a65e99 100644
--- a/aiohttp_cors/urldispatcher_router_adapter.py
+++ b/aiohttp_cors/urldispatcher_router_adapter.py
@@ -14,9 +14,7 @@
"""AbstractRouterAdapter for aiohttp.web.UrlDispatcher.
"""
-import asyncio
import collections
-import re
from typing import Union
@@ -24,6 +22,7 @@ from aiohttp import web
from aiohttp import hdrs
from .abc import AbstractRouterAdapter
+from .mixin import CorsViewMixin
# There several usage patterns of routes which should be handled
@@ -89,6 +88,22 @@ class _ResourceConfig:
self.method_config = {}
+def _is_web_view(entity, strict=True):
+ webview = False
+ if isinstance(entity, web.AbstractRoute):
+ handler = entity.handler
+ if isinstance(handler, type) and issubclass(handler, web.View):
+ webview = True
+ if not issubclass(handler, CorsViewMixin):
+ if strict:
+ raise ValueError("web view should be derived from "
+ "aiohttp_cors.WebViewMixig for working "
+ "with the library")
+ else:
+ return False
+ return webview
+
+
class ResourcesUrlDispatcherRouterAdapter(AbstractRouterAdapter):
"""Adapter for `UrlDispatcher` for Resources-based routing only.
@@ -138,6 +153,22 @@ class ResourcesUrlDispatcherRouterAdapter(AbstractRouterAdapter):
if resource in self._resources_with_preflight_handlers:
# Preflight handler already added for this resource.
return
+ for route_obj in resource:
+ if route_obj.method == hdrs.METH_OPTIONS:
+ if route_obj.handler is handler:
+ return # already added
+ else:
+ raise ValueError(
+ "{!r} already has OPTIONS handler {!r}"
+ .format(resource, route_obj.handler))
+ elif route_obj.method == hdrs.METH_ANY:
+ if _is_web_view(route_obj):
+ self._preflight_routes.add(route_obj)
+ self._resources_with_preflight_handlers.add(resource)
+ return
+ else:
+ raise ValueError("{!r} already has a '*' handler "
+ "for all methods".format(resource))
preflight_route = resource.add_route(hdrs.METH_OPTIONS, handler)
self._preflight_routes.add(preflight_route)
@@ -160,13 +191,8 @@ class ResourcesUrlDispatcherRouterAdapter(AbstractRouterAdapter):
elif isinstance(routing_entity, web.ResourceRoute):
route = routing_entity
- # Preflight handler for Route's Resource already must be
- # configured.
if not self.is_cors_for_resource(route.resource):
- raise ValueError(
- "Can't setup CORS for {!r} request, "
- "CORS must be enabled for route's resource first.".format(
- route))
+ self.add_preflight_handler(route.resource, handler)
else:
raise ValueError(
@@ -187,8 +213,10 @@ class ResourcesUrlDispatcherRouterAdapter(AbstractRouterAdapter):
def is_preflight_request(self, request: web.Request) -> bool:
"""Is `request` is a CORS preflight request."""
-
- return self._request_route(request) in self._preflight_routes
+ route = self._request_route(request)
+ if _is_web_view(route, strict=False):
+ return request.method == 'OPTIONS'
+ return route in self._preflight_routes
def is_cors_enabled_on_request(self, request: web.Request) -> bool:
"""Is `request` is a request for CORS-enabled resource."""
@@ -219,6 +247,9 @@ class ResourcesUrlDispatcherRouterAdapter(AbstractRouterAdapter):
# Add resource's route configuration or fail if it's already added.
if route.resource not in self._resource_config:
+ self.set_config_for_routing_entity(route.resource, config)
+
+ if route.resource not in self._resource_config:
raise ValueError(
"Can't setup CORS for {!r} request, "
"CORS must be enabled for route's resource first.".format(
@@ -239,8 +270,7 @@ class ResourcesUrlDispatcherRouterAdapter(AbstractRouterAdapter):
"Resource or ResourceRoute expected, got {!r}".format(
routing_entity))
- @asyncio.coroutine
- def get_preflight_request_config(
+ async def get_preflight_request_config(
self,
preflight_request: web.Request,
origin: str,
@@ -279,155 +309,16 @@ class ResourcesUrlDispatcherRouterAdapter(AbstractRouterAdapter):
resource_config = self._resource_config[resource]
# Take Route config (if any) with defaults from Resource CORS
# configuration and global defaults.
+ route = request.match_info.route
+ if _is_web_view(route, strict=False):
+ method_config = request.match_info.handler.get_request_config(
+ request, request.method)
+ else:
+ method_config = resource_config.method_config.get(request.method,
+ {})
defaulted_config = collections.ChainMap(
- resource_config.method_config.get(request.method, {}),
+ method_config,
resource_config.default_config,
self._default_config)
return defaulted_config
-
-
-class OldRoutesUrlDispatcherRouterAdapter(AbstractRouterAdapter):
- """Adapter for `UrlDispatcher` for old-style routing only.
-
- In all use cases when Resource is not explicitly used,
- Resource will automatically allocated for old route.
- In this case all routes will have it's own resource, and to find
- related routes (routes that shares same path) we need to iterate over
- all routes with enabled CORS and check is they handle specific path.
-
- This whole class should go away when user will migrate to proper
- Resource/Route usage scheme.
- """
-
- def __init__(self,
- router: web.UrlDispatcher,
- defaults):
- """
- :param defaults:
- Default CORS configuration.
- """
- self._router = router
-
- # Default configuration for all routes.
- self._default_config = defaults
-
- # Mapping from route to config.
- self._route_config = collections.OrderedDict()
-
- self._preflight_routes = set()
-
- def add_preflight_handler(
- self,
- route: web.AbstractRoute,
- handler):
- """Add OPTIONS handler for same paths that `route` handles."""
-
- assert isinstance(route, web.AbstractRoute)
-
- if isinstance(route, web.ResourceRoute):
- # New-style route (which Resource is not used explicitly,
- # otherwise it would be handled by other adapter).
- preflight_route = route.resource.add_route(
- hdrs.METH_OPTIONS, handler)
-
- elif isinstance(route, web.Route):
- # Old-style route.
-
- if isinstance(route, web.StaticRoute):
- # TODO: Use custom matches that uses `str.startswith()`
- # if regexp performance is not enough.
- pattern = re.compile("^" + re.escape(route._prefix))
- preflight_route = web.DynamicRoute(
- hdrs.METH_OPTIONS, handler, None, pattern, "")
- self._router.register_route(preflight_route)
-
- elif isinstance(route, web.PlainRoute):
- # May occur only if user manually creates PlainRoute.
- preflight_route = self._router.add_route(
- hdrs.METH_OPTIONS, route._path, handler)
-
- elif isinstance(route, web.DynamicRoute):
- # May occur only if user manually creates DynamicRoute.
- preflight_route = web.DynamicRoute(
- hdrs.METH_OPTIONS, handler, None,
- route._pattern, route._formatter)
- self._router.register_route(preflight_route)
-
- else:
- raise RuntimeError(
- "Unhandled deprecated route type {!r}".format(route))
-
- else:
- raise RuntimeError("Unhandled route type {!r}".format(route))
-
- self._preflight_routes.add(preflight_route)
-
- def _request_route(self, request: web.Request) -> web.ResourceRoute:
- match_info = request.match_info
- assert isinstance(match_info, web.UrlMappingMatchInfo)
- return match_info.route
-
- def is_preflight_request(self, request: web.Request) -> bool:
- """Is `request` is a CORS preflight request."""
-
- return self._request_route(request) in self._preflight_routes
-
- def is_cors_enabled_on_request(self, request: web.Request) -> bool:
- """Is `request` is a request for CORS-enabled resource."""
-
- return self._request_route(request) in self._route_config
-
- def set_config_for_routing_entity(
- self,
- route: web.AbstractRoute,
- config):
- """Record CORS configuration for route."""
-
- assert isinstance(route, web.AbstractRoute)
-
- if any(options.allow_methods is not None
- for options in config.values()):
- raise ValueError(
- "'allow_methods' parameter is not supported on old-style "
- "routes. You specified {!r} for {!r}. "
- "Use Resources to configure CORS.".format(
- config, route))
-
- if route in self._route_config:
- raise ValueError(
- "CORS is already configured for {!r} route.".format(
- route))
-
- self._route_config[route] = config
-
- @asyncio.coroutine
- def get_preflight_request_config(
- self,
- preflight_request: web.Request,
- origin: str,
- requested_method: str):
- assert self.is_preflight_request(preflight_request)
-
- request = preflight_request.clone(method=requested_method)
- for route, config in self._route_config.items():
- match_info, allowed_methods = yield from route.resource.resolve(
- request)
- if match_info is not None:
- return collections.ChainMap(config, self._default_config)
- else:
- raise KeyError
-
- def get_non_preflight_request_config(self, request: web.Request):
- """Get stored CORS configuration for routing entity that handles
- specified request."""
-
- assert self.is_cors_enabled_on_request(request)
-
- route = self._request_route(request)
- route_config = self._route_config[route]
-
- defaulted_config = collections.ChainMap(
- route_config, self._default_config)
-
- return defaulted_config
diff --git a/appveyor.yml b/appveyor.yml
deleted file mode 100644
index 012eecb..0000000
--- a/appveyor.yml
+++ /dev/null
@@ -1,35 +0,0 @@
-version: 0.0.1.dev{build}
-
-environment:
- matrix:
- - PYTHON: "C:\\Python34"
- PYTHON_VERSION: "3.4.4"
- PYTHON_ARCH: "32"
- - PYTHON: "C:\\Python34"
- PYTHON_VERSION: "3.4.4"
- PYTHON_ARCH: "64"
- - PYTHON: "C:\\Python35"
- PYTHON_VERSION: "3.5.1"
- PYTHON_ARCH: "32"
- - PYTHON: "C:\\Python35"
- PYTHON_VERSION: "3.5.1"
- PYTHON_ARCH: "64"
- - PYTHON: "C:\\Python36"
- PYTHON_VERSION: "3.6.0"
- PYTHON_ARCH: "32"
- - PYTHON: "C:\\Python36"
- PYTHON_VERSION: "3.6.0"
- PYTHON_ARCH: "64"
-
-install:
- - "powershell ./install_python_and_pip.ps1"
- - "SET PATH=%PYTHON%;%PYTHON%\\Scripts;%PATH%"
- - "python --version"
- - "pip install -U pip setuptools wheel"
- - "pip list"
- - "pip install -r requirements-dev.txt"
- - "python setup.py develop"
-
-build: false
-test_script:
- - "python setup.py test"
diff --git a/pytest.ini b/pytest.ini
new file mode 100644
index 0000000..e62899b
--- /dev/null
+++ b/pytest.ini
@@ -0,0 +1,3 @@
+[pytest]
+filterwarnings=
+ error
diff --git a/requirements-dev.txt b/requirements-dev.txt
index 7a37e87..6331621 100644
--- a/requirements-dev.txt
+++ b/requirements-dev.txt
@@ -1,11 +1,14 @@
-tox==2.7.0
-pytest==3.0.7
-pytest-cov==2.4.0
-pytest-runner==2.11.1
-pytest-flakes==1.0.1
+aiohttp==3.0.5
+tox==2.9.1
+pytest==3.4.0
+pytest-aiohttp==0.3.0
+pytest-cov==2.5.1
+pytest-runner==3.0
+pytest-flakes==2.0.0
pytest-pylint==0.7.1
-flake8==3.3.0
-selenium==3.3.3
-docutils==0.13.1
+pytest-sugar==0.9.1
+flake8==3.5.0
+selenium==3.8.1
+docutils==0.14
pygments==2.2.0
-e .
diff --git a/setup.cfg b/setup.cfg
index 31ad82b..20d367e 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,2 +1,5 @@
[aliases]
test = pytest
+
+[tool:pytest]
+addopts= --cov=aiohttp_cors --cov-report=term --cov-report=html --cov-branch --no-cov-on-fail \ No newline at end of file
diff --git a/setup.py b/setup.py
index c70d98b..ce7d69c 100644
--- a/setup.py
+++ b/setup.py
@@ -79,6 +79,7 @@ setup(
"Programming Language :: Python :: 3.6",
"Topic :: Software Development :: Libraries",
"Topic :: Internet :: WWW/HTTP",
+ "Framework :: AsyncIO",
"Operating System :: MacOS :: MacOS X",
"Operating System :: Microsoft :: Windows",
"Operating System :: POSIX",
diff --git a/tests/aio_test_base.py b/tests/aio_test_base.py
deleted file mode 100644
index 7289b6a..0000000
--- a/tests/aio_test_base.py
+++ /dev/null
@@ -1,63 +0,0 @@
-# Copyright 2015 Vladimir Rutsky <vladimir@rutsky.org>
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""Base classes and utility functions for testing asyncio-powered code.
-"""
-
-import unittest
-import asyncio
-import socket
-import functools
-import concurrent.futures
-
-
-@asyncio.coroutine
-def create_server(protocol_factory, loop=None, sock=None):
- """Create server listening on random port"""
-
- if sock is None:
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- sock.bind(("127.0.0.1", 0))
- sock.listen(10)
-
- if loop is None:
- loop = asyncio.get_event_loop()
-
- return (yield from loop.create_server(protocol_factory, sock=sock))
-
-
-class AioTestBase(unittest.TestCase):
- """Base class for tests that need temporary asyncio event loop"""
- def setUp(self):
- self.loop = asyncio.new_event_loop()
- asyncio.set_event_loop(self.loop)
-
- self.thread_pool_executor = concurrent.futures.ThreadPoolExecutor(4)
-
- def tearDown(self):
- self.thread_pool_executor.shutdown()
-
- self.loop.close()
- asyncio.set_event_loop(None)
-
-
-def asynctest(test_method):
- """Decorator for coroutine tests.
-
- To be used with `AioTestBase`-based tests"""
- @functools.wraps(test_method)
- def wrapper(self):
- """Synchronously run test method in the event loop"""
- self.loop.run_until_complete(test_method(self))
- return wrapper
diff --git a/tests/doc/test_basic_usage.py b/tests/doc/test_basic_usage.py
index 899da8d..4eace99 100644
--- a/tests/doc/test_basic_usage.py
+++ b/tests/doc/test_basic_usage.py
@@ -14,93 +14,87 @@
"""Test basic usage."""
-import unittest
-
-
-class TestBasicUsage(unittest.TestCase):
- def test_main(self):
- # This tests corresponds to example from documentation.
- # If you updating it, don't forget to update documentation.
-
- import asyncio
- from aiohttp import web
- import aiohttp_cors
-
- @asyncio.coroutine
- def handler(request):
- return web.Response(
- text="Hello!",
- headers={
- "X-Custom-Server-Header": "Custom data",
- })
-
- app = web.Application()
-
- # `aiohttp_cors.setup` returns `aiohttp_cors.CorsConfig` instance.
- # The `cors` instance will store CORS configuration for the
- # application.
- cors = aiohttp_cors.setup(app)
-
- # To enable CORS processing for specific route you need to add
- # that route to the CORS configuration object and specify its
- # CORS options.
- resource = cors.add(app.router.add_resource("/hello"))
- route = cors.add(
- resource.add_route("GET", handler), {
- "http://client.example.org": aiohttp_cors.ResourceOptions(
- allow_credentials=True,
- expose_headers=("X-Custom-Server-Header",),
- allow_headers=("X-Requested-With", "Content-Type"),
- max_age=3600,
- )
- })
- assert route is not None
+async def test_main():
+ # This tests corresponds to example from documentation.
+ # If you updating it, don't forget to update documentation.
- def test_defaults(self):
- # This tests corresponds to example from documentation.
- # If you updating it, don't forget to update documentation.
+ from aiohttp import web
+ import aiohttp_cors
- import asyncio
- from aiohttp import web
- import aiohttp_cors
+ async def handler(request):
+ return web.Response(
+ text="Hello!",
+ headers={
+ "X-Custom-Server-Header": "Custom data",
+ })
- @asyncio.coroutine
- def handler(request):
- return web.Response(
- text="Hello!",
- headers={
- "X-Custom-Server-Header": "Custom data",
- })
+ app = web.Application()
+
+ # `aiohttp_cors.setup` returns `aiohttp_cors.CorsConfig` instance.
+ # The `cors` instance will store CORS configuration for the
+ # application.
+ cors = aiohttp_cors.setup(app)
+
+ # To enable CORS processing for specific route you need to add
+ # that route to the CORS configuration object and specify its
+ # CORS options.
+ resource = cors.add(app.router.add_resource("/hello"))
+ route = cors.add(
+ resource.add_route("GET", handler), {
+ "http://client.example.org": aiohttp_cors.ResourceOptions(
+ allow_credentials=True,
+ expose_headers=("X-Custom-Server-Header",),
+ allow_headers=("X-Requested-With", "Content-Type"),
+ max_age=3600,
+ )
+ })
+
+ assert route is not None
+
+
+async def test_defaults():
+ # This tests corresponds to example from documentation.
+ # If you updating it, don't forget to update documentation.
+
+ from aiohttp import web
+ import aiohttp_cors
+
+ async def handler(request):
+ return web.Response(
+ text="Hello!",
+ headers={
+ "X-Custom-Server-Header": "Custom data",
+ })
- handler_post = handler
- handler_put = handler
+ handler_post = handler
+ handler_put = handler
- app = web.Application()
+ app = web.Application()
- # Example:
+ # Example:
- cors = aiohttp_cors.setup(app, defaults={
- # Allow all to read all CORS-enabled resources from
- # http://client.example.org.
- "http://client.example.org": aiohttp_cors.ResourceOptions(),
- })
+ cors = aiohttp_cors.setup(app, defaults={
+ # Allow all to read all CORS-enabled resources from
+ # http://client.example.org.
+ "http://client.example.org": aiohttp_cors.ResourceOptions(),
+ })
- # Enable CORS on routes.
+ # Enable CORS on routes.
- # According to defaults POST and PUT will be available only to
- # "http://client.example.org".
- hello_resource = cors.add(app.router.add_resource("/hello"))
- cors.add(hello_resource.add_route("POST", handler_post))
- cors.add(hello_resource.add_route("PUT", handler_put))
+ # According to defaults POST and PUT will be available only to
+ # "http://client.example.org".
+ hello_resource = cors.add(app.router.add_resource("/hello"))
+ cors.add(hello_resource.add_route("POST", handler_post))
+ cors.add(hello_resource.add_route("PUT", handler_put))
- # In addition to "http://client.example.org", GET request will be
- # allowed from "http://other-client.example.org" origin.
- cors.add(hello_resource.add_route("GET", handler), {
- "http://other-client.example.org":
- aiohttp_cors.ResourceOptions(),
- })
+ # In addition to "http://client.example.org", GET request will be
+ # allowed from "http://other-client.example.org" origin.
+ cors.add(hello_resource.add_route("GET", handler), {
+ "http://other-client.example.org":
+ aiohttp_cors.ResourceOptions(),
+ })
- # CORS will be enabled only on the resources added to `CorsConfig`,
- # so following resource will be NOT CORS-enabled.
- app.router.add_route("GET", "/private", handler)
+ # CORS will be enabled only on the resources added to `CorsConfig`,
+ # so following resource will be NOT CORS-enabled.
+ app.router.add_route("GET", "/private", handler)
diff --git a/tests/integration/test_main.py b/tests/integration/test_main.py
index 661094e..098c3d3 100644
--- a/tests/integration/test_main.py
+++ b/tests/integration/test_main.py
@@ -15,18 +15,14 @@
"""Test generic usage
"""
-import asyncio
import pathlib
-from yarl import URL
+import pytest
-from tests.aio_test_base import AioTestBase, create_server, asynctest
-
-import aiohttp
from aiohttp import web
from aiohttp import hdrs
-from aiohttp_cors import setup, ResourceOptions
+from aiohttp_cors import setup as _setup, ResourceOptions, CorsViewMixin
TEST_BODY = "Hello, world"
@@ -34,9 +30,8 @@ SERVER_CUSTOM_HEADER_NAME = "X-Server-Custom-Header"
SERVER_CUSTOM_HEADER_VALUE = "some value"
-@asyncio.coroutine
# pylint: disable=unused-argument
-def handler(request: web.Request) -> web.StreamResponse:
+async def handler(request: web.Request) -> web.StreamResponse:
"""Dummy request handler, returning `TEST_BODY`."""
response = web.Response(text=TEST_BODY)
@@ -45,800 +40,876 @@ def handler(request: web.Request) -> web.StreamResponse:
return response
-class AioAiohttpAppTestBase(AioTestBase):
- """Base class for tests that create single aiohttp server.
+class WebViewHandler(web.View, CorsViewMixin):
- Class manages server creation using create_server() method and proper
- server shutdown.
- """
+ async def get(self) -> web.StreamResponse:
+ """Dummy request handler, returning `TEST_BODY`."""
+ response = web.Response(text=TEST_BODY)
- def setUp(self):
- super().setUp()
+ response.headers[SERVER_CUSTOM_HEADER_NAME] = \
+ SERVER_CUSTOM_HEADER_VALUE
- self.handler = None
- self.app = None
- self.url = None
+ return response
- self.server = None
- self.session = aiohttp.ClientSession(loop=self.loop)
+@pytest.fixture(params=['resource', 'view', 'route'])
+def make_app(request):
+ def inner(defaults, route_config):
+ app = web.Application()
+ cors = _setup(app, defaults=defaults)
- def tearDown(self):
- self.session.close()
+ if request.param == 'resource':
+ resource = cors.add(app.router.add_resource("/resource"))
+ cors.add(resource.add_route("GET", handler), route_config)
+ elif request.param == 'view':
+ WebViewHandler.cors_config = route_config
+ cors.add(
+ app.router.add_route("*", "/resource", WebViewHandler))
+ elif request.param == 'route':
+ cors.add(
+ app.router.add_route("GET", "/resource", handler),
+ route_config)
+ else:
+ raise RuntimeError('unknown parameter {}'.format(request.param))
- if self.server is not None:
- self.loop.run_until_complete(self.shutdown_server())
+ return app
- super().tearDown()
+ return inner
- @asyncio.coroutine
- def create_server(self, app: web.Application):
- """Create server listening on random port."""
- assert self.app is None
- self.app = app
+async def test_message_roundtrip(aiohttp_client):
+ """Test that aiohttp server is correctly setup in the base class."""
- assert self.handler is None
- self.handler = app.make_handler()
+ app = web.Application()
+ app.router.add_route("GET", "/", handler)
- self.server = (yield from create_server(self.handler, self.loop))
+ client = await aiohttp_client(app)
- return self.server
+ resp = await client.get('/')
+ assert resp.status == 200
+ data = await resp.text()
- @property
- def server_url(self):
- """Server navigatable URL."""
- assert self.server is not None
- hostaddr, port = self.server.sockets[0].getsockname()
- return "http://{host}:{port}/".format(host=hostaddr, port=port)
+ assert data == TEST_BODY
- @asyncio.coroutine
- def shutdown_server(self):
- """Shutdown server."""
- assert self.server is not None
- self.server.close()
- yield from self.handler.finish_connections()
- yield from self.server.wait_closed()
- yield from self.app.cleanup()
+async def test_dummy_setup(aiohttp_server):
+ """Test a dummy configuration."""
+ app = web.Application()
+ _setup(app)
- self.server = None
- self.app = None
- self.handler = None
+ await aiohttp_server(app)
-class TestMain(AioAiohttpAppTestBase):
- """Tests CORS server by issuing CORS requests."""
+async def test_dummy_setup_roundtrip(aiohttp_client):
+ """Test a dummy configuration with a message round-trip."""
+ app = web.Application()
+ _setup(app)
- @asynctest
- @asyncio.coroutine
- def test_message_roundtrip(self):
- """Test that aiohttp server is correctly setup in the base class."""
+ app.router.add_route("GET", "/", handler)
- app = web.Application()
+ client = await aiohttp_client(app)
- app.router.add_route("GET", "/", handler)
+ resp = await client.get('/')
+ assert resp.status == 200
+ data = await resp.text()
- yield from self.create_server(app)
+ assert data == TEST_BODY
- response = yield from self.session.request("GET", self.server_url)
- self.assertEqual(response.status, 200)
- data = yield from response.text()
- self.assertEqual(data, TEST_BODY)
+async def test_dummy_setup_roundtrip_resource(aiohttp_client):
+ """Test a dummy configuration with a message round-trip."""
+ app = web.Application()
+ _setup(app)
- @asynctest
- @asyncio.coroutine
- def test_dummy_setup(self):
- """Test a dummy configuration."""
- app = web.Application()
- setup(app)
+ app.router.add_resource("/").add_route("GET", handler)
- yield from self.create_server(app)
+ client = await aiohttp_client(app)
- @asynctest
- @asyncio.coroutine
- def test_dummy_setup_roundtrip(self):
- """Test a dummy configuration with a message round-trip."""
- app = web.Application()
- setup(app)
+ resp = await client.get('/')
+ assert resp.status == 200
+ data = await resp.text()
- app.router.add_route("GET", "/", handler)
+ assert data == TEST_BODY
- yield from self.create_server(app)
- response = yield from self.session.request("GET", self.server_url)
- self.assertEqual(response.status, 200)
- data = yield from response.text()
+async def test_simple_no_origin(aiohttp_client, make_app):
+ app = make_app(None, {"http://client1.example.org":
+ ResourceOptions()})
- self.assertEqual(data, TEST_BODY)
+ client = await aiohttp_client(app)
- @asynctest
- @asyncio.coroutine
- def test_dummy_setup_roundtrip_resource(self):
- """Test a dummy configuration with a message round-trip."""
- app = web.Application()
- setup(app)
-
- app.router.add_resource("/").add_route("GET", handler)
-
- yield from self.create_server(app)
-
- response = yield from self.session.request("GET", self.server_url)
- self.assertEqual(response.status, 200)
- data = yield from response.text()
-
- self.assertEqual(data, TEST_BODY)
-
- @asyncio.coroutine
- def _run_simple_requests_tests(self,
- tests_descriptions,
- use_resources):
- """Runs CORS simple requests (without a preflight request) based
- on the passed tests descriptions.
- """
-
- @asyncio.coroutine
- def run_test(test):
- """Run single test"""
-
- response = yield from self.session.get(
- self.server_url + "resource",
- headers=test.get("request_headers", {}))
- self.assertEqual(response.status, 200)
- self.assertEqual((yield from response.text()), TEST_BODY)
-
- for header_name, header_value in test.get(
- "in_response_headers", {}).items():
- with self.subTest(header_name=header_name):
- self.assertEqual(
- response.headers.get(header_name),
- header_value)
-
- for header_name in test.get("not_in_request_headers", {}).items():
- self.assertNotIn(header_name, response.headers)
-
- for test_descr in tests_descriptions:
- with self.subTest(group_name=test_descr["name"]):
- app = web.Application()
- cors = setup(app, defaults=test_descr["defaults"])
-
- if use_resources:
- resource = cors.add(app.router.add_resource("/resource"))
- cors.add(resource.add_route("GET", handler),
- test_descr["route_config"])
-
- else:
- cors.add(
- app.router.add_route("GET", "/resource", handler),
- test_descr["route_config"])
-
- yield from self.create_server(app)
-
- try:
- for test_data in test_descr["tests"]:
- with self.subTest(name=test_data["name"]):
- yield from run_test(test_data)
- finally:
- yield from self.shutdown_server()
-
- @asynctest
- @asyncio.coroutine
- def test_simple_default(self):
- """Test CORS simple requests with a route with the default
- configuration.
-
- The default configuration means that:
- * no credentials are allowed,
- * no headers are exposed,
- * no client headers are allowed.
- """
-
- client1 = "http://client1.example.org"
- client2 = "http://client2.example.org"
- client1_80 = "http://client1.example.org:80"
- client1_https = "https://client2.example.org"
-
- tests_descriptions = [
- {
- "name": "default",
- "defaults": None,
- "route_config":
- {
- client1: ResourceOptions(),
- },
- "tests": [
- {
- "name": "no origin header",
- "not_in_response_headers": {
- hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
- hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
- hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
- }
- },
- {
- "name": "allowed origin",
- "request_headers": {
- hdrs.ORIGIN: client1,
- },
- "in_response_headers": {
- hdrs.ACCESS_CONTROL_ALLOW_ORIGIN: client1,
- },
- "not_in_response_headers": {
- hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
- hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
- }
- },
- {
- "name": "not allowed origin",
- "request_headers": {
- hdrs.ORIGIN: client2,
- },
- "not_in_response_headers": {
- hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
- hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
- hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
- }
- },
- {
- "name": "explicitly specified default port",
- # CORS specification says, that origins may compared
- # as strings, so "example.org:80" is not the same as
- # "example.org".
- "request_headers": {
- hdrs.ORIGIN: client1_80,
- },
- "not_in_response_headers": {
- hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
- hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
- hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
- }
- },
- {
- "name": "different scheme",
- "request_headers": {
- hdrs.ORIGIN: client1_https,
- },
- "not_in_response_headers": {
- hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
- hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
- hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
- }
- },
- ],
- },
- ]
-
- yield from self._run_simple_requests_tests(tests_descriptions, False)
- yield from self._run_simple_requests_tests(tests_descriptions, True)
-
- @asynctest
- @asyncio.coroutine
- def test_simple_with_credentials(self):
- """Test CORS simple requests with a route with enabled authorization.
-
- Route with enabled authorization must return
- Origin: <origin as requested, NOT "*">
- Access-Control-Allow-Credentials: true
- """
-
- client1 = "http://client1.example.org"
- client2 = "http://client2.example.org"
-
- credential_tests = [
- {
- "name": "no origin header",
- "not_in_response_headers": {
- hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
- hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
- hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
- }
- },
- {
- "name": "allowed origin",
- "request_headers": {
- hdrs.ORIGIN: client1,
- },
- "in_response_headers": {
- hdrs.ACCESS_CONTROL_ALLOW_ORIGIN: client1,
- hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS: "true",
- },
- "not_in_response_headers": {
- hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
- }
- },
- {
- "name": "disallowed origin",
- "request_headers": {
- hdrs.ORIGIN: client2,
- },
- "not_in_response_headers": {
- hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
- hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
- hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
- }
- },
- ]
-
- tests_descriptions = [
- {
- "name": "route settings",
- "defaults": None,
- "route_config":
- {
- client1: ResourceOptions(allow_credentials=True),
- },
- "tests": credential_tests,
- },
- {
- "name": "cors default settings",
- "defaults":
- {
- client1: ResourceOptions(allow_credentials=True),
- },
- "route_config": None,
- "tests": credential_tests,
- },
- ]
-
- yield from self._run_simple_requests_tests(tests_descriptions, False)
- yield from self._run_simple_requests_tests(tests_descriptions, True)
-
- @asynctest
- @asyncio.coroutine
- def test_simple_expose_headers(self):
- """Test CORS simple requests with a route that exposes header."""
-
- client1 = "http://client1.example.org"
- client2 = "http://client2.example.org"
-
- tests_descriptions = [
- {
- "name": "default",
- "defaults": None,
- "route_config":
- {
- client1: ResourceOptions(
- expose_headers=(SERVER_CUSTOM_HEADER_NAME,)),
- },
- "tests": [
- {
- "name": "no origin header",
- "not_in_response_headers": {
- hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
- hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
- hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
- }
- },
- {
- "name": "allowed origin",
- "request_headers": {
- hdrs.ORIGIN: client1,
- },
- "in_response_headers": {
- hdrs.ACCESS_CONTROL_ALLOW_ORIGIN: client1,
- hdrs.ACCESS_CONTROL_EXPOSE_HEADERS:
- SERVER_CUSTOM_HEADER_NAME,
- },
- "not_in_response_headers": {
- hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
- }
- },
- {
- "name": "not allowed origin",
- "request_headers": {
- hdrs.ORIGIN: client2,
- },
- "not_in_response_headers": {
- hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
- hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
- hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
- }
- },
- ],
- },
- ]
-
- yield from self._run_simple_requests_tests(tests_descriptions, False)
- yield from self._run_simple_requests_tests(tests_descriptions, True)
- yield from self._run_simple_requests_tests(tests_descriptions, True)
-
- @asyncio.coroutine
- def _run_preflight_requests_tests(self, tests_descriptions, use_resources):
- """Runs CORS preflight requests based on the passed tests descriptions.
- """
-
- @asyncio.coroutine
- def run_test(test):
- """Run single test"""
-
- response = yield from self.session.options(
- self.server_url + "resource",
- headers=test.get("request_headers", {}))
- self.assertEqual(response.status, test.get("response_status", 200))
- response_text = yield from response.text()
- in_response = test.get("in_response")
- if in_response is not None:
- self.assertIn(in_response, response_text)
- else:
- self.assertEqual(response_text, "")
-
- for header_name, header_value in test.get(
- "in_response_headers", {}).items():
- self.assertEqual(
- response.headers.get(header_name),
- header_value)
-
- for header_name in test.get("not_in_request_headers", {}).items():
- self.assertNotIn(header_name, response.headers)
-
- for test_descr in tests_descriptions:
- with self.subTest(group_name=test_descr["name"]):
- app = web.Application()
- cors = setup(app, defaults=test_descr["defaults"])
-
- if use_resources:
- resource = cors.add(app.router.add_resource("/resource"))
- cors.add(resource.add_route("GET", handler),
- test_descr["route_config"])
-
- else:
- cors.add(
- app.router.add_route("GET", "/resource", handler),
- test_descr["route_config"])
-
- yield from self.create_server(app)
-
- try:
- for test_data in test_descr["tests"]:
- with self.subTest(name=test_data["name"]):
- yield from run_test(test_data)
- finally:
- yield from self.shutdown_server()
-
- @asynctest
- @asyncio.coroutine
- def test_preflight_default(self):
- """Test CORS preflight requests with a route with the default
- configuration.
-
- The default configuration means that:
- * no credentials are allowed,
- * no headers are exposed,
- * no client headers are allowed.
- """
-
- client1 = "http://client1.example.org"
- client2 = "http://client2.example.org"
-
- tests_descriptions = [
- {
- "name": "default",
- "defaults": None,
- "route_config":
- {
- client1: ResourceOptions(),
- },
- "tests": [
- {
- "name": "no origin",
- "response_status": 403,
- "in_response": "origin header is not specified",
- "not_in_response_headers": {
- hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
- hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
- hdrs.ACCESS_CONTROL_MAX_AGE,
- hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
- hdrs.ACCESS_CONTROL_ALLOW_METHODS,
- hdrs.ACCESS_CONTROL_ALLOW_HEADERS,
- },
- },
- {
- "name": "no method",
- "request_headers": {
- hdrs.ORIGIN: client1,
- },
- "response_status": 403,
- "in_response": "'Access-Control-Request-Method' "
- "header is not specified",
- "not_in_response_headers": {
- hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
- hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
- hdrs.ACCESS_CONTROL_MAX_AGE,
- hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
- hdrs.ACCESS_CONTROL_ALLOW_METHODS,
- hdrs.ACCESS_CONTROL_ALLOW_HEADERS,
- },
- },
- {
- "name": "origin and method",
- "request_headers": {
- hdrs.ORIGIN: client1,
- hdrs.ACCESS_CONTROL_REQUEST_METHOD: "GET",
- },
- "in_response_headers": {
- hdrs.ACCESS_CONTROL_ALLOW_ORIGIN: client1,
- hdrs.ACCESS_CONTROL_ALLOW_METHODS: "GET",
- },
- "not_in_response_headers": {
- hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
- hdrs.ACCESS_CONTROL_MAX_AGE,
- hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
- hdrs.ACCESS_CONTROL_ALLOW_HEADERS,
- },
- },
- {
- "name": "disallowed origin",
- "request_headers": {
- hdrs.ORIGIN: client2,
- hdrs.ACCESS_CONTROL_REQUEST_METHOD: "GET",
- },
- "response_status": 403,
- "in_response": "origin '{}' is not allowed".format(
- client2),
- "not_in_response_headers": {
- hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
- hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
- hdrs.ACCESS_CONTROL_MAX_AGE,
- hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
- hdrs.ACCESS_CONTROL_ALLOW_METHODS,
- hdrs.ACCESS_CONTROL_ALLOW_HEADERS,
- },
- },
- {
- "name": "disallowed method",
- "request_headers": {
- hdrs.ORIGIN: client1,
- hdrs.ACCESS_CONTROL_REQUEST_METHOD: "POST",
- },
- "response_status": 403,
- "in_response": "request method 'POST' is not allowed",
- "not_in_response_headers": {
- hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
- hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
- hdrs.ACCESS_CONTROL_MAX_AGE,
- hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
- hdrs.ACCESS_CONTROL_ALLOW_METHODS,
- hdrs.ACCESS_CONTROL_ALLOW_HEADERS,
- },
- },
- ],
- },
- ]
-
- yield from self._run_preflight_requests_tests(
- tests_descriptions, False)
- yield from self._run_preflight_requests_tests(
- tests_descriptions, True)
-
- @asynctest
- @asyncio.coroutine
- def test_preflight_request_multiple_routes_with_one_options(self):
- """Test CORS preflight handling on resource that is available through
- several routes.
- """
- app = web.Application()
- cors = setup(app, defaults={
- "*": ResourceOptions(
- allow_credentials=True,
- expose_headers="*",
- allow_headers="*",
- )
- })
-
- cors.add(app.router.add_route("GET", "/{name}", handler))
- cors.add(app.router.add_route("PUT", "/{name}", handler))
-
- yield from self.create_server(app)
-
- response = yield from self.session.request(
- "OPTIONS", self.server_url + "user",
- headers={
- hdrs.ORIGIN: "http://example.org",
- hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT"
- }
+ resp = await client.get("/resource")
+ assert resp.status == 200
+ resp_text = await resp.text()
+ assert resp_text == TEST_BODY
+
+ for header_name in {
+ hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
+ hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
+ hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
+ }:
+ assert header_name not in resp.headers
+
+
+async def test_simple_allowed_origin(aiohttp_client, make_app):
+ app = make_app(None, {"http://client1.example.org":
+ ResourceOptions()})
+
+ client = await aiohttp_client(app)
+
+ resp = await client.get("/resource",
+ headers={hdrs.ORIGIN:
+ 'http://client1.example.org'})
+ assert resp.status == 200
+ resp_text = await resp.text()
+ assert resp_text == TEST_BODY
+
+ for hdr, val in {
+ hdrs.ACCESS_CONTROL_ALLOW_ORIGIN: 'http://client1.example.org',
+ }.items():
+ assert resp.headers.get(hdr) == val
+
+ for header_name in {
+ hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
+ hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
+ }:
+ assert header_name not in resp.headers
+
+
+async def test_simple_not_allowed_origin(aiohttp_client, make_app):
+ app = make_app(None, {"http://client1.example.org":
+ ResourceOptions()})
+
+ client = await aiohttp_client(app)
+
+ resp = await client.get("/resource",
+ headers={hdrs.ORIGIN:
+ 'http://client2.example.org'})
+ assert resp.status == 200
+ resp_text = await resp.text()
+ assert resp_text == TEST_BODY
+
+ for header_name in {
+ hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
+ hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
+ hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
+ }:
+ assert header_name not in resp.headers
+
+
+async def test_simple_explicit_port(aiohttp_client, make_app):
+ app = make_app(None, {"http://client1.example.org":
+ ResourceOptions()})
+
+ client = await aiohttp_client(app)
+
+ resp = await client.get("/resource",
+ headers={hdrs.ORIGIN:
+ 'http://client1.example.org:80'})
+ assert resp.status == 200
+ resp_text = await resp.text()
+ assert resp_text == TEST_BODY
+
+ for header_name in {
+ hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
+ hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
+ hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
+ }:
+ assert header_name not in resp.headers
+
+
+async def test_simple_different_scheme(aiohttp_client, make_app):
+ app = make_app(None, {"http://client1.example.org":
+ ResourceOptions()})
+
+ client = await aiohttp_client(app)
+
+ resp = await client.get("/resource",
+ headers={hdrs.ORIGIN:
+ 'https://client1.example.org'})
+ assert resp.status == 200
+ resp_text = await resp.text()
+ assert resp_text == TEST_BODY
+
+ for header_name in {
+ hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
+ hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
+ hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
+ }:
+ assert header_name not in resp.headers
+
+
+@pytest.fixture(params=[
+ (None,
+ {"http://client1.example.org": ResourceOptions(allow_credentials=True)}),
+ ({"http://client1.example.org": ResourceOptions(allow_credentials=True)},
+ None),
+])
+def app_for_credentials(make_app, request):
+ return make_app(*request.param)
+
+
+async def test_cred_no_origin(aiohttp_client, app_for_credentials):
+ app = app_for_credentials
+
+ client = await aiohttp_client(app)
+
+ resp = await client.get("/resource")
+ assert resp.status == 200
+ resp_text = await resp.text()
+ assert resp_text == TEST_BODY
+
+ for header_name in {
+ hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
+ hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
+ hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
+ }:
+ assert header_name not in resp.headers
+
+
+async def test_cred_allowed_origin(aiohttp_client, app_for_credentials):
+ app = app_for_credentials
+
+ client = await aiohttp_client(app)
+
+ resp = await client.get("/resource",
+ headers={hdrs.ORIGIN:
+ 'http://client1.example.org'})
+ assert resp.status == 200
+ resp_text = await resp.text()
+ assert resp_text == TEST_BODY
+
+ for hdr, val in {
+ hdrs.ACCESS_CONTROL_ALLOW_ORIGIN: 'http://client1.example.org',
+ hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS: "true"}.items():
+ assert resp.headers.get(hdr) == val
+
+ for header_name in {
+ hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
+ }:
+ assert header_name not in resp.headers
+
+
+async def test_cred_disallowed_origin(aiohttp_client, app_for_credentials):
+ app = app_for_credentials
+
+ client = await aiohttp_client(app)
+
+ resp = await client.get("/resource",
+ headers={hdrs.ORIGIN:
+ 'http://client2.example.org'})
+ assert resp.status == 200
+ resp_text = await resp.text()
+ assert resp_text == TEST_BODY
+
+ for header_name in {
+ hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
+ hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
+ hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
+ }:
+ assert header_name not in resp.headers
+
+
+async def test_simple_expose_headers_no_origin(aiohttp_client, make_app):
+ app = make_app(None, {"http://client1.example.org":
+ ResourceOptions(
+ expose_headers=(SERVER_CUSTOM_HEADER_NAME,))})
+
+ client = await aiohttp_client(app)
+
+ resp = await client.get("/resource")
+ assert resp.status == 200
+ resp_text = await resp.text()
+ assert resp_text == TEST_BODY
+
+ for header_name in {
+ hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
+ hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
+ hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
+ }:
+ assert header_name not in resp.headers
+
+
+async def test_simple_expose_headers_allowed_origin(aiohttp_client, make_app):
+ app = make_app(None, {"http://client1.example.org":
+ ResourceOptions(
+ expose_headers=(SERVER_CUSTOM_HEADER_NAME,))})
+
+ client = await aiohttp_client(app)
+
+ resp = await client.get("/resource",
+ headers={hdrs.ORIGIN:
+ 'http://client1.example.org'})
+ assert resp.status == 200
+ resp_text = await resp.text()
+ assert resp_text == TEST_BODY
+
+ for hdr, val in {
+ hdrs.ACCESS_CONTROL_ALLOW_ORIGIN: 'http://client1.example.org',
+ hdrs.ACCESS_CONTROL_EXPOSE_HEADERS:
+ SERVER_CUSTOM_HEADER_NAME}.items():
+ assert resp.headers.get(hdr) == val
+
+ for header_name in {
+ hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
+ }:
+ assert header_name not in resp.headers
+
+
+async def test_simple_expose_headers_not_allowed_origin(aiohttp_client,
+ make_app):
+ app = make_app(None, {"http://client1.example.org":
+ ResourceOptions(
+ expose_headers=(SERVER_CUSTOM_HEADER_NAME,))})
+
+ client = await aiohttp_client(app)
+
+ resp = await client.get("/resource",
+ headers={hdrs.ORIGIN:
+ 'http://client2.example.org'})
+ assert resp.status == 200
+ resp_text = await resp.text()
+ assert resp_text == TEST_BODY
+
+ for header_name in {
+ hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
+ hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
+ hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
+ }:
+ assert header_name not in resp.headers
+
+
+async def test_preflight_default_no_origin(aiohttp_client, make_app):
+ app = make_app(None, {"http://client1.example.org":
+ ResourceOptions()})
+
+ client = await aiohttp_client(app)
+
+ resp = await client.options("/resource")
+ assert resp.status == 403
+ resp_text = await resp.text()
+ assert "origin header is not specified" in resp_text
+
+ for header_name in {
+ hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
+ hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
+ hdrs.ACCESS_CONTROL_MAX_AGE,
+ hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
+ hdrs.ACCESS_CONTROL_ALLOW_METHODS,
+ hdrs.ACCESS_CONTROL_ALLOW_HEADERS,
+ }:
+ assert header_name not in resp.headers
+
+
+async def test_preflight_default_no_method(aiohttp_client, make_app):
+
+ app = make_app(None, {"http://client1.example.org":
+ ResourceOptions()})
+
+ client = await aiohttp_client(app)
+
+ resp = await client.options("/resource", headers={
+ hdrs.ORIGIN: "http://client1.example.org",
+ })
+ assert resp.status == 403
+ resp_text = await resp.text()
+ assert "'Access-Control-Request-Method' header is not specified"\
+ in resp_text
+
+ for header_name in {
+ hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
+ hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
+ hdrs.ACCESS_CONTROL_MAX_AGE,
+ hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
+ hdrs.ACCESS_CONTROL_ALLOW_METHODS,
+ hdrs.ACCESS_CONTROL_ALLOW_HEADERS,
+ }:
+ assert header_name not in resp.headers
+
+
+async def test_preflight_default_origin_and_method(aiohttp_client, make_app):
+
+ app = make_app(None, {"http://client1.example.org":
+ ResourceOptions()})
+
+ client = await aiohttp_client(app)
+
+ resp = await client.options("/resource", headers={
+ hdrs.ORIGIN: "http://client1.example.org",
+ hdrs.ACCESS_CONTROL_REQUEST_METHOD: "GET",
+ })
+ assert resp.status == 200
+ resp_text = await resp.text()
+ assert '' == resp_text
+
+ for hdr, val in {
+ hdrs.ACCESS_CONTROL_ALLOW_ORIGIN: "http://client1.example.org",
+ hdrs.ACCESS_CONTROL_ALLOW_METHODS: "GET"}.items():
+ assert resp.headers.get(hdr) == val
+
+ for header_name in {
+ hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
+ hdrs.ACCESS_CONTROL_MAX_AGE,
+ hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
+ hdrs.ACCESS_CONTROL_ALLOW_HEADERS,
+ }:
+ assert header_name not in resp.headers
+
+
+async def test_preflight_default_disallowed_origin(aiohttp_client, make_app):
+
+ app = make_app(None, {"http://client1.example.org":
+ ResourceOptions()})
+
+ client = await aiohttp_client(app)
+
+ resp = await client.options("/resource", headers={
+ hdrs.ORIGIN: "http://client2.example.org",
+ hdrs.ACCESS_CONTROL_REQUEST_METHOD: "GET",
+ })
+ assert resp.status == 403
+ resp_text = await resp.text()
+ assert "origin 'http://client2.example.org' is not allowed" in resp_text
+
+ for header_name in {
+ hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
+ hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
+ hdrs.ACCESS_CONTROL_MAX_AGE,
+ hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
+ hdrs.ACCESS_CONTROL_ALLOW_METHODS,
+ hdrs.ACCESS_CONTROL_ALLOW_HEADERS,
+ }:
+ assert header_name not in resp.headers
+
+
+async def test_preflight_default_disallowed_method(aiohttp_client, make_app):
+
+ app = make_app(None, {"http://client1.example.org":
+ ResourceOptions()})
+
+ client = await aiohttp_client(app)
+
+ resp = await client.options("/resource", headers={
+ hdrs.ORIGIN: "http://client1.example.org",
+ hdrs.ACCESS_CONTROL_REQUEST_METHOD: "POST",
+ })
+ assert resp.status == 403
+ resp_text = await resp.text()
+ assert ("request method 'POST' is not allowed for "
+ "'http://client1.example.org' origin" in resp_text)
+
+ for header_name in {
+ hdrs.ACCESS_CONTROL_ALLOW_ORIGIN,
+ hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS,
+ hdrs.ACCESS_CONTROL_MAX_AGE,
+ hdrs.ACCESS_CONTROL_EXPOSE_HEADERS,
+ hdrs.ACCESS_CONTROL_ALLOW_METHODS,
+ hdrs.ACCESS_CONTROL_ALLOW_HEADERS,
+ }:
+ assert header_name not in resp.headers
+
+
+async def test_preflight_req_multiple_routes_with_one_options(aiohttp_client):
+ """Test CORS preflight handling on resource that is available through
+ several routes.
+ """
+ app = web.Application()
+ cors = _setup(app, defaults={
+ "*": ResourceOptions(
+ allow_credentials=True,
+ expose_headers="*",
+ allow_headers="*",
)
- self.assertEqual(response.status, 200)
+ })
- data = yield from response.text()
- self.assertEqual(data, "")
+ cors.add(app.router.add_route("GET", "/{name}", handler))
+ cors.add(app.router.add_route("PUT", "/{name}", handler))
- @asynctest
- @asyncio.coroutine
- def test_preflight_request_multiple_routes_with_one_options_resource(self):
- """Test CORS preflight handling on resource that is available through
- several routes.
- """
- app = web.Application()
- cors = setup(app, defaults={
- "*": ResourceOptions(
- allow_credentials=True,
- expose_headers="*",
- allow_headers="*",
- )
- })
-
- resource = cors.add(app.router.add_resource("/{name}"))
- cors.add(resource.add_route("GET", handler))
- cors.add(resource.add_route("PUT", handler))
-
- yield from self.create_server(app)
-
- response = yield from self.session.request(
- "OPTIONS", self.server_url + "user",
- headers={
- hdrs.ORIGIN: "http://example.org",
- hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT"
- }
+ client = await aiohttp_client(app)
+
+ resp = await client.options(
+ "/user",
+ headers={
+ hdrs.ORIGIN: "http://example.org",
+ hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT"
+ }
+ )
+ assert resp.status == 200
+
+ data = await resp.text()
+ assert data == ""
+
+
+async def test_preflight_request_mult_routes_with_one_options_resource(
+ aiohttp_client):
+ """Test CORS preflight handling on resource that is available through
+ several routes.
+ """
+ app = web.Application()
+ cors = _setup(app, defaults={
+ "*": ResourceOptions(
+ allow_credentials=True,
+ expose_headers="*",
+ allow_headers="*",
)
- self.assertEqual(response.status, 200)
+ })
- data = yield from response.text()
- self.assertEqual(data, "")
+ resource = cors.add(app.router.add_resource("/{name}"))
+ cors.add(resource.add_route("GET", handler))
+ cors.add(resource.add_route("PUT", handler))
- @asynctest
- @asyncio.coroutine
- def test_preflight_request_headers_resource(self):
- """Test CORS preflight request handlers handling."""
- app = web.Application()
- cors = setup(app, defaults={
- "*": ResourceOptions(
- allow_credentials=True,
- expose_headers="*",
- allow_headers=("Content-Type", "X-Header"),
- )
- })
-
- cors.add(app.router.add_route("PUT", "/", handler))
-
- yield from self.create_server(app)
-
- response = yield from self.session.request(
- "OPTIONS", self.server_url,
- headers={
- hdrs.ORIGIN: "http://example.org",
- hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT",
- hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type",
- }
+ client = await aiohttp_client(app)
+
+ resp = await client.options(
+ "/user",
+ headers={
+ hdrs.ORIGIN: "http://example.org",
+ hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT"
+ }
+ )
+ assert resp.status == 200
+
+ data = await resp.text()
+ assert data == ""
+
+
+async def test_preflight_request_max_age_resource(aiohttp_client):
+ """Test CORS preflight handling on resource that is available through
+ several routes.
+ """
+ app = web.Application()
+ cors = _setup(app, defaults={
+ "*": ResourceOptions(
+ allow_credentials=True,
+ expose_headers="*",
+ allow_headers="*",
+ max_age=1200
)
- self.assertEqual((yield from response.text()), "")
- self.assertEqual(response.status, 200)
- # Access-Control-Allow-Headers must be compared in case-insensitive
- # way.
- self.assertEqual(
- response.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS].upper(),
- "content-type".upper())
+ })
+
+ resource = cors.add(app.router.add_resource("/{name}"))
+ cors.add(resource.add_route("GET", handler))
+
+ client = await aiohttp_client(app)
+
+ resp = await client.options(
+ "/user",
+ headers={
+ hdrs.ORIGIN: "http://example.org",
+ hdrs.ACCESS_CONTROL_REQUEST_METHOD: "GET"
+ }
+ )
+ assert resp.status == 200
+ assert resp.headers[hdrs.ACCESS_CONTROL_MAX_AGE].upper() == "1200"
+
+ data = await resp.text()
+ assert data == ""
- response = yield from self.session.request(
- "OPTIONS", self.server_url,
- headers={
- hdrs.ORIGIN: "http://example.org",
- hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT",
- hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "X-Header,content-type",
- }
+
+async def test_preflight_request_max_age_webview(aiohttp_client):
+ """Test CORS preflight handling on resource that is available through
+ several routes.
+ """
+ app = web.Application()
+ cors = _setup(app, defaults={
+ "*": ResourceOptions(
+ allow_credentials=True,
+ expose_headers="*",
+ allow_headers="*",
+ max_age=1200
)
- self.assertEqual(response.status, 200)
- # Access-Control-Allow-Headers must be compared in case-insensitive
- # way.
- self.assertEqual(
- frozenset(response.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS]
- .upper().split(",")),
- {"X-Header".upper(), "content-type".upper()})
- self.assertEqual((yield from response.text()), "")
-
- response = yield from self.session.request(
- "OPTIONS", self.server_url,
- headers={
- hdrs.ORIGIN: "http://example.org",
- hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT",
- hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type,Test",
- }
+ })
+
+ class TestView(web.View, CorsViewMixin):
+ async def get(self):
+ resp = web.Response(text=TEST_BODY)
+
+ resp.headers[SERVER_CUSTOM_HEADER_NAME] = \
+ SERVER_CUSTOM_HEADER_VALUE
+
+ return resp
+
+ cors.add(app.router.add_route("*", "/{name}", TestView))
+
+ client = await aiohttp_client(app)
+
+ resp = await client.options(
+ "/user",
+ headers={
+ hdrs.ORIGIN: "http://example.org",
+ hdrs.ACCESS_CONTROL_REQUEST_METHOD: "GET"
+ }
+ )
+ assert resp.status == 200
+ assert resp.headers[hdrs.ACCESS_CONTROL_MAX_AGE].upper() == "1200"
+
+ data = await resp.text()
+ assert data == ""
+
+
+async def test_preflight_request_mult_routes_with_one_options_webview(
+ aiohttp_client):
+ """Test CORS preflight handling on resource that is available through
+ several routes.
+ """
+ app = web.Application()
+ cors = _setup(app, defaults={
+ "*": ResourceOptions(
+ allow_credentials=True,
+ expose_headers="*",
+ allow_headers="*",
)
- self.assertEqual(response.status, 403)
- self.assertNotIn(
- hdrs.ACCESS_CONTROL_ALLOW_HEADERS,
- response.headers)
- self.assertIn(
- "headers are not allowed: TEST",
- (yield from response.text()))
-
- @asynctest
- @asyncio.coroutine
- def test_preflight_request_headers(self):
- """Test CORS preflight request handlers handling."""
- app = web.Application()
- cors = setup(app, defaults={
- "*": ResourceOptions(
- allow_credentials=True,
- expose_headers="*",
- allow_headers=("Content-Type", "X-Header"),
- )
- })
-
- resource = cors.add(app.router.add_resource("/"))
- cors.add(resource.add_route("PUT", handler))
-
- yield from self.create_server(app)
-
- response = yield from self.session.request(
- "OPTIONS", self.server_url,
- headers={
- hdrs.ORIGIN: "http://example.org",
- hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT",
- hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type",
- }
+ })
+
+ class TestView(web.View, CorsViewMixin):
+ async def get(self):
+ resp = web.Response(text=TEST_BODY)
+
+ resp.headers[SERVER_CUSTOM_HEADER_NAME] = \
+ SERVER_CUSTOM_HEADER_VALUE
+
+ return resp
+
+ put = get
+
+ cors.add(app.router.add_route("*", "/{name}", TestView))
+
+ client = await aiohttp_client(app)
+
+ resp = await client.options(
+ "/user",
+ headers={
+ hdrs.ORIGIN: "http://example.org",
+ hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT"
+ }
+ )
+ assert resp.status == 200
+
+ data = await resp.text()
+ assert data == ""
+
+
+async def test_preflight_request_headers_webview(aiohttp_client):
+ """Test CORS preflight request handlers handling."""
+ app = web.Application()
+ cors = _setup(app, defaults={
+ "*": ResourceOptions(
+ allow_credentials=True,
+ expose_headers="*",
+ allow_headers=("Content-Type", "X-Header"),
)
- self.assertEqual((yield from response.text()), "")
- self.assertEqual(response.status, 200)
- # Access-Control-Allow-Headers must be compared in case-insensitive
- # way.
- self.assertEqual(
- response.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS].upper(),
+ })
+
+ class TestView(web.View, CorsViewMixin):
+ async def put(self):
+ response = web.Response(text=TEST_BODY)
+
+ response.headers[SERVER_CUSTOM_HEADER_NAME] = \
+ SERVER_CUSTOM_HEADER_VALUE
+
+ return response
+
+ cors.add(app.router.add_route("*", "/", TestView))
+
+ client = await aiohttp_client(app)
+
+ resp = await client.options(
+ '/',
+ headers={
+ hdrs.ORIGIN: "http://example.org",
+ hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT",
+ hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type",
+ }
+ )
+ assert (await resp.text()) == ""
+ assert resp.status == 200
+ # Access-Control-Allow-Headers must be compared in case-insensitive
+ # way.
+ assert (resp.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS].upper() ==
"content-type".upper())
- response = yield from self.session.request(
- "OPTIONS", self.server_url,
- headers={
- hdrs.ORIGIN: "http://example.org",
- hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT",
- hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "X-Header,content-type",
- }
+ resp = await client.options(
+ '/',
+ headers={
+ hdrs.ORIGIN: "http://example.org",
+ hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT",
+ hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "X-Header,content-type",
+ }
+ )
+ assert resp.status == 200
+ # Access-Control-Allow-Headers must be compared in case-insensitive
+ # way.
+ assert (
+ frozenset(resp.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS]
+ .upper().split(",")) ==
+ {"X-Header".upper(), "content-type".upper()})
+ assert (await resp.text()) == ""
+
+ resp = await client.options(
+ '/',
+ headers={
+ hdrs.ORIGIN: "http://example.org",
+ hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT",
+ hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type,Test",
+ }
+ )
+ assert resp.status == 403
+ assert hdrs.ACCESS_CONTROL_ALLOW_HEADERS not in resp.headers
+ assert "headers are not allowed: TEST" in (await resp.text())
+
+
+async def test_preflight_request_headers_resource(aiohttp_client):
+ """Test CORS preflight request handlers handling."""
+ app = web.Application()
+ cors = _setup(app, defaults={
+ "*": ResourceOptions(
+ allow_credentials=True,
+ expose_headers="*",
+ allow_headers=("Content-Type", "X-Header"),
)
- self.assertEqual(response.status, 200)
- # Access-Control-Allow-Headers must be compared in case-insensitive
- # way.
- self.assertEqual(
- frozenset(response.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS]
- .upper().split(",")),
- {"X-Header".upper(), "content-type".upper()})
- self.assertEqual((yield from response.text()), "")
-
- response = yield from self.session.request(
- "OPTIONS", self.server_url,
- headers={
- hdrs.ORIGIN: "http://example.org",
- hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT",
- hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type,Test",
- }
+ })
+
+ cors.add(app.router.add_route("PUT", "/", handler))
+
+ client = await aiohttp_client(app)
+
+ resp = await client.options(
+ '/',
+ headers={
+ hdrs.ORIGIN: "http://example.org",
+ hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT",
+ hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type",
+ }
+ )
+ assert (await resp.text()) == ""
+ assert resp.status == 200
+ # Access-Control-Allow-Headers must be compared in case-insensitive
+ # way.
+ assert (
+ resp.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS].upper() ==
+ "content-type".upper())
+
+ resp = await client.options(
+ '/',
+ headers={
+ hdrs.ORIGIN: "http://example.org",
+ hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT",
+ hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "X-Header,content-type",
+ }
+ )
+ assert resp.status == 200
+ # Access-Control-Allow-Headers must be compared in case-insensitive
+ # way.
+ assert (
+ frozenset(resp.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS]
+ .upper().split(",")) ==
+ {"X-Header".upper(), "content-type".upper()})
+ assert (await resp.text()) == ""
+
+ resp = await client.options(
+ '/',
+ headers={
+ hdrs.ORIGIN: "http://example.org",
+ hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT",
+ hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type,Test",
+ }
+ )
+ assert resp.status == 403
+ assert hdrs.ACCESS_CONTROL_ALLOW_HEADERS not in resp.headers
+ assert "headers are not allowed: TEST" in (await resp.text())
+
+
+async def test_preflight_request_headers(aiohttp_client):
+ """Test CORS preflight request handlers handling."""
+ app = web.Application()
+ cors = _setup(app, defaults={
+ "*": ResourceOptions(
+ allow_credentials=True,
+ expose_headers="*",
+ allow_headers=("Content-Type", "X-Header"),
)
- self.assertEqual(response.status, 403)
- self.assertNotIn(
- hdrs.ACCESS_CONTROL_ALLOW_HEADERS,
- response.headers)
- self.assertIn(
- "headers are not allowed: TEST",
- (yield from response.text()))
-
- @asynctest
- @asyncio.coroutine
- def test_static_route(self):
- """Test a static route with CORS."""
- app = web.Application()
- cors = setup(app, defaults={
- "*": ResourceOptions(
- allow_credentials=True,
- expose_headers="*",
- allow_methods="*",
- allow_headers=("Content-Type", "X-Header"),
- )
- })
-
- test_static_path = pathlib.Path(__file__).parent
- cors.add(app.router.add_static("/static", test_static_path, name='static'))
-
- yield from self.create_server(app)
-
- response = yield from self.session.request(
- "OPTIONS", URL(self.server_url) / "static/test_page.html",
- headers={
- hdrs.ORIGIN: "http://example.org",
- hdrs.ACCESS_CONTROL_REQUEST_METHOD: "OPTIONS",
- hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type",
- }
+ })
+
+ resource = cors.add(app.router.add_resource("/"))
+ cors.add(resource.add_route("PUT", handler))
+
+ client = await aiohttp_client(app)
+
+ resp = await client.options(
+ '/',
+ headers={
+ hdrs.ORIGIN: "http://example.org",
+ hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT",
+ hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type",
+ }
+ )
+ assert (await resp.text()) == ""
+ assert resp.status == 200
+ # Access-Control-Allow-Headers must be compared in case-insensitive
+ # way.
+ assert (
+ resp.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS].upper() ==
+ "content-type".upper())
+
+ resp = await client.options(
+ '/',
+ headers={
+ hdrs.ORIGIN: "http://example.org",
+ hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT",
+ hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "X-Header,content-type",
+ }
+ )
+ assert resp.status == 200
+ # Access-Control-Allow-Headers must be compared in case-insensitive
+ # way.
+ assert (
+ frozenset(resp.headers[hdrs.ACCESS_CONTROL_ALLOW_HEADERS]
+ .upper().split(",")) ==
+ {"X-Header".upper(), "content-type".upper()})
+ assert (await resp.text()) == ""
+
+ resp = await client.options(
+ '/',
+ headers={
+ hdrs.ORIGIN: "http://example.org",
+ hdrs.ACCESS_CONTROL_REQUEST_METHOD: "PUT",
+ hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type,Test",
+ }
+ )
+ assert resp.status == 403
+ assert hdrs.ACCESS_CONTROL_ALLOW_HEADERS not in resp.headers
+ assert "headers are not allowed: TEST" in (await resp.text())
+
+
+async def test_static_route(aiohttp_client):
+ """Test a static route with CORS."""
+ app = web.Application()
+ cors = _setup(app, defaults={
+ "*": ResourceOptions(
+ allow_credentials=True,
+ expose_headers="*",
+ allow_methods="*",
+ allow_headers=("Content-Type", "X-Header"),
)
- data = yield from response.text()
- self.assertEqual(response.status, 200)
- self.assertEqual(data, '')
+ })
+
+ test_static_path = pathlib.Path(__file__).parent
+ cors.add(app.router.add_static("/static", test_static_path,
+ name='static'))
+
+ client = await aiohttp_client(app)
+
+ resp = await client.options(
+ "/static/test_page.html",
+ headers={
+ hdrs.ORIGIN: "http://example.org",
+ hdrs.ACCESS_CONTROL_REQUEST_METHOD: "OPTIONS",
+ hdrs.ACCESS_CONTROL_REQUEST_HEADERS: "content-type",
+ }
+ )
+ data = await resp.text()
+ assert resp.status == 200
+ assert data == ''
# TODO: test requesting resources with not configured CORS.
diff --git a/tests/integration/test_page.html b/tests/integration/test_page.html
index c2f4f7a..240742b 100644
--- a/tests/integration/test_page.html
+++ b/tests/integration/test_page.html
@@ -406,7 +406,7 @@
xhr.responseText);
} else {
- log('Received server addressess:', xhr.response);
+ log('Received server addresses:', xhr.response);
setServersUrls(JSON.parse(xhr.responseText));
}
diff --git a/tests/integration/test_real_browser.py b/tests/integration/test_real_browser.py
index 06a828b..a5c9030 100644
--- a/tests/integration/test_real_browser.py
+++ b/tests/integration/test_real_browser.py
@@ -19,12 +19,12 @@ import os
import json
import asyncio
import socket
-import unittest
import pathlib
import logging
import webbrowser
from aiohttp import web, hdrs
+import pytest
import selenium.common.exceptions
from selenium import webdriver
@@ -33,9 +33,7 @@ from selenium.webdriver.common.keys import Keys
from selenium.webdriver.common.by import By
from selenium.webdriver.support import expected_conditions as EC
-from aiohttp_cors import setup, ResourceOptions
-
-from ..aio_test_base import create_server, AioTestBase, asynctest
+from aiohttp_cors import setup as _setup, ResourceOptions, CorsViewMixin
class _ServerDescr:
@@ -52,7 +50,7 @@ class _ServerDescr:
class IntegrationServers:
"""Integration servers starting/stopping manager"""
- def __init__(self, use_resources, *, loop=None):
+ def __init__(self, use_resources, use_webview, *, loop=None):
self.servers = {}
self.loop = loop
@@ -60,6 +58,7 @@ class IntegrationServers:
self.loop = asyncio.get_event_loop()
self.use_resources = use_resources
+ self.use_webview = use_webview
self._logger = logging.getLogger("IntegrationServers")
@@ -67,37 +66,37 @@ class IntegrationServers:
def origin_server_url(self):
return self.servers["origin"].url
- @asyncio.coroutine
- def start_servers(self):
+ async def start_servers(self):
test_page_path = pathlib.Path(__file__).with_name("test_page.html")
- @asyncio.coroutine
- def handle_test_page(request: web.Request) -> web.StreamResponse:
+ async def handle_test_page(request: web.Request) -> web.StreamResponse:
with test_page_path.open("r", encoding="utf-8") as f:
return web.Response(
text=f.read(),
headers={hdrs.CONTENT_TYPE: "text/html"})
- @asyncio.coroutine
- def handle_no_cors(request: web.Request) -> web.StreamResponse:
+ async def handle_no_cors(request: web.Request) -> web.StreamResponse:
return web.Response(
text="""{"type": "no_cors.json"}""",
headers={hdrs.CONTENT_TYPE: "application/json"})
- @asyncio.coroutine
- def handle_resource(request: web.Request) -> web.StreamResponse:
+ async def handle_resource(request: web.Request) -> web.StreamResponse:
return web.Response(
text="""{"type": "resource"}""",
headers={hdrs.CONTENT_TYPE: "application/json"})
- @asyncio.coroutine
- def handle_servers_addresses(
+ async def handle_servers_addresses(
request: web.Request) -> web.StreamResponse:
servers_addresses = \
{name: descr.url for name, descr in self.servers.items()}
return web.Response(
text=json.dumps(servers_addresses))
+ class ResourceView(web.View, CorsViewMixin):
+
+ async def get(self) -> web.StreamResponse:
+ return await handle_resource(self.request)
+
# For most resources:
# "origin" server has no CORS configuration.
# "allowing" server explicitly allows CORS requests to "origin" server.
@@ -137,8 +136,12 @@ class IntegrationServers:
for server_name in server_names:
app = self.servers[server_name].app
app.router.add_route("GET", "/no_cors.json", handle_no_cors)
- app.router.add_route("GET", "/cors_resource", handle_resource,
- name="cors_resource")
+ if self.use_webview:
+ app.router.add_route("*", "/cors_resource", ResourceView,
+ name="cors_resource")
+ else:
+ app.router.add_route("GET", "/cors_resource", handle_resource,
+ name="cors_resource")
cors_default_configs = {
"allowing": {
@@ -167,7 +170,7 @@ class IntegrationServers:
default_config = cors_default_configs.get(server_name)
if default_config is None:
continue
- server_descr.cors = setup(
+ server_descr.cors = _setup(
server_descr.app, defaults=default_config)
# Add CORS routes.
@@ -182,27 +185,30 @@ class IntegrationServers:
server_descr.cors.add(resource)
server_descr.cors.add(route)
+ elif self.use_webview:
+ server_descr.cors.add(route)
+
else:
server_descr.cors.add(route)
# Start servers.
for server_name, server_descr in self.servers.items():
handler = server_descr.app.make_handler()
- server = yield from create_server(handler, self.loop,
- sock=server_sockets[server_name])
+ server = await self.loop.create_server(
+ handler,
+ sock=server_sockets[server_name])
server_descr.handler = handler
server_descr.server = server
self._logger.info("Started server '%s' at '%s'",
server_name, server_descr.url)
- @asyncio.coroutine
- def stop_servers(self):
+ async def stop_servers(self):
for server_descr in self.servers.values():
server_descr.server.close()
- yield from server_descr.handler.finish_connections()
- yield from server_descr.server.wait_closed()
- yield from server_descr.app.cleanup()
+ await server_descr.handler.shutdown()
+ await server_descr.server.wait_closed()
+ await server_descr.app.cleanup()
self.servers = {}
@@ -218,99 +224,64 @@ def _get_chrome_driver():
return driver
-class TestInBrowser(AioTestBase):
- @asyncio.coroutine
- def _test_in_webdriver(self, driver, use_resources):
- # TODO: Use pytest's fixtures to test use resources/not use resources.
- servers = IntegrationServers(use_resources)
- yield from servers.start_servers()
-
- def selenium_thread():
- driver.get(servers.origin_server_url)
- assert "aiohttp_cors" in driver.title
-
- wait = WebDriverWait(driver, 10)
-
- run_button = wait.until(EC.element_to_be_clickable(
- (By.ID, "runTestsButton")))
-
- # Start tests.
- run_button.send_keys(Keys.RETURN)
-
- # Wait while test will finish (until clear button is not
- # activated).
- wait.until(EC.element_to_be_clickable(
- (By.ID, "clearResultsButton")))
-
- # Get results json
- results_area = driver.find_element_by_id("results")
-
- return json.loads(results_area.get_attribute("value"))
-
- try:
- results = yield from self.loop.run_in_executor(
- self.thread_pool_executor, selenium_thread)
-
- self.assertEqual(results["status"], "success")
- for test_name, test_data in results["data"].items():
- with self.subTest(group_name=test_name):
- self.assertEqual(test_data["status"], "success",
- msg=(test_name, test_data))
-
- finally:
- yield from servers.stop_servers()
-
- @asynctest
- @asyncio.coroutine
- def test_firefox(self):
- try:
- driver = webdriver.Firefox()
- except selenium.common.exceptions.WebDriverException:
- raise unittest.SkipTest
-
- try:
- yield from self._test_in_webdriver(driver, False)
- finally:
- driver.close()
-
- @asynctest
- @asyncio.coroutine
- def test_chromium(self):
- try:
- driver = _get_chrome_driver()
- except selenium.common.exceptions.WebDriverException:
- raise unittest.SkipTest
-
- try:
- yield from self._test_in_webdriver(driver, False)
- finally:
- driver.close()
-
- @asynctest
- @asyncio.coroutine
- def test_firefox_resource(self):
- try:
- driver = webdriver.Firefox()
- except selenium.common.exceptions.WebDriverException:
- raise unittest.SkipTest
-
- try:
- yield from self._test_in_webdriver(driver, True)
- finally:
- driver.close()
-
- @asynctest
- @asyncio.coroutine
- def test_chromium_resource(self):
- try:
- driver = _get_chrome_driver()
- except selenium.common.exceptions.WebDriverException:
- raise unittest.SkipTest
-
- try:
- yield from self._test_in_webdriver(driver, True)
- finally:
- driver.close()
+@pytest.fixture(params=[(False, False),
+ (True, False),
+ (False, True)])
+def server(request, loop):
+ async def inner():
+ # to grab implicit loop
+ return IntegrationServers(*request.param)
+ return loop.run_until_complete(inner())
+
+
+@pytest.fixture(params=[webdriver.Firefox,
+ _get_chrome_driver])
+def driver(request):
+ try:
+ driver = request.param()
+ except selenium.common.exceptions.WebDriverException:
+ pytest.skip("Driver is not supported")
+
+ yield driver
+ driver.close()
+
+
+async def test_in_webdriver(driver, server):
+ loop = asyncio.get_event_loop()
+ await server.start_servers()
+
+ def selenium_thread():
+ driver.get(server.origin_server_url)
+ assert "aiohttp_cors" in driver.title
+
+ wait = WebDriverWait(driver, 10)
+
+ run_button = wait.until(EC.element_to_be_clickable(
+ (By.ID, "runTestsButton")))
+
+ # Start tests.
+ run_button.send_keys(Keys.RETURN)
+
+ # Wait while test will finish (until clear button is not
+ # activated).
+ wait.until(EC.element_to_be_clickable(
+ (By.ID, "clearResultsButton")))
+
+ # Get results json
+ results_area = driver.find_element_by_id("results")
+
+ return json.loads(results_area.get_attribute("value"))
+
+ try:
+ results = await loop.run_in_executor(
+ None, selenium_thread)
+
+ assert results["status"] == "success"
+ for test_name, test_data in results["data"].items():
+ assert test_data["status"] == "success"
+
+ finally:
+ await server.stop_servers()
def _run_integration_server():
@@ -322,7 +293,7 @@ def _run_integration_server():
loop = asyncio.get_event_loop()
- servers = IntegrationServers()
+ servers = IntegrationServers(False, True)
logger.info("Starting integration servers...")
loop.run_until_complete(servers.start_servers())
diff --git a/tests/unit/test___about__.py b/tests/unit/test___about__.py
index 99c5672..4b16655 100644
--- a/tests/unit/test___about__.py
+++ b/tests/unit/test___about__.py
@@ -15,15 +15,12 @@
"""Test aiohttp_cors package metainformation.
"""
-import unittest
from pkg_resources import parse_version
import aiohttp_cors
-class TestMetaInformation(unittest.TestCase):
- """Test package metainformation"""
- # pylint: disable=no-self-use
- def test_version(self):
- """Test package version string"""
- parse_version(aiohttp_cors.__version__)
+def test_version():
+ """Test package version string"""
+ # not raised
+ parse_version(aiohttp_cors.__version__)
diff --git a/tests/unit/test_cors_config.py b/tests/unit/test_cors_config.py
index 96e837d..5b8d8f3 100644
--- a/tests/unit/test_cors_config.py
+++ b/tests/unit/test_cors_config.py
@@ -16,75 +16,120 @@
"""
import asyncio
-import unittest
+import pytest
from aiohttp import web
-from aiohttp_cors import CorsConfig, ResourceOptions
+from aiohttp_cors import CorsConfig, ResourceOptions, CorsViewMixin
-def _handler(request):
+async def _handler(request):
return web.Response(text="Done")
-class TestCorsConfig(unittest.TestCase):
- """Unit tests for CorsConfig"""
-
- def setUp(self):
- self.loop = asyncio.new_event_loop()
- self.app = web.Application(loop=self.loop)
- self.cors = CorsConfig(self.app, defaults={
- "*": ResourceOptions()
- })
- self.get_route = self.app.router.add_route(
- "GET", "/get_path", _handler)
- self.options_route = self.app.router.add_route(
- "OPTIONS", "/options_path", _handler)
-
- def tearDown(self):
- self.loop.close()
-
- def test_add_options_route(self):
- """Test configuring OPTIONS route"""
-
- with self.assertRaises(RuntimeError):
- self.cors.add(self.options_route.resource)
-
- def test_plain_named_route(self):
- """Test adding plain named route."""
- # Adding CORS routes should not introduce new named routes.
- self.assertEqual(len(self.app.router.keys()), 0)
- route = self.app.router.add_route(
- "GET", "/{name}", _handler, name="dynamic_named_route")
- self.assertEqual(len(self.app.router.keys()), 1)
- self.cors.add(route)
- self.assertEqual(len(self.app.router.keys()), 1)
-
- def test_dynamic_named_route(self):
- """Test adding dynamic named route."""
- self.assertEqual(len(self.app.router.keys()), 0)
- route = self.app.router.add_route(
- "GET", "/{name}", _handler, name="dynamic_named_route")
- self.assertEqual(len(self.app.router.keys()), 1)
- self.cors.add(route)
- self.assertEqual(len(self.app.router.keys()), 1)
-
- def test_static_named_route(self):
- """Test adding dynamic named route."""
- self.assertEqual(len(self.app.router.keys()), 0)
- route = self.app.router.add_static(
- "/file", "/", name="dynamic_named_route")
- self.assertEqual(len(self.app.router.keys()), 1)
- self.cors.add(route)
- self.assertEqual(len(self.app.router.keys()), 1)
-
- def test_static_resource(self):
- """Test adding static resource."""
- self.assertEqual(len(self.app.router.keys()), 0)
- self.app.router.add_static(
- "/file", "/", name="dynamic_named_route")
- self.assertEqual(len(self.app.router.keys()), 1)
- for resource in list(self.app.router.resources()):
- if issubclass(resource, web.StaticResource):
- self.cors.add(resource)
- self.assertEqual(len(self.app.router.keys()), 1)
+class _View(web.View, CorsViewMixin):
+
+ @asyncio.coroutine
+ def get(self):
+ return web.Response(text="Done")
+
+
+@pytest.fixture
+def app():
+ return web.Application()
+
+
+@pytest.fixture
+def cors(app):
+ return CorsConfig(app, defaults={
+ "*": ResourceOptions()
+ })
+
+
+@pytest.fixture
+def get_route(app):
+ return app.router.add_route(
+ "GET", "/get_path", _handler)
+
+
+@pytest.fixture
+def options_route(app):
+ return app.router.add_route(
+ "OPTIONS", "/options_path", _handler)
+
+
+def test_add_options_route(cors, options_route):
+ """Test configuring OPTIONS route"""
+
+ with pytest.raises(ValueError,
+ match="/options_path already has OPTIONS handler"):
+ cors.add(options_route.resource)
+
+
+def test_plain_named_route(app, cors):
+ """Test adding plain named route."""
+ # Adding CORS routes should not introduce new named routes.
+ assert len(app.router.keys()) == 0
+ route = app.router.add_route(
+ "GET", "/{name}", _handler, name="dynamic_named_route")
+ assert len(app.router.keys()) == 1
+ cors.add(route)
+ assert len(app.router.keys()) == 1
+
+
+def test_dynamic_named_route(app, cors):
+ """Test adding dynamic named route."""
+ assert len(app.router.keys()) == 0
+ route = app.router.add_route(
+ "GET", "/{name}", _handler, name="dynamic_named_route")
+ assert len(app.router.keys()) == 1
+ cors.add(route)
+ assert len(app.router.keys()) == 1
+
+
+def test_static_named_route(app, cors):
+ """Test adding dynamic named route."""
+ assert len(app.router.keys()) == 0
+ route = app.router.add_static(
+ "/file", "/", name="dynamic_named_route")
+ assert len(app.router.keys()) == 1
+ cors.add(route)
+ assert len(app.router.keys()) == 1
+
+
+def test_static_resource(app, cors):
+ """Test adding static resource."""
+ assert len(app.router.keys()) == 0
+ app.router.add_static(
+ "/file", "/", name="dynamic_named_route")
+ assert len(app.router.keys()) == 1
+ for resource in list(app.router.resources()):
+ if issubclass(resource, web.StaticResource):
+ cors.add(resource)
+ assert len(app.router.keys()) == 1
+
+
+def test_web_view_resource(app, cors):
+ """Test adding resource with web.View as handler"""
+ assert len(app.router.keys()) == 0
+ route = app.router.add_route(
+ "GET", "/{name}", _View, name="dynamic_named_route")
+ assert len(app.router.keys()) == 1
+ cors.add(route)
+ assert len(app.router.keys()) == 1
+
+
+def test_web_view_warning(app, cors):
+ """Test adding resource with web.View as handler"""
+ route = app.router.add_route("*", "/", _View)
+ with pytest.warns(DeprecationWarning):
+ cors.add(route, webview=True)
+
+
+def test_disable_bare_view(app, cors):
+ class View(web.View):
+ pass
+
+ route = app.router.add_route("*", "/", View)
+ with pytest.raises(ValueError):
+ cors.add(route)
diff --git a/tests/unit/test_mixin.py b/tests/unit/test_mixin.py
new file mode 100644
index 0000000..fb33b2e
--- /dev/null
+++ b/tests/unit/test_mixin.py
@@ -0,0 +1,125 @@
+import asyncio
+
+from unittest import mock
+
+import pytest
+from aiohttp import web
+
+from aiohttp_cors import CorsConfig, APP_CONFIG_KEY
+from aiohttp_cors import ResourceOptions, CorsViewMixin, custom_cors
+
+
+DEFAULT_CONFIG = {
+ '*': ResourceOptions()
+}
+
+CLASS_CONFIG = {
+ '*': ResourceOptions()
+}
+
+CUSTOM_CONFIG = {
+ 'www.client1.com': ResourceOptions(allow_headers=['X-Host'])
+}
+
+
+class SimpleView(web.View, CorsViewMixin):
+ async def get(self):
+ return web.Response(text="Done")
+
+
+class SimpleViewWithConfig(web.View, CorsViewMixin):
+
+ cors_config = CLASS_CONFIG
+
+ async def get(self):
+ return web.Response(text="Done")
+
+
+class CustomMethodView(web.View, CorsViewMixin):
+
+ cors_config = CLASS_CONFIG
+
+ async def get(self):
+ return web.Response(text="Done")
+
+ @custom_cors(CUSTOM_CONFIG)
+ async def post(self):
+ return web.Response(text="Done")
+
+
+@pytest.fixture
+def _app():
+ return web.Application()
+
+
+@pytest.fixture
+def cors(_app):
+ ret = CorsConfig(_app, defaults=DEFAULT_CONFIG)
+ _app[APP_CONFIG_KEY] = ret
+ return ret
+
+
+@pytest.fixture
+def app(_app, cors):
+ # a trick to install a cors into app
+ return _app
+
+
+def test_raise_exception_when_cors_not_configure():
+ request = mock.Mock()
+ request.app = {}
+ view = CustomMethodView(request)
+
+ with pytest.raises(ValueError):
+ view.get_request_config(request, 'post')
+
+
+async def test_raises_forbidden_when_config_not_found(app):
+ app[APP_CONFIG_KEY].defaults = {}
+ request = mock.Mock()
+ request.app = app
+ request.headers = {
+ 'Origin': '*',
+ 'Access-Control-Request-Method': 'GET'
+ }
+ view = SimpleView(request)
+
+ with pytest.raises(web.HTTPForbidden):
+ await view.options()
+
+
+def test_method_with_custom_cors(app):
+ """Test adding resource with web.View as handler"""
+ request = mock.Mock()
+ request.app = app
+ view = CustomMethodView(request)
+
+ assert hasattr(view.post, 'post_cors_config')
+ assert asyncio.iscoroutinefunction(view.post)
+ config = view.get_request_config(request, 'post')
+
+ assert config.get('www.client1.com') == CUSTOM_CONFIG['www.client1.com']
+
+
+def test_method_with_class_config(app):
+ """Test adding resource with web.View as handler"""
+ request = mock.Mock()
+ request.app = app
+ view = SimpleViewWithConfig(request)
+
+ assert not hasattr(view.get, 'get_cors_config')
+ config = view.get_request_config(request, 'get')
+
+ assert config.get('*') == CLASS_CONFIG['*']
+
+
+def test_method_with_default_config(app):
+ """Test adding resource with web.View as handler"""
+ request = mock.Mock()
+ request.app = app
+ view = SimpleView(request)
+
+ assert not hasattr(view.get, 'get_cors_config')
+ config = view.get_request_config(request, 'get')
+
+ assert config.get('*') == DEFAULT_CONFIG['*']
diff --git a/tests/unit/test_preflight_handler.py b/tests/unit/test_preflight_handler.py
new file mode 100644
index 0000000..84fc8bc
--- /dev/null
+++ b/tests/unit/test_preflight_handler.py
@@ -0,0 +1,12 @@
+from unittest import mock
+
+import pytest
+
+from aiohttp_cors.preflight_handler import _PreflightHandler
+
+
+async def test_raises_when_handler_not_extend():
+ request = mock.Mock()
+ handler = _PreflightHandler()
+ with pytest.raises(NotImplementedError):
+ await handler._get_config(request, 'origin', 'GET')
diff --git a/tests/unit/test_resource_options.py b/tests/unit/test_resource_options.py
index 8c6631e..2ba88c2 100644
--- a/tests/unit/test_resource_options.py
+++ b/tests/unit/test_resource_options.py
@@ -15,46 +15,37 @@
"""aiohttp_cors.resource_options unit tests.
"""
-import unittest
+import pytest
from aiohttp_cors.resource_options import ResourceOptions
-class TestResourceOptions(unittest.TestCase):
- """Unit tests for ResourceOptions class"""
-
- def test_init_no_args(self):
- """Test construction without arguments"""
- opts = ResourceOptions()
-
- self.assertFalse(opts.allow_credentials)
- self.assertFalse(opts.expose_headers)
- self.assertFalse(opts.allow_headers)
- self.assertIsNone(opts.max_age)
-
- def test_comparison(self):
- self.assertTrue(ResourceOptions() == ResourceOptions())
- self.assertFalse(ResourceOptions() != ResourceOptions())
- self.assertFalse(
- ResourceOptions(allow_credentials=True) == ResourceOptions())
- self.assertTrue(
- ResourceOptions(allow_credentials=True) != ResourceOptions())
-
- def test_allow_methods(self):
- self.assertIsNone(ResourceOptions().allow_methods)
- self.assertEqual(
- ResourceOptions(allow_methods='*').allow_methods,
- '*')
- self.assertEqual(
- ResourceOptions(allow_methods=[]).allow_methods,
- frozenset())
- self.assertEqual(
- ResourceOptions(allow_methods=['get']).allow_methods,
+def test_init_no_args():
+ """Test construction without arguments"""
+ opts = ResourceOptions()
+
+ assert not opts.allow_credentials
+ assert not opts.expose_headers
+ assert not opts.allow_headers
+ assert opts.max_age is None
+
+
+def test_comparison():
+ assert ResourceOptions() == ResourceOptions()
+ assert not (ResourceOptions() != ResourceOptions())
+ assert not (ResourceOptions(allow_credentials=True) == ResourceOptions())
+ assert ResourceOptions(allow_credentials=True) != ResourceOptions()
+
+
+def test_allow_methods():
+ assert ResourceOptions().allow_methods is None
+ assert ResourceOptions(allow_methods='*').allow_methods == '*'
+ assert ResourceOptions(allow_methods=[]).allow_methods == frozenset()
+ assert (ResourceOptions(allow_methods=['get']).allow_methods ==
frozenset(['GET']))
- self.assertEqual(
- ResourceOptions(allow_methods=['get', 'Post']).allow_methods,
+ assert (ResourceOptions(allow_methods=['get', 'Post']).allow_methods ==
{'GET', 'POST'})
- with self.assertRaises(ValueError):
- ResourceOptions(allow_methods='GET')
+ with pytest.raises(ValueError):
+ ResourceOptions(allow_methods='GET')
# TODO: test arguments parsing
diff --git a/tests/unit/test_urldispatcher_router_adapter.py b/tests/unit/test_urldispatcher_router_adapter.py
index ad5d9c1..d81d2a3 100644
--- a/tests/unit/test_urldispatcher_router_adapter.py
+++ b/tests/unit/test_urldispatcher_router_adapter.py
@@ -15,10 +15,9 @@
"""aiohttp_cors.urldispatcher_router_adapter unit tests.
"""
-import asyncio
-import unittest
from unittest import mock
+import pytest
from aiohttp import web
from aiohttp_cors.urldispatcher_router_adapter import \
@@ -26,90 +25,90 @@ from aiohttp_cors.urldispatcher_router_adapter import \
from aiohttp_cors import ResourceOptions
-def _handler(request):
+async def _handler(request):
return web.Response(text="Done")
-class TestResourcesUrlDispatcherRouterAdapter(unittest.TestCase):
- """Unit tests for CorsConfig"""
-
- def setUp(self):
- self.loop = asyncio.new_event_loop()
- self.app = web.Application(loop=self.loop)
-
- self.adapter = ResourcesUrlDispatcherRouterAdapter(
- self.app.router, defaults={
- "*": ResourceOptions()
- })
- self.get_route = self.app.router.add_route(
- "GET", "/get_path", _handler)
- self.options_route = self.app.router.add_route(
- "OPTIONS", "/options_path", _handler)
-
- def tearDown(self):
- self.loop.close()
-
- def test_add_get_route(self):
- """Test configuring GET route"""
- result = self.adapter.add_preflight_handler(
- self.get_route.resource, _handler)
- self.assertIsNone(result)
-
- self.assertEqual(len(self.adapter._resource_config), 0)
- self.assertEqual(
- len(self.adapter._resources_with_preflight_handlers), 1)
- self.assertEqual(len(self.adapter._preflight_routes), 1)
-
- def test_add_options_route(self):
- """Test configuring OPTIONS route"""
-
- with self.assertRaisesRegex(
- ValueError,
- "CORS must be enabled for route's resource first"):
- self.adapter.add_preflight_handler(self.options_route, _handler)
-
- self.assertFalse(self.adapter._resources_with_preflight_handlers)
- self.assertFalse(self.adapter._preflight_routes)
-
- def test_get_non_preflight_request_config(self):
- self.adapter.add_preflight_handler(
- self.get_route.resource, _handler)
- self.adapter.set_config_for_routing_entity(
- self.get_route.resource, {
- 'http://example.org': ResourceOptions(),
- })
-
- self.adapter.add_preflight_handler(
- self.get_route, _handler)
- self.adapter.set_config_for_routing_entity(
- self.get_route, {
- 'http://test.example.org': ResourceOptions(),
- })
-
- request = mock.Mock()
-
- with mock.patch('aiohttp_cors.urldispatcher_router_adapter.'
- 'ResourcesUrlDispatcherRouterAdapter.'
- 'is_cors_enabled_on_request'
- ) as is_cors_enabled_on_request, \
- mock.patch('aiohttp_cors.urldispatcher_router_adapter.'
- 'ResourcesUrlDispatcherRouterAdapter.'
- '_request_resource'
- ) as _request_resource:
- is_cors_enabled_on_request.return_value = True
- _request_resource.return_value = self.get_route.resource
-
- self.assertEqual(
- self.adapter.get_non_preflight_request_config(request),
+@pytest.fixture
+def app():
+ return web.Application()
+
+
+@pytest.fixture
+def adapter(app):
+ return ResourcesUrlDispatcherRouterAdapter(
+ app.router, defaults={
+ "*": ResourceOptions()
+ })
+
+
+@pytest.fixture
+def get_route(app):
+ return app.router.add_route(
+ "GET", "/get_path", _handler)
+
+
+@pytest.fixture
+def options_route(app):
+ return app.router.add_route(
+ "OPTIONS", "/options_path", _handler)
+
+
+def test_add_get_route(adapter, get_route):
+ """Test configuring GET route"""
+ result = adapter.add_preflight_handler(
+ get_route.resource, _handler)
+ assert result is None
+
+ assert len(adapter._resource_config) == 0
+ assert len(adapter._resources_with_preflight_handlers) == 1
+ assert len(adapter._preflight_routes) == 1
+
+
+def test_add_options_route(adapter, options_route):
+ """Test configuring OPTIONS route"""
+
+ adapter.add_preflight_handler(options_route, _handler)
+
+ assert not adapter._resources_with_preflight_handlers
+ assert not adapter._preflight_routes
+
+
+def test_get_non_preflight_request_config(adapter, get_route):
+ adapter.add_preflight_handler(get_route.resource, _handler)
+ adapter.set_config_for_routing_entity(
+ get_route.resource, {
+ 'http://example.org': ResourceOptions(),
+ })
+
+ adapter.add_preflight_handler(get_route, _handler)
+ adapter.set_config_for_routing_entity(
+ get_route, {
+ 'http://test.example.org': ResourceOptions(),
+ })
+
+ request = mock.Mock()
+
+ with mock.patch('aiohttp_cors.urldispatcher_router_adapter.'
+ 'ResourcesUrlDispatcherRouterAdapter.'
+ 'is_cors_enabled_on_request'
+ ) as is_cors_enabled_on_request, \
+ mock.patch('aiohttp_cors.urldispatcher_router_adapter.'
+ 'ResourcesUrlDispatcherRouterAdapter.'
+ '_request_resource'
+ ) as _request_resource:
+ is_cors_enabled_on_request.return_value = True
+ _request_resource.return_value = get_route.resource
+
+ assert (adapter.get_non_preflight_request_config(request) ==
{
'*': ResourceOptions(),
'http://example.org': ResourceOptions(),
})
- request.method = 'GET'
+ request.method = 'GET'
- self.assertEqual(
- self.adapter.get_non_preflight_request_config(request),
+ assert (adapter.get_non_preflight_request_config(request) ==
{
'*': ResourceOptions(),
'http://example.org': ResourceOptions(),
diff --git a/tox.ini b/tox.ini
index 5a9786a..9668d37 100644
--- a/tox.ini
+++ b/tox.ini
@@ -18,5 +18,6 @@ commands =
[pytest]
testpaths = aiohttp_cors tests
+addopts= --cov=aiohttp_cors --cov-report=term --cov-report=html --cov-branch --no-cov-on-fail
;addopts = --cov aiohttp_cors
; --pylint-rcfile=.pylintrc --pylint