summaryrefslogtreecommitdiff
path: root/src/etcd/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/etcd/client.py')
-rw-r--r--src/etcd/client.py258
1 files changed, 179 insertions, 79 deletions
diff --git a/src/etcd/client.py b/src/etcd/client.py
index 03b0451..afeabe8 100644
--- a/src/etcd/client.py
+++ b/src/etcd/client.py
@@ -18,6 +18,8 @@ import urllib3
import urllib3.util
import json
import ssl
+import dns.resolver
+from functools import wraps
import etcd
try:
@@ -42,16 +44,22 @@ class Client(object):
_comparison_conditions = set(('prevValue', 'prevIndex', 'prevExist'))
_read_options = set(('recursive', 'wait', 'waitIndex', 'sorted', 'quorum'))
_del_conditions = set(('prevValue', 'prevIndex'))
+
+ http = None
+
def __init__(
self,
host='127.0.0.1',
port=4001,
+ srv_domain=None,
version_prefix='/v2',
read_timeout=60,
allow_redirect=True,
protocol='http',
cert=None,
ca_cert=None,
+ username=None,
+ password=None,
allow_reconnect=False,
use_proxies=False,
expected_cluster_id=None,
@@ -67,6 +75,8 @@ class Client(object):
port (int): Port used to connect to etcd.
+ srv_domain (str): Domain to search the SRV record for cluster autodiscovery.
+
version_prefix (str): Url or version prefix in etcd url (default=/v2).
read_timeout (int): max seconds to wait for a read.
@@ -81,6 +91,10 @@ class Client(object):
ca_cert (str): The ca certificate. If pressent it will enable
validation.
+ username (str): username for etcd authentication.
+
+ password (str): password for etcd authentication.
+
allow_reconnect (bool): allow the client to reconnect to another
etcd server in the cluster in the case the
default one does not respond.
@@ -98,8 +112,15 @@ class Client(object):
by host. By default this will use up to 10
connections.
"""
- _log.debug("New etcd client created for %s:%s%s",
- host, port, version_prefix)
+
+ # If a DNS record is provided, use it to get the hosts list
+ if srv_domain is not None:
+ try:
+ host = self._discover(srv_domain)
+ except Exception as e:
+ _log.error("Could not discover the etcd hosts from %s: %s",
+ srv_domain, e)
+
self._protocol = protocol
def uri(protocol, host, port):
@@ -151,8 +172,20 @@ class Client(object):
kw['ca_certs'] = ca_cert
kw['cert_reqs'] = ssl.CERT_REQUIRED
+ self.username = None
+ self.password = None
+ if username and password:
+ self.username = username
+ self.password = password
+ elif username:
+ _log.warning('Username provided without password, both are required for authentication')
+ elif password:
+ _log.warning('Password provided without username, both are required for authentication')
+
self.http = urllib3.PoolManager(num_pools=10, **kw)
+ _log.debug("New etcd client created for %s", self.base_uri)
+
if self._allow_reconnect:
# we need the set of servers in the cluster in order to try
# reconnecting upon error. The cluster members will be
@@ -174,6 +207,27 @@ class Client(object):
_log.debug("Machines cache initialised to %s",
self._machines_cache)
+ def _discover(self, domain):
+ srv_name = "_etcd._tcp.{}".format(domain)
+ answers = dns.resolver.query(srv_name, 'SRV')
+ hosts = []
+ for answer in answers:
+ hosts.append(
+ (answer.target.to_text(omit_final_dot=True), answer.port))
+ _log.debug("Found %s", hosts)
+ if not len(hosts):
+ raise ValueError("The SRV record is present but no host were found")
+ return tuple(hosts)
+
+ def __del__(self):
+ """Clean up open connections"""
+ if self.http is not None:
+ try:
+ self.http.clear()
+ except ReferenceError:
+ # this may hit an already-cleared weakref
+ pass
+
@property
def base_uri(self):
"""URI used by the client to connect to etcd."""
@@ -221,6 +275,7 @@ class Client(object):
response = self.http.request(
self._MGET,
uri,
+ headers=self._get_headers(),
timeout=self.read_timeout,
redirect=self.allow_redirect)
@@ -278,9 +333,9 @@ class Client(object):
try:
leader = json.loads(
- self.api_execute(self.version_prefix + '/stats/leader',
+ self.api_execute(self.version_prefix + '/stats/self',
self._MGET).data.decode('utf-8'))
- return self.members[leader['leader']]
+ return self.members[leader['leaderInfo']['leader']]
except Exception as e:
raise etcd.EtcdException("Cannot get leader data: %s" % e)
@@ -718,85 +773,123 @@ class Client(object):
_log.info("Selected new etcd server %s", mach)
return mach
+ def _wrap_request(payload):
+ @wraps(payload)
+ def wrapper(self, path, method, params=None, timeout=None):
+ some_request_failed = False
+ response = False
+
+ if timeout is None:
+ timeout = self.read_timeout
+
+ if timeout == 0:
+ timeout = None
+
+ if not path.startswith('/'):
+ raise ValueError('Path does not start with /')
+
+ while not response:
+ try:
+ response = payload(self, path, method,
+ params=params, timeout=timeout)
+ # Check the cluster ID hasn't changed under us. We use
+ # preload_content=False above so we can read the headers
+ # before we wait for the content of a watch.
+ self._check_cluster_id(response)
+ # Now force the data to be preloaded in order to trigger any
+ # IO-related errors in this method rather than when we try to
+ # access it later.
+ _ = response.data
+ # urllib3 doesn't wrap all httplib exceptions and earlier versions
+ # don't wrap socket errors either.
+ except (urllib3.exceptions.HTTPError,
+ HTTPException, socket.error) as e:
+ if (isinstance(params, dict) and
+ params.get("wait") == "true" and
+ isinstance(e,
+ urllib3.exceptions.ReadTimeoutError)):
+ _log.debug("Watch timed out.")
+ raise etcd.EtcdWatchTimedOut(
+ "Watch timed out: %r" % e,
+ cause=e
+ )
+ _log.error("Request to server %s failed: %r",
+ self._base_uri, e)
+ if self._allow_reconnect:
+ _log.info("Reconnection allowed, looking for another "
+ "server.")
+ # _next_server() raises EtcdException if there are no
+ # machines left to try, breaking out of the loop.
+ self._base_uri = self._next_server(cause=e)
+ some_request_failed = True
+
+ # if exception is raised on _ = response.data
+ # the condition for while loop will be False
+ # but we should retry
+ response = False
+ else:
+ _log.debug("Reconnection disabled, giving up.")
+ raise etcd.EtcdConnectionFailed(
+ "Connection to etcd failed due to %r" % e,
+ cause=e
+ )
+ except etcd.EtcdClusterIdChanged as e:
+ _log.warning(e)
+ raise
+ except:
+ _log.exception("Unexpected request failure, re-raising.")
+ raise
+
+ if some_request_failed:
+ if not self._use_proxies:
+ # The cluster may have changed since last invocation
+ self._machines_cache = self.machines
+ self._machines_cache.remove(self._base_uri)
+ return self._handle_server_response(response)
+ return wrapper
+
+ @_wrap_request
def api_execute(self, path, method, params=None, timeout=None):
""" Executes the query. """
-
- some_request_failed = False
- response = False
-
- if timeout is None:
- timeout = self.read_timeout
-
- if timeout == 0:
- timeout = None
-
- if not path.startswith('/'):
- raise ValueError('Path does not start with /')
-
- while not response:
- try:
- url = self._base_uri + path
-
- if (method == self._MGET) or (method == self._MDELETE):
- response = self.http.request(
- method,
- url,
- timeout=timeout,
- fields=params,
- redirect=self.allow_redirect,
- preload_content=False)
-
- elif (method == self._MPUT) or (method == self._MPOST):
- response = self.http.request_encode_body(
- method,
- url,
- fields=params,
- timeout=timeout,
- encode_multipart=False,
- redirect=self.allow_redirect,
- preload_content=False)
- else:
+ url = self._base_uri + path
+
+ if (method == self._MGET) or (method == self._MDELETE):
+ return self.http.request(
+ method,
+ url,
+ timeout=timeout,
+ fields=params,
+ redirect=self.allow_redirect,
+ headers=self._get_headers(),
+ preload_content=False)
+
+ elif (method == self._MPUT) or (method == self._MPOST):
+ return self.http.request_encode_body(
+ method,
+ url,
+ fields=params,
+ timeout=timeout,
+ encode_multipart=False,
+ redirect=self.allow_redirect,
+ headers=self._get_headers(),
+ preload_content=False)
+ else:
raise etcd.EtcdException(
'HTTP method {} not supported'.format(method))
-
- # Check the cluster ID hasn't changed under us. We use
- # preload_content=False above so we can read the headers
- # before we wait for the content of a watch.
- self._check_cluster_id(response)
- # Now force the data to be preloaded in order to trigger any
- # IO-related errors in this method rather than when we try to
- # access it later.
- _ = response.data
- # urllib3 doesn't wrap all httplib exceptions and earlier versions
- # don't wrap socket errors either.
- except (urllib3.exceptions.HTTPError,
- HTTPException,
- socket.error) as e:
- _log.error("Request to server %s failed: %r",
- self._base_uri, e)
- if self._allow_reconnect:
- _log.info("Reconnection allowed, looking for another "
- "server.")
- # _next_server() raises EtcdException if there are no
- # machines left to try, breaking out of the loop.
- self._base_uri = self._next_server(cause=e)
- some_request_failed = True
- else:
- _log.debug("Reconnection disabled, giving up.")
- raise etcd.EtcdConnectionFailed(
- "Connection to etcd failed due to %r" % e,
- cause=e
- )
- except:
- _log.exception("Unexpected request failure, re-raising.")
- raise
-
- if some_request_failed:
- if not self._use_proxies:
- # The cluster may have changed since last invocation
- self._machines_cache = self.machines
- self._machines_cache.remove(self._base_uri)
- return self._handle_server_response(response)
+
+ @_wrap_request
+ def api_execute_json(self, path, method, params=None, timeout=None):
+ url = self._base_uri + path
+ json_payload = json.dumps(params)
+ headers = self._get_headers()
+ headers['Content-Type'] = 'application/json'
+ return self.http.urlopen(method,
+ url,
+ body=json_payload,
+ timeout=timeout,
+ redirect=self.allow_redirect,
+ headers=headers,
+ preload_content=False)
def _check_cluster_id(self, response):
cluster_id = response.getheader("x-etcd-cluster-id")
@@ -827,8 +920,15 @@ class Client(object):
# throw the appropriate exception
try:
r = json.loads(resp)
+ r['status'] = response.status
except (TypeError, ValueError):
# Bad JSON, make a response locally.
r = {"message": "Bad response",
"cause": str(resp)}
etcd.EtcdError.handle(r)
+
+ def _get_headers(self):
+ if self.username and self.password:
+ credentials = ':'.join((self.username, self.password))
+ return urllib3.make_headers(basic_auth=credentials)
+ return {}