diff options
author | Andrew Shadura <andrewsh@debian.org> | 2016-12-13 14:30:44 +0100 |
---|---|---|
committer | Andrew Shadura <andrewsh@debian.org> | 2016-12-13 14:30:44 +0100 |
commit | 9c9b36e9c28964142ddd6b9dd0dfb415351db4aa (patch) | |
tree | 425fee28bc9280dd0a667b1c7ab4e3090fd5984c /waitress | |
parent | b72233be19c69f2e3ae65d8bd859d5a12981b784 (diff) |
Imported Upstream version 1.0.1
Diffstat (limited to 'waitress')
-rw-r--r-- | waitress/__init__.py | 3 | ||||
-rw-r--r-- | waitress/adjustments.py | 134 | ||||
-rw-r--r-- | waitress/buffers.py | 2 | ||||
-rw-r--r-- | waitress/channel.py | 2 | ||||
-rw-r--r-- | waitress/compat.py | 29 | ||||
-rw-r--r-- | waitress/parser.py | 2 | ||||
-rw-r--r-- | waitress/runner.py | 37 | ||||
-rw-r--r-- | waitress/server.py | 161 | ||||
-rw-r--r-- | waitress/task.py | 11 | ||||
-rw-r--r-- | waitress/tests/test_adjustments.py | 144 | ||||
-rw-r--r-- | waitress/tests/test_buffers.py | 2 | ||||
-rw-r--r-- | waitress/tests/test_functional.py | 4 | ||||
-rw-r--r-- | waitress/tests/test_parser.py | 17 | ||||
-rw-r--r-- | waitress/tests/test_server.py | 36 | ||||
-rw-r--r-- | waitress/tests/test_task.py | 21 |
15 files changed, 571 insertions, 34 deletions
diff --git a/waitress/__init__.py b/waitress/__init__.py index 27210d4..775fe3a 100644 --- a/waitress/__init__.py +++ b/waitress/__init__.py @@ -10,8 +10,7 @@ def serve(app, **kw): logging.basicConfig() server = _server(app, **kw) if not _quiet: # pragma: no cover - print('serving on http://%s:%s' % (server.effective_host, - server.effective_port)) + server.print_listen('Serving on http://{}:{}') if _profile: # pragma: no cover profile('server.run()', globals(), locals(), (), False) else: diff --git a/waitress/adjustments.py b/waitress/adjustments.py index d5b237b..1a56621 100644 --- a/waitress/adjustments.py +++ b/waitress/adjustments.py @@ -15,7 +15,13 @@ """ import getopt import socket -import sys + +from waitress.compat import ( + PY2, + WIN, + string_types, + HAS_IPV6, + ) truthy = frozenset(('t', 'true', 'y', 'yes', 'on', '1')) @@ -36,6 +42,22 @@ def asoctal(s): """Convert the given octal string to an actual number.""" return int(s, 8) +def aslist_cronly(value): + if isinstance(value, string_types): + value = filter(None, [x.strip() for x in value.splitlines()]) + return list(value) + +def aslist(value): + """ Return a list of strings, separating the input based on newlines + and, if flatten=True (the default), also split on spaces within + each line.""" + values = aslist_cronly(value) + result = [] + for value in values: + subvalues = value.split() + result.extend(subvalues) + return result + def slash_fixed_str(s): s = s.strip() if s: @@ -44,6 +66,12 @@ def slash_fixed_str(s): s = '/' + s.lstrip('/').rstrip('/') return s +class _str_marker(str): + pass + +class _int_marker(int): + pass + class Adjustments(object): """This class contains tunable parameters. """ @@ -51,6 +79,9 @@ class Adjustments(object): _params = ( ('host', str), ('port', int), + ('ipv4', asbool), + ('ipv6', asbool), + ('listen', aslist), ('threads', int), ('trusted_proxy', str), ('url_scheme', str), @@ -77,10 +108,12 @@ class Adjustments(object): _param_map = dict(_params) # hostname or IP address to listen on - host = '0.0.0.0' + host = _str_marker('0.0.0.0') # TCP port to listen on - port = 8080 + port = _int_marker(8080) + + listen = ['{}:{}'.format(host, port)] # mumber of threads available for tasks threads = 4 @@ -174,14 +207,96 @@ class Adjustments(object): # The asyncore.loop flag to use poll() instead of the default select(). asyncore_use_poll = False + # Enable IPv4 by default + ipv4 = True + + # Enable IPv6 by default + ipv6 = True + def __init__(self, **kw): + + if 'listen' in kw and ('host' in kw or 'port' in kw): + raise ValueError('host and or port may not be set if listen is set.') + for k, v in kw.items(): if k not in self._param_map: raise ValueError('Unknown adjustment %r' % k) setattr(self, k, self._param_map[k](v)) - if (sys.platform[:3] == "win" and - self.host == 'localhost'): # pragma: no cover - self.host = '' + + if (not isinstance(self.host, _str_marker) or + not isinstance(self.port, _int_marker)): + self.listen = ['{}:{}'.format(self.host, self.port)] + + enabled_families = socket.AF_UNSPEC + + if not self.ipv4 and not HAS_IPV6: # pragma: no cover + raise ValueError( + 'IPv4 is disabled but IPv6 is not available. Cowardly refusing to start.' + ) + + if self.ipv4 and not self.ipv6: + enabled_families = socket.AF_INET + + if not self.ipv4 and self.ipv6 and HAS_IPV6: + enabled_families = socket.AF_INET6 + + wanted_sockets = [] + hp_pairs = [] + for i in self.listen: + if ':' in i: + (host, port) = i.rsplit(":", 1) + + # IPv6 we need to make sure that we didn't split on the address + if ']' in port: # pragma: nocover + (host, port) = (i, str(self.port)) + else: + (host, port) = (i, str(self.port)) + + if WIN and PY2: # pragma: no cover + try: + # Try turning the port into an integer + port = int(port) + except: + raise ValueError( + 'Windows does not support service names instead of port numbers' + ) + + try: + if '[' in host and ']' in host: # pragma: nocover + host = host.strip('[').rstrip(']') + + if host == '*': + host = None + + for s in socket.getaddrinfo( + host, + port, + enabled_families, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + socket.AI_PASSIVE + ): + (family, socktype, proto, _, sockaddr) = s + + # It seems that getaddrinfo() may sometimes happily return + # the same result multiple times, this of course makes + # bind() very unhappy... + # + # Split on %, and drop the zone-index from the host in the + # sockaddr. Works around a bug in OS X whereby + # getaddrinfo() returns the same link-local interface with + # two different zone-indices (which makes no sense what so + # ever...) yet treats them equally when we attempt to bind(). + if ( + sockaddr[1] == 0 or + (sockaddr[0].split('%', 1)[0], sockaddr[1]) not in hp_pairs + ): + wanted_sockets.append((family, socktype, proto, sockaddr)) + hp_pairs.append((sockaddr[0].split('%', 1)[0], sockaddr[1])) + except: + raise ValueError('Invalid host/port specified.') + + self.listen = wanted_sockets @classmethod def parse_args(cls, argv): @@ -203,9 +318,15 @@ class Adjustments(object): 'help': False, 'call': False, } + opts, args = getopt.getopt(argv, '', long_opts) for opt, value in opts: param = opt.lstrip('-').replace('-', '_') + + if param == 'listen': + kw['listen'] = '{} {}'.format(kw.get('listen', ''), value) + continue + if param.startswith('no_'): param = param[3:] kw[param] = 'false' @@ -215,4 +336,5 @@ class Adjustments(object): kw[param] = 'true' else: kw[param] = value + return kw, args diff --git a/waitress/buffers.py b/waitress/buffers.py index f174a79..cacc094 100644 --- a/waitress/buffers.py +++ b/waitress/buffers.py @@ -44,7 +44,7 @@ class FileBasedBuffer(object): return self.remain def __nonzero__(self): - return self.remain > 0 + return True __bool__ = __nonzero__ # py3 diff --git a/waitress/channel.py b/waitress/channel.py index c806bee..ca02511 100644 --- a/waitress/channel.py +++ b/waitress/channel.py @@ -249,6 +249,8 @@ class HTTPChannel(logging_dispatcher, object): 'Unexpected error when closing an outbuf') continue # pragma: no cover (coverage bug, it is hit) else: + if hasattr(outbuf, 'prune'): + outbuf.prune() dobreak = True while outbuflen > 0: diff --git a/waitress/compat.py b/waitress/compat.py index 9e06cde..700f7a1 100644 --- a/waitress/compat.py +++ b/waitress/compat.py @@ -1,5 +1,7 @@ import sys import types +import platform +import warnings try: import urlparse @@ -7,8 +9,12 @@ except ImportError: # pragma: no cover from urllib import parse as urlparse # True if we are running on Python 3. +PY2 = sys.version_info[0] == 2 PY3 = sys.version_info[0] == 3 +# True if we are running on Windows +WIN = platform.system() == 'Windows' + if PY3: # pragma: no cover string_types = str, integer_types = int, @@ -109,3 +115,26 @@ try: MAXINT = sys.maxint except AttributeError: # pragma: no cover MAXINT = sys.maxsize + + +# Fix for issue reported in https://github.com/Pylons/waitress/issues/138, +# Python on Windows may not define IPPROTO_IPV6 in socket. +import socket + +HAS_IPV6 = socket.has_ipv6 + +if hasattr(socket, 'IPPROTO_IPV6') and hasattr(socket, 'IPV6_V6ONLY'): + IPPROTO_IPV6 = socket.IPPROTO_IPV6 + IPV6_V6ONLY = socket.IPV6_V6ONLY +else: # pragma: no cover + if WIN: + IPPROTO_IPV6 = 41 + IPV6_V6ONLY = 27 + else: + warnings.warn( + 'OS does not support required IPv6 socket flags. This is requirement ' + 'for Waitress. Please open an issue at https://github.com/Pylons/waitress. ' + 'IPv6 support has been disabled.', + RuntimeWarning + ) + HAS_IPV6 = False diff --git a/waitress/parser.py b/waitress/parser.py index 9962b83..fc71d68 100644 --- a/waitress/parser.py +++ b/waitress/parser.py @@ -182,6 +182,8 @@ class HTTPRequestParser(object): index = line.find(b':') if index > 0: key = line[:index] + if b'_' in key: + continue value = line[index + 1:].strip() key1 = tostr(key.upper().replace(b'-', b'_')) # If a header already exists, we append subsequent values diff --git a/waitress/runner.py b/waitress/runner.py index 04cd78f..abdb38e 100644 --- a/waitress/runner.py +++ b/waitress/runner.py @@ -42,9 +42,46 @@ Standard options: Hostname or IP address on which to listen, default is '0.0.0.0', which means "all IP addresses on this host". + Note: May not be used together with --listen + --port=PORT TCP port on which to listen, default is '8080' + Note: May not be used together with --listen + + --listen=ip:port + Tell waitress to listen on an ip port combination. + + Example: + + --listen=127.0.0.1:8080 + --listen=[::1]:8080 + --listen=*:8080 + + This option may be used multiple times to listen on multipe sockets. + A wildcard for the hostname is also supported and will bind to both + IPv4/IPv6 depending on whether they are enabled or disabled. + + --[no-]ipv4 + Toggle on/off IPv4 support. + + Example: + + --no-ipv4 + + This will disable IPv4 socket support. This affects wildcard matching + when generating the list of sockets. + + --[no-]ipv6 + Toggle on/off IPv6 support. + + Example: + + --no-ipv6 + + This will turn on IPv6 socket support. This affects wildcard matching + when generating a list of sockets. + --unix-socket=PATH Path of Unix socket. If a socket path is specified, a Unix domain socket is made instead of the usual inet domain socket. diff --git a/waitress/server.py b/waitress/server.py index 87338c8..d3fbd79 100644 --- a/waitress/server.py +++ b/waitress/server.py @@ -22,7 +22,14 @@ from waitress import trigger from waitress.adjustments import Adjustments from waitress.channel import HTTPChannel from waitress.task import ThreadedTaskDispatcher -from waitress.utilities import cleanup_unix_socket, logging_dispatcher +from waitress.utilities import ( + cleanup_unix_socket, + logging_dispatcher, + ) +from waitress.compat import ( + IPPROTO_IPV6, + IPV6_V6ONLY, + ) def create_server(application, map=None, @@ -42,11 +49,90 @@ def create_server(application, 'to return a WSGI app within your application.' ) adj = Adjustments(**kw) + + if map is None: # pragma: nocover + map = {} + + dispatcher = _dispatcher + if dispatcher is None: + dispatcher = ThreadedTaskDispatcher() + dispatcher.set_thread_count(adj.threads) + if adj.unix_socket and hasattr(socket, 'AF_UNIX'): - cls = UnixWSGIServer - else: - cls = TcpWSGIServer - return cls(application, map, _start, _sock, _dispatcher, adj) + sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None) + return UnixWSGIServer( + application, + map, + _start, + _sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo) + + effective_listen = [] + last_serv = None + for sockinfo in adj.listen: + # When TcpWSGIServer is called, it registers itself in the map. This + # side-effect is all we need it for, so we don't store a reference to + # or return it to the user. + last_serv = TcpWSGIServer( + application, + map, + _start, + _sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo) + effective_listen.append((last_serv.effective_host, last_serv.effective_port)) + + # We are running a single server, so we can just return the last server, + # saves us from having to create one more object + if len(adj.listen) == 1: + # In this case we have no need to use a MultiSocketServer + return last_serv + + # Return a class that has a utility function to print out the sockets it's + # listening on, and has a .run() function. All of the TcpWSGIServers + # registered themselves in the map above. + return MultiSocketServer(map, adj, effective_listen, dispatcher) + + +# This class is only ever used if we have multiple listen sockets. It allows +# the serve() API to call .run() which starts the asyncore loop, and catches +# SystemExit/KeyboardInterrupt so that it can atempt to cleanly shut down. +class MultiSocketServer(object): + asyncore = asyncore # test shim + + def __init__(self, + map=None, + adj=None, + effective_listen=None, + dispatcher=None, + ): + self.adj = adj + self.map = map + self.effective_listen = effective_listen + self.task_dispatcher = dispatcher + + def print_listen(self, format_str): # pragma: nocover + for l in self.effective_listen: + l = list(l) + + if ':' in l[0]: + l[0] = '[{}]'.format(l[0]) + + print(format_str.format(*l)) + + def run(self): + try: + self.asyncore.loop( + timeout=self.adj.asyncore_loop_timeout, + map=self.map, + use_poll=self.adj.asyncore_use_poll, + ) + except (SystemExit, KeyboardInterrupt): + self.task_dispatcher.shutdown() + class BaseWSGIServer(logging_dispatcher, object): @@ -54,15 +140,15 @@ class BaseWSGIServer(logging_dispatcher, object): next_channel_cleanup = 0 socketmod = socket # test shim asyncore = asyncore # test shim - family = None def __init__(self, application, map=None, _start=True, # test shim _sock=None, # test shim - _dispatcher=None, # test shim + dispatcher=None, # dispatcher adj=None, # adjustments + sockinfo=None, # opaque object **kw ): if adj is None: @@ -72,20 +158,30 @@ class BaseWSGIServer(logging_dispatcher, object): # conflicts with apps and libs that use the asyncore global socket # map ala https://github.com/Pylons/waitress/issues/63 map = {} + if sockinfo is None: + sockinfo = adj.listen[0] + + self.sockinfo = sockinfo + self.family = sockinfo[0] + self.socktype = sockinfo[1] self.application = application self.adj = adj self.trigger = trigger.trigger(map) - if _dispatcher is None: - _dispatcher = ThreadedTaskDispatcher() - _dispatcher.set_thread_count(self.adj.threads) - self.task_dispatcher = _dispatcher + if dispatcher is None: + dispatcher = ThreadedTaskDispatcher() + dispatcher.set_thread_count(self.adj.threads) + + self.task_dispatcher = dispatcher self.asyncore.dispatcher.__init__(self, _sock, map=map) if _sock is None: - self.create_socket(self.family, socket.SOCK_STREAM) + self.create_socket(self.family, self.socktype) + if self.family == socket.AF_INET6: # pragma: nocover + self.socket.setsockopt(IPPROTO_IPV6, IPV6_V6ONLY, 1) + self.set_reuse_addr() self.bind_server_socket() self.effective_host, self.effective_port = self.getsockname() - self.server_name = self.get_server_name(self.adj.host) + self.server_name = self.get_server_name(self.effective_host) self.active_channels = {} if _start: self.accept_connections() @@ -99,12 +195,13 @@ class BaseWSGIServer(logging_dispatcher, object): server_name = str(ip) else: server_name = str(self.socketmod.gethostname()) + # Convert to a host name if necessary. for c in server_name: if c != '.' and not c.isdigit(): return server_name try: - if server_name == '0.0.0.0': + if server_name == '0.0.0.0' or server_name == '::': return 'localhost' server_name = self.socketmod.gethostbyaddr(server_name)[0] except socket.error: # pragma: no cover @@ -186,25 +283,51 @@ class BaseWSGIServer(logging_dispatcher, object): if (not channel.requests) and channel.last_activity < cutoff: channel.will_close = True -class TcpWSGIServer(BaseWSGIServer): + def print_listen(self, format_str): # pragma: nocover + print(format_str.format(self.effective_host, self.effective_port)) - family = socket.AF_INET + +class TcpWSGIServer(BaseWSGIServer): def bind_server_socket(self): - self.bind((self.adj.host, self.adj.port)) + (_, _, _, sockaddr) = self.sockinfo + self.bind(sockaddr) def getsockname(self): - return self.socket.getsockname() + return self.socketmod.getnameinfo( + self.socket.getsockname(), + self.socketmod.NI_NUMERICSERV) def set_socket_options(self, conn): for (level, optname, value) in self.adj.socket_options: conn.setsockopt(level, optname, value) + if hasattr(socket, 'AF_UNIX'): class UnixWSGIServer(BaseWSGIServer): - family = socket.AF_UNIX + def __init__(self, + application, + map=None, + _start=True, # test shim + _sock=None, # test shim + dispatcher=None, # dispatcher + adj=None, # adjustments + sockinfo=None, # opaque object + **kw): + if sockinfo is None: + sockinfo = (socket.AF_UNIX, socket.SOCK_STREAM, None, None) + + super(UnixWSGIServer, self).__init__( + application, + map=map, + _start=_start, + _sock=_sock, + dispatcher=dispatcher, + adj=adj, + sockinfo=sockinfo, + **kw) def bind_server_socket(self): cleanup_unix_socket(self.adj.unix_socket) diff --git a/waitress/task.py b/waitress/task.py index 7136c32..4ce410c 100644 --- a/waitress/task.py +++ b/waitress/task.py @@ -358,6 +358,9 @@ class WSGITask(Task): if not status.__class__ is str: raise AssertionError('status %s is not a string' % status) + if '\n' in status or '\r' in status: + raise ValueError("carriage return/line " + "feed character present in status") self.status = status @@ -371,6 +374,14 @@ class WSGITask(Task): raise AssertionError( 'Header value %r is not a string in %r' % (v, (k, v)) ) + + if '\n' in v or '\r' in v: + raise ValueError("carriage return/line " + "feed character present in header value") + if '\n' in k or '\r' in k: + raise ValueError("carriage return/line " + "feed character present in header name") + kl = k.lower() if kl == 'content-length': self.content_length = int(v) diff --git a/waitress/tests/test_adjustments.py b/waitress/tests/test_adjustments.py index f2b28c2..9446705 100644 --- a/waitress/tests/test_adjustments.py +++ b/waitress/tests/test_adjustments.py @@ -1,4 +1,10 @@ import sys +import socket + +from waitress.compat import ( + PY2, + WIN, + ) if sys.version_info[:2] == (2, 6): # pragma: no cover import unittest2 as unittest @@ -45,13 +51,35 @@ class Test_asbool(unittest.TestCase): class TestAdjustments(unittest.TestCase): + def _hasIPv6(self): # pragma: nocover + if not socket.has_ipv6: + return False + + try: + socket.getaddrinfo( + '::1', + 0, + socket.AF_UNSPEC, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + socket.AI_PASSIVE | socket.AI_ADDRCONFIG + ) + + return True + except socket.gaierror as e: + # Check to see what the error is + if e.errno == socket.EAI_ADDRFAMILY: + return False + else: + raise e + def _makeOne(self, **kw): from waitress.adjustments import Adjustments return Adjustments(**kw) def test_goodvars(self): inst = self._makeOne( - host='host', + host='localhost', port='8080', threads='5', trusted_proxy='192.168.1.1', @@ -74,8 +102,11 @@ class TestAdjustments(unittest.TestCase): unix_socket='/tmp/waitress.sock', unix_socket_perms='777', url_prefix='///foo/', + ipv4=True, + ipv6=False, ) - self.assertEqual(inst.host, 'host') + + self.assertEqual(inst.host, 'localhost') self.assertEqual(inst.port, 8080) self.assertEqual(inst.threads, 5) self.assertEqual(inst.trusted_proxy, '192.168.1.1') @@ -98,10 +129,96 @@ class TestAdjustments(unittest.TestCase): self.assertEqual(inst.unix_socket, '/tmp/waitress.sock') self.assertEqual(inst.unix_socket_perms, 0o777) self.assertEqual(inst.url_prefix, '/foo') + self.assertEqual(inst.ipv4, True) + self.assertEqual(inst.ipv6, False) + + bind_pairs = [ + sockaddr[:2] + for (family, _, _, sockaddr) in inst.listen + if family == socket.AF_INET + ] + + # On Travis, somehow we start listening to two sockets when resolving + # localhost... + self.assertEqual(('127.0.0.1', 8080), bind_pairs[0]) + + def test_goodvar_listen(self): + inst = self._makeOne(listen='127.0.0.1') + + bind_pairs = [(host, port) for (_, _, _, (host, port)) in inst.listen] + + self.assertEqual(bind_pairs, [('127.0.0.1', 8080)]) + + def test_default_listen(self): + inst = self._makeOne() + + bind_pairs = [(host, port) for (_, _, _, (host, port)) in inst.listen] + + self.assertEqual(bind_pairs, [('0.0.0.0', 8080)]) + + def test_multiple_listen(self): + inst = self._makeOne(listen='127.0.0.1:9090 127.0.0.1:8080') + + bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] + + self.assertEqual(bind_pairs, + [('127.0.0.1', 9090), + ('127.0.0.1', 8080)]) + + def test_wildcard_listen(self): + inst = self._makeOne(listen='*:8080') + + bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] + + self.assertTrue(len(bind_pairs) >= 1) + + def test_ipv6_no_port(self): # pragma: nocover + if not self._hasIPv6(): + return + + inst = self._makeOne(listen='[::1]') + + bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] + + self.assertEqual(bind_pairs, [('::1', 8080)]) + + def test_bad_port(self): + self.assertRaises(ValueError, self._makeOne, listen='127.0.0.1:test') + + def test_service_port(self): + if WIN and PY2: # pragma: no cover + # On Windows and Python 2 this is broken, so we raise a ValueError + self.assertRaises( + ValueError, + self._makeOne, + listen='127.0.0.1:http', + ) + return + + inst = self._makeOne(listen='127.0.0.1:http 0.0.0.0:https') + + bind_pairs = [sockaddr[:2] for (_, _, _, sockaddr) in inst.listen] + + self.assertEqual(bind_pairs, [('127.0.0.1', 80), ('0.0.0.0', 443)]) + + def test_dont_mix_host_port_listen(self): + self.assertRaises( + ValueError, + self._makeOne, + host='localhost', + port='8080', + listen='127.0.0.1:8080', + ) def test_badvar(self): self.assertRaises(ValueError, self._makeOne, nope=True) + def test_ipv4_disabled(self): + self.assertRaises(ValueError, self._makeOne, ipv4=False, listen="127.0.0.1:8080") + + def test_ipv6_disabled(self): + self.assertRaises(ValueError, self._makeOne, ipv6=False, listen="[::]:8080") + class TestCLI(unittest.TestCase): def parse(self, argv): @@ -147,10 +264,31 @@ class TestCLI(unittest.TestCase): self.assertDictContainsSubset({ 'host': 'localhost', 'port': '80', - 'unix_socket_perms':'777', + 'unix_socket_perms': '777', }, opts) self.assertSequenceEqual(args, []) + def test_listen_params(self): + opts, args = self.parse([ + '--listen=test:80', + ]) + + self.assertDictContainsSubset({ + 'listen': ' test:80' + }, opts) + self.assertSequenceEqual(args, []) + + def test_multiple_listen_params(self): + opts, args = self.parse([ + '--listen=test:80', + '--listen=test:8080', + ]) + + self.assertDictContainsSubset({ + 'listen': ' test:80 test:8080' + }, opts) + self.assertSequenceEqual(args, []) + def test_bad_param(self): import getopt self.assertRaises(getopt.GetoptError, self.parse, ['--no-host']) diff --git a/waitress/tests/test_buffers.py b/waitress/tests/test_buffers.py index 8a4ce6e..46a215e 100644 --- a/waitress/tests/test_buffers.py +++ b/waitress/tests/test_buffers.py @@ -31,7 +31,7 @@ class TestFileBasedBuffer(unittest.TestCase): inst.remain = 10 self.assertEqual(bool(inst), True) inst.remain = 0 - self.assertEqual(bool(inst), False) + self.assertEqual(bool(inst), True) def test_append(self): f = io.BytesIO(b'data') diff --git a/waitress/tests/test_functional.py b/waitress/tests/test_functional.py index 020486a..59ef4e4 100644 --- a/waitress/tests/test_functional.py +++ b/waitress/tests/test_functional.py @@ -34,6 +34,8 @@ class FixtureTcpWSGIServer(server.TcpWSGIServer): """A version of TcpWSGIServer that relays back what it's bound to. """ + family = socket.AF_INET # Testing + def __init__(self, application, queue, **kw): # pragma: no cover # Coverage doesn't see this as it's ran in a separate process. kw['port'] = 0 # Bind to any available port. @@ -1386,6 +1388,8 @@ if hasattr(socket, 'AF_UNIX'): """A version of UnixWSGIServer that relays back what it's bound to. """ + family = socket.AF_UNIX # Testing + def __init__(self, application, queue, **kw): # pragma: no cover # Coverage doesn't see this as it's ran in a separate process. # To permit parallel testing, use a PID-dependent socket. diff --git a/waitress/tests/test_parser.py b/waitress/tests/test_parser.py index 423d75a..781d7c7 100644 --- a/waitress/tests/test_parser.py +++ b/waitress/tests/test_parser.py @@ -408,9 +408,24 @@ Hello. self.assertEqual(self.parser.headers, { 'CONTENT_LENGTH': '7', 'X_FORWARDED_FOR': - '10.11.12.13, unknown,127.0.0.1, 255.255.255.255', + '10.11.12.13, unknown,127.0.0.1', }) + def testSpoofedHeadersDropped(self): + data = b"""\ +GET /foobar HTTP/8.4 +x-auth_user: bob +content-length: 7 + +Hello. +""" + self.feed(data) + self.assertTrue(self.parser.completed) + self.assertEqual(self.parser.headers, { + 'CONTENT_LENGTH': '7', + }) + + class DummyBodyStream(object): def getfile(self): diff --git a/waitress/tests/test_server.py b/waitress/tests/test_server.py index 0ff8871..39b90b3 100644 --- a/waitress/tests/test_server.py +++ b/waitress/tests/test_server.py @@ -34,10 +34,23 @@ class TestWSGIServer(unittest.TestCase): _start=_start, ) + def _makeOneWithMulti(self, adj=None, _start=True, + app=dummy_app, listen="127.0.0.1:0 127.0.0.1:0"): + sock = DummySock() + task_dispatcher = DummyTaskDispatcher() + map = {} + from waitress.server import create_server + return create_server( + app, + listen=listen, + map=map, + _dispatcher=task_dispatcher, + _start=_start, + _sock=sock) + def test_ctor_app_is_None(self): self.assertRaises(ValueError, self._makeOneWithMap, app=None) - def test_ctor_start_true(self): inst = self._makeOneWithMap(_start=True) self.assertEqual(inst.accepting, True) @@ -72,6 +85,10 @@ class TestWSGIServer(unittest.TestCase): result = inst.get_server_name('0.0.0.0') self.assertEqual(result, 'localhost') + def test_get_server_multi(self): + inst = self._makeOneWithMulti() + self.assertEqual(inst.__class__.__name__, 'MultiSocketServer') + def test_run(self): inst = self._makeOneWithMap(_start=False) inst.asyncore = DummyAsyncore() @@ -79,6 +96,13 @@ class TestWSGIServer(unittest.TestCase): inst.run() self.assertTrue(inst.task_dispatcher.was_shutdown) + def test_run_base_server(self): + inst = self._makeOneWithMulti(_start=False) + inst.asyncore = DummyAsyncore() + inst.task_dispatcher = DummyTaskDispatcher() + inst.run() + self.assertTrue(inst.task_dispatcher.was_shutdown) + def test_pull_trigger(self): inst = self._makeOneWithMap(_start=False) inst.trigger = DummyTrigger() @@ -242,6 +266,16 @@ if hasattr(socket, 'AF_UNIX'): [(inst, client, ('localhost', None), inst.adj)] ) + def test_creates_new_sockinfo(self): + from waitress.server import UnixWSGIServer + inst = UnixWSGIServer( + dummy_app, + unix_socket=self.unix_socket, + unix_socket_perms='600' + ) + + self.assertEqual(inst.sockinfo[0], socket.AF_UNIX) + class DummySock(object): accepted = False blocking = False diff --git a/waitress/tests/test_task.py b/waitress/tests/test_task.py index 6d6fcce..2a2759a 100644 --- a/waitress/tests/test_task.py +++ b/waitress/tests/test_task.py @@ -409,6 +409,27 @@ class TestWSGITask(unittest.TestCase): inst.channel.server.application = app self.assertRaises(AssertionError, inst.execute) + def test_execute_bad_header_value_control_characters(self): + def app(environ, start_response): + start_response('200 OK', [('a', '\n')]) + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(ValueError, inst.execute) + + def test_execute_bad_header_name_control_characters(self): + def app(environ, start_response): + start_response('200 OK', [('a\r', 'value')]) + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(ValueError, inst.execute) + + def test_execute_bad_status_control_characters(self): + def app(environ, start_response): + start_response('200 OK\r', []) + inst = self._makeOne() + inst.channel.server.application = app + self.assertRaises(ValueError, inst.execute) + def test_preserve_header_value_order(self): def app(environ, start_response): write = start_response('200 OK', [('C', 'b'), ('A', 'b'), ('A', 'a')]) |