diff options
Diffstat (limited to 'src/etcd/client.py')
-rw-r--r-- | src/etcd/client.py | 258 |
1 files changed, 179 insertions, 79 deletions
diff --git a/src/etcd/client.py b/src/etcd/client.py index 03b0451..afeabe8 100644 --- a/src/etcd/client.py +++ b/src/etcd/client.py @@ -18,6 +18,8 @@ import urllib3 import urllib3.util import json import ssl +import dns.resolver +from functools import wraps import etcd try: @@ -42,16 +44,22 @@ class Client(object): _comparison_conditions = set(('prevValue', 'prevIndex', 'prevExist')) _read_options = set(('recursive', 'wait', 'waitIndex', 'sorted', 'quorum')) _del_conditions = set(('prevValue', 'prevIndex')) + + http = None + def __init__( self, host='127.0.0.1', port=4001, + srv_domain=None, version_prefix='/v2', read_timeout=60, allow_redirect=True, protocol='http', cert=None, ca_cert=None, + username=None, + password=None, allow_reconnect=False, use_proxies=False, expected_cluster_id=None, @@ -67,6 +75,8 @@ class Client(object): port (int): Port used to connect to etcd. + srv_domain (str): Domain to search the SRV record for cluster autodiscovery. + version_prefix (str): Url or version prefix in etcd url (default=/v2). read_timeout (int): max seconds to wait for a read. @@ -81,6 +91,10 @@ class Client(object): ca_cert (str): The ca certificate. If pressent it will enable validation. + username (str): username for etcd authentication. + + password (str): password for etcd authentication. + allow_reconnect (bool): allow the client to reconnect to another etcd server in the cluster in the case the default one does not respond. @@ -98,8 +112,15 @@ class Client(object): by host. By default this will use up to 10 connections. """ - _log.debug("New etcd client created for %s:%s%s", - host, port, version_prefix) + + # If a DNS record is provided, use it to get the hosts list + if srv_domain is not None: + try: + host = self._discover(srv_domain) + except Exception as e: + _log.error("Could not discover the etcd hosts from %s: %s", + srv_domain, e) + self._protocol = protocol def uri(protocol, host, port): @@ -151,8 +172,20 @@ class Client(object): kw['ca_certs'] = ca_cert kw['cert_reqs'] = ssl.CERT_REQUIRED + self.username = None + self.password = None + if username and password: + self.username = username + self.password = password + elif username: + _log.warning('Username provided without password, both are required for authentication') + elif password: + _log.warning('Password provided without username, both are required for authentication') + self.http = urllib3.PoolManager(num_pools=10, **kw) + _log.debug("New etcd client created for %s", self.base_uri) + if self._allow_reconnect: # we need the set of servers in the cluster in order to try # reconnecting upon error. The cluster members will be @@ -174,6 +207,27 @@ class Client(object): _log.debug("Machines cache initialised to %s", self._machines_cache) + def _discover(self, domain): + srv_name = "_etcd._tcp.{}".format(domain) + answers = dns.resolver.query(srv_name, 'SRV') + hosts = [] + for answer in answers: + hosts.append( + (answer.target.to_text(omit_final_dot=True), answer.port)) + _log.debug("Found %s", hosts) + if not len(hosts): + raise ValueError("The SRV record is present but no host were found") + return tuple(hosts) + + def __del__(self): + """Clean up open connections""" + if self.http is not None: + try: + self.http.clear() + except ReferenceError: + # this may hit an already-cleared weakref + pass + @property def base_uri(self): """URI used by the client to connect to etcd.""" @@ -221,6 +275,7 @@ class Client(object): response = self.http.request( self._MGET, uri, + headers=self._get_headers(), timeout=self.read_timeout, redirect=self.allow_redirect) @@ -278,9 +333,9 @@ class Client(object): try: leader = json.loads( - self.api_execute(self.version_prefix + '/stats/leader', + self.api_execute(self.version_prefix + '/stats/self', self._MGET).data.decode('utf-8')) - return self.members[leader['leader']] + return self.members[leader['leaderInfo']['leader']] except Exception as e: raise etcd.EtcdException("Cannot get leader data: %s" % e) @@ -718,85 +773,123 @@ class Client(object): _log.info("Selected new etcd server %s", mach) return mach + def _wrap_request(payload): + @wraps(payload) + def wrapper(self, path, method, params=None, timeout=None): + some_request_failed = False + response = False + + if timeout is None: + timeout = self.read_timeout + + if timeout == 0: + timeout = None + + if not path.startswith('/'): + raise ValueError('Path does not start with /') + + while not response: + try: + response = payload(self, path, method, + params=params, timeout=timeout) + # Check the cluster ID hasn't changed under us. We use + # preload_content=False above so we can read the headers + # before we wait for the content of a watch. + self._check_cluster_id(response) + # Now force the data to be preloaded in order to trigger any + # IO-related errors in this method rather than when we try to + # access it later. + _ = response.data + # urllib3 doesn't wrap all httplib exceptions and earlier versions + # don't wrap socket errors either. + except (urllib3.exceptions.HTTPError, + HTTPException, socket.error) as e: + if (isinstance(params, dict) and + params.get("wait") == "true" and + isinstance(e, + urllib3.exceptions.ReadTimeoutError)): + _log.debug("Watch timed out.") + raise etcd.EtcdWatchTimedOut( + "Watch timed out: %r" % e, + cause=e + ) + _log.error("Request to server %s failed: %r", + self._base_uri, e) + if self._allow_reconnect: + _log.info("Reconnection allowed, looking for another " + "server.") + # _next_server() raises EtcdException if there are no + # machines left to try, breaking out of the loop. + self._base_uri = self._next_server(cause=e) + some_request_failed = True + + # if exception is raised on _ = response.data + # the condition for while loop will be False + # but we should retry + response = False + else: + _log.debug("Reconnection disabled, giving up.") + raise etcd.EtcdConnectionFailed( + "Connection to etcd failed due to %r" % e, + cause=e + ) + except etcd.EtcdClusterIdChanged as e: + _log.warning(e) + raise + except: + _log.exception("Unexpected request failure, re-raising.") + raise + + if some_request_failed: + if not self._use_proxies: + # The cluster may have changed since last invocation + self._machines_cache = self.machines + self._machines_cache.remove(self._base_uri) + return self._handle_server_response(response) + return wrapper + + @_wrap_request def api_execute(self, path, method, params=None, timeout=None): """ Executes the query. """ - - some_request_failed = False - response = False - - if timeout is None: - timeout = self.read_timeout - - if timeout == 0: - timeout = None - - if not path.startswith('/'): - raise ValueError('Path does not start with /') - - while not response: - try: - url = self._base_uri + path - - if (method == self._MGET) or (method == self._MDELETE): - response = self.http.request( - method, - url, - timeout=timeout, - fields=params, - redirect=self.allow_redirect, - preload_content=False) - - elif (method == self._MPUT) or (method == self._MPOST): - response = self.http.request_encode_body( - method, - url, - fields=params, - timeout=timeout, - encode_multipart=False, - redirect=self.allow_redirect, - preload_content=False) - else: + url = self._base_uri + path + + if (method == self._MGET) or (method == self._MDELETE): + return self.http.request( + method, + url, + timeout=timeout, + fields=params, + redirect=self.allow_redirect, + headers=self._get_headers(), + preload_content=False) + + elif (method == self._MPUT) or (method == self._MPOST): + return self.http.request_encode_body( + method, + url, + fields=params, + timeout=timeout, + encode_multipart=False, + redirect=self.allow_redirect, + headers=self._get_headers(), + preload_content=False) + else: raise etcd.EtcdException( 'HTTP method {} not supported'.format(method)) - - # Check the cluster ID hasn't changed under us. We use - # preload_content=False above so we can read the headers - # before we wait for the content of a watch. - self._check_cluster_id(response) - # Now force the data to be preloaded in order to trigger any - # IO-related errors in this method rather than when we try to - # access it later. - _ = response.data - # urllib3 doesn't wrap all httplib exceptions and earlier versions - # don't wrap socket errors either. - except (urllib3.exceptions.HTTPError, - HTTPException, - socket.error) as e: - _log.error("Request to server %s failed: %r", - self._base_uri, e) - if self._allow_reconnect: - _log.info("Reconnection allowed, looking for another " - "server.") - # _next_server() raises EtcdException if there are no - # machines left to try, breaking out of the loop. - self._base_uri = self._next_server(cause=e) - some_request_failed = True - else: - _log.debug("Reconnection disabled, giving up.") - raise etcd.EtcdConnectionFailed( - "Connection to etcd failed due to %r" % e, - cause=e - ) - except: - _log.exception("Unexpected request failure, re-raising.") - raise - - if some_request_failed: - if not self._use_proxies: - # The cluster may have changed since last invocation - self._machines_cache = self.machines - self._machines_cache.remove(self._base_uri) - return self._handle_server_response(response) + + @_wrap_request + def api_execute_json(self, path, method, params=None, timeout=None): + url = self._base_uri + path + json_payload = json.dumps(params) + headers = self._get_headers() + headers['Content-Type'] = 'application/json' + return self.http.urlopen(method, + url, + body=json_payload, + timeout=timeout, + redirect=self.allow_redirect, + headers=headers, + preload_content=False) def _check_cluster_id(self, response): cluster_id = response.getheader("x-etcd-cluster-id") @@ -827,8 +920,15 @@ class Client(object): # throw the appropriate exception try: r = json.loads(resp) + r['status'] = response.status except (TypeError, ValueError): # Bad JSON, make a response locally. r = {"message": "Bad response", "cause": str(resp)} etcd.EtcdError.handle(r) + + def _get_headers(self): + if self.username and self.password: + credentials = ':'.join((self.username, self.password)) + return urllib3.make_headers(basic_auth=credentials) + return {} |