summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorJose Plana <jplana@gmail.com>2013-11-27 14:29:41 -0800
committerJose Plana <jplana@gmail.com>2013-11-27 14:29:41 -0800
commit449f3f3b3c010401903638f8ba5ce83d54e8b39e (patch)
tree0839ca0ef3ffe9a87ca445421205ea4c04d396c2 /src
parentda123a1a326e8c400999af3ba7e38bf1cbefd31f (diff)
parentfbe7be18cbceec4ab6a114ee54dd333e400a4486 (diff)
Merge pull request #5 from jplana/allow-reconnect
Handling reconnections
Diffstat (limited to 'src')
-rw-r--r--src/etcd/client.py92
-rw-r--r--src/etcd/tests/integration/helpers.py71
-rw-r--r--src/etcd/tests/integration/test_simple.py110
3 files changed, 226 insertions, 47 deletions
diff --git a/src/etcd/client.py b/src/etcd/client.py
index 8e50368..2ce351c 100644
--- a/src/etcd/client.py
+++ b/src/etcd/client.py
@@ -17,6 +17,7 @@ class Client(object):
"""
Client for etcd, the distributed log service using raft.
"""
+
def __init__(
self,
host='127.0.0.1',
@@ -26,12 +27,15 @@ class Client(object):
protocol='http',
cert=None,
ca_cert=None,
+ allow_reconnect=False,
):
"""
Initialize the client.
Args:
- host (str): IP to connect to.
+ host (mixed):
+ If a string, IP to connect to.
+ If a tuple ((host, port), (host, port), ...)
port (int): Port used to connect to etcd.
@@ -47,15 +51,33 @@ class Client(object):
ca_cert (str): The ca certificate. If pressent it will enable
validation.
+ allow_reconnect (bool): allow the client to reconnect to another
+ etcd server in the cluster in the case the
+ default one does not respond.
+
"""
- self._host = host
- self._port = port
+ self._machines_cache = []
+
self._protocol = protocol
- self._base_uri = "%s://%s:%d" % (protocol, host, port)
+
+ def uri(protocol, host, port):
+ return '%s://%s:%d' % (protocol, host, port)
+
+ if not isinstance(host, tuple):
+ self._host = host
+ self._port = port
+ else:
+ self._host, self._port = host[0]
+ self._machines_cache.extend(
+ [uri(self._protocol, *conn) for conn in host])
+
+ self._base_uri = uri(self._protocol, self._host, self._port)
+
self.version_prefix = '/v1'
self._read_timeout = read_timeout
self._allow_redirect = allow_redirect
+ self._allow_reconnect = allow_reconnect
self._MGET = 'GET'
self._MPOST = 'POST'
@@ -104,6 +126,14 @@ class Client(object):
self.http = urllib3.PoolManager(num_pools=10, **kw)
+ if self._allow_reconnect:
+ # we need the set of servers in the cluster in order to try
+ # reconnecting upon error.
+ self._machines_cache = self.machines
+ self._machines_cache.remove(self._base_uri)
+ else:
+ self._machines_cache = []
+
@property
def base_uri(self):
"""URI used by the client to connect to etcd."""
@@ -359,27 +389,49 @@ class Client(object):
except:
raise etcd.EtcdException('Unable to decode server response')
+ def _next_server(self):
+ """ Selects the next server in the list, refreshes the server list. """
+ try:
+ return self._machines_cache.pop()
+ except IndexError:
+ raise etcd.EtcdException('No more machines in the cluster')
+
def api_execute(self, path, method, params=None):
""" Executes the query. """
- url = self._base_uri + path
-
- if (method == self._MGET) or (method == self._MDELETE):
- response = self.http.request(
- method,
- url,
- fields=params,
- redirect=self.allow_redirect)
-
- elif method == self._MPOST:
- response = self.http.request_encode_body(
- method,
- url,
- fields=params,
- encode_multipart=False,
- redirect=self.allow_redirect)
+
+ some_request_failed = False
+ response = False
+
+ while not response:
+ try:
+ url = self._base_uri + path
+
+ if (method == self._MGET) or (method == self._MDELETE):
+ response = self.http.request(
+ method,
+ url,
+ fields=params,
+ redirect=self.allow_redirect)
+
+ elif method == self._MPOST:
+ response = self.http.request_encode_body(
+ method,
+ url,
+ fields=params,
+ encode_multipart=False,
+ redirect=self.allow_redirect)
+
+ except urllib3.exceptions.MaxRetryError:
+ self._base_uri = self._next_server()
+ some_request_failed = True
+
+ if some_request_failed:
+ self._machines_cache = self.machines
+ self._machines_cache.remove(self._base_uri)
if response.status == 200:
return response.data
+
else:
try:
error = json.loads(response.data)
diff --git a/src/etcd/tests/integration/helpers.py b/src/etcd/tests/integration/helpers.py
index e199f6d..e1de006 100644
--- a/src/etcd/tests/integration/helpers.py
+++ b/src/etcd/tests/integration/helpers.py
@@ -16,42 +16,61 @@ class EtcdProcessHelper(object):
proc_name='etcd',
port_range_start=4001,
internal_port_range_start=7001,
- ):
+ cluster=False):
+
self.base_directory = base_directory
self.proc_name = proc_name
self.port_range_start = port_range_start
self.internal_port_range_start = internal_port_range_start
- self.processes = []
+ self.processes = {}
+ self.cluster = cluster
def run(self, number=1, proc_args=None):
- log = logging.getLogger()
for i in range(0, number):
- directory = tempfile.mkdtemp(
- dir=self.base_directory,
- prefix='python-etcd.%d-' % i)
- log.debug('Created directory %s' % directory)
- daemon_args = [
- self.proc_name,
- '-d', directory,
- '-n', 'test-node-%d' % i,
- '-s', '127.0.0.1:%d' % (self.internal_port_range_start + i),
- '-c', '127.0.0.1:%d' % (self.port_range_start + i),
- ]
- if proc_args:
- daemon_args.extend(proc_args)
- daemon = subprocess.Popen(daemon_args)
- log.debug('Started %d' % daemon.pid)
- time.sleep(2)
- self.processes.append((directory, daemon))
+ self.add_one(i, proc_args)
def stop(self):
log = logging.getLogger()
- for directory, process in self.processes:
- process.kill()
- time.sleep(2)
- log.debug('Killed etcd pid:%d' % process.pid)
- shutil.rmtree(directory)
- log.debug('Removed directory %s' % directory)
+ for key in [k for k in self.processes.keys()]:
+ self.kill_one(key)
+
+ def add_one(self, slot, proc_args=None):
+ log = logging.getLogger()
+ directory = tempfile.mkdtemp(
+ dir=self.base_directory,
+ prefix='python-etcd.%d-' % slot)
+
+ log.debug('Created directory %s' % directory)
+ daemon_args = [
+ self.proc_name,
+ '-d', directory,
+ '-n', 'test-node-%d' % slot,
+ '-s', '127.0.0.1:%d' % (self.internal_port_range_start + slot),
+ '-c', '127.0.0.1:%d' % (self.port_range_start + slot),
+ ]
+
+ if proc_args:
+ daemon_args.extend(proc_args)
+
+ if slot > 0 and self.cluster:
+ daemon_args.append('-C')
+ daemon_args.append(
+ '127.0.0.1:%d' % self.internal_port_range_start)
+
+ daemon = subprocess.Popen(daemon_args)
+ log.debug('Started %d' % daemon.pid)
+ log.debug('Params: %s' % daemon_args)
+ time.sleep(2)
+ self.processes[slot] = (directory, daemon)
+
+ def kill_one(self, slot):
+ log = logging.getLogger()
+ dir, process = self.processes.pop(slot)
+ process.kill()
+ time.sleep(2)
+ log.debug('Killed etcd pid:%d', process.pid)
+ shutil.rmtree(dir)
+ log.debug('Removed directory %s' % dir)
class TestingCA(object):
diff --git a/src/etcd/tests/integration/test_simple.py b/src/etcd/tests/integration/test_simple.py
index 8050ac5..f25b283 100644
--- a/src/etcd/tests/integration/test_simple.py
+++ b/src/etcd/tests/integration/test_simple.py
@@ -63,7 +63,7 @@ class TestSimple(EtcdIntegrationTest):
def test_machines(self):
""" INTEGRATION: retrieve machines """
- self.assertEquals(self.client.machines, ['http://127.0.0.1:6001'])
+ self.assertEquals(self.client.machines[0], 'http://127.0.0.1:6001')
def test_leader(self):
""" INTEGRATION: retrieve leader """
@@ -149,6 +149,114 @@ class TestErrors(EtcdIntegrationTest):
pass
+class TestClusterFunctions(EtcdIntegrationTest):
+ @classmethod
+ def setUpClass(cls):
+ program = cls._get_exe()
+ cls.directory = tempfile.mkdtemp(prefix='python-etcd')
+
+ cls.processHelper = helpers.EtcdProcessHelper(
+ cls.directory,
+ proc_name=program,
+ port_range_start=6001,
+ internal_port_range_start=8001,
+ cluster=True)
+
+ def test_reconnect(self):
+ """ INTEGRATION: get key after the server we're connected fails. """
+ self.processHelper.stop()
+ self.processHelper.run(number=3)
+ self.client = etcd.Client(port=6001, allow_reconnect=True)
+ set_result = self.client.set('/test_set', 'test-key1')
+ get_result = self.client.get('/test_set')
+
+ self.assertEquals('test-key1', get_result.value)
+
+ self.processHelper.kill_one(0)
+
+ get_result = self.client.get('/test_set')
+ self.assertEquals('test-key1', get_result.value)
+
+ def test_reconnect_with_several_hosts_passed(self):
+ """ INTEGRATION: receive several hosts at connection setup. """
+ self.processHelper.stop()
+ self.processHelper.run(number=3)
+ self.client = etcd.Client(
+ host=(
+ ('127.0.0.1', 6004),
+ ('127.0.0.1', 6001)),
+ allow_reconnect=True)
+ set_result = self.client.set('/test_set', 'test-key1')
+ get_result = self.client.get('/test_set')
+
+ self.assertEquals('test-key1', get_result.value)
+
+ self.processHelper.kill_one(0)
+
+ get_result = self.client.get('/test_set')
+ self.assertEquals('test-key1', get_result.value)
+
+ def test_reconnect_not_allowed(self):
+ """ INTEGRATION: fail on server kill if not allow_reconnect """
+ self.processHelper.stop()
+ self.processHelper.run(number=3)
+ self.client = etcd.Client(port=6001, allow_reconnect=False)
+ self.processHelper.kill_one(0)
+ self.assertRaises(etcd.EtcdException, self.client.get, '/test_set')
+
+ def test_reconnet_fails(self):
+ """ INTEGRATION: fails to reconnect if no available machines """
+ self.processHelper.stop()
+ # Start with three instances (0, 1, 2)
+ self.processHelper.run(number=3)
+ # Connect to instance 0
+ self.client = etcd.Client(port=6001, allow_reconnect=True)
+ set_result = self.client.set('/test_set', 'test-key1')
+
+ get_result = self.client.get('/test_set')
+ self.assertEquals('test-key1', get_result.value)
+ self.processHelper.kill_one(2)
+ self.processHelper.kill_one(1)
+ self.processHelper.kill_one(0)
+ self.assertRaises(etcd.EtcdException, self.client.get, '/test_set')
+
+ def test_reconnect_to_failed_node(self):
+ """ INTEGRATION: after a server failed and recovered we can connect."""
+
+ self.processHelper.stop()
+ # Start with three instances (0, 1, 2)
+ self.processHelper.run(number=3)
+
+ # Connect to instance 0
+ self.client = etcd.Client(port=6001, allow_reconnect=True)
+ set_result = self.client.set('/test_set', 'test-key1')
+
+ get_result = self.client.get('/test_set')
+ self.assertEquals('test-key1', get_result.value)
+
+ # kill 1 -> instances = (0, 2)
+ self.processHelper.kill_one(1)
+
+ get_result = self.client.get('/test_set')
+ self.assertEquals('test-key1', get_result.value)
+
+ # kill 0 -> Instances (2)
+ self.processHelper.kill_one(0)
+
+ get_result = self.client.get('/test_set')
+ self.assertEquals('test-key1', get_result.value)
+
+ # Add 0 (failed server) -> Instances (0,2)
+ self.processHelper.add_one(0)
+ # Instances (0, 2)
+
+ # kill 2 -> Instances (0) (previously failed)
+ self.processHelper.kill_one(2)
+
+ get_result = self.client.get('/test_set')
+ self.assertEquals('test-key1', get_result.value)
+
+
class TestWatch(EtcdIntegrationTest):
def test_watch(self):