summaryrefslogtreecommitdiff
path: root/src/etcd/lock.py
blob: 9068576c9c213a22ef3ae6369fe8ac98361f8095 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import contextlib

import etcd


class Lock(object):

    """
    Lock object using etcd's lock module.
    """

    def __init__(self, client, key, ttl=0, value=None):
        """
        Initialize a lock object.

        Args:
            client (Client):  etcd client to use for communication.

            key (string):  key to lock.

            ttl (int):  ttl (in seconds) for the lock to live.
                        0 or None to lock forever.

            value (mixed):  value to store on the lock.
        """
        self.client = client
        if not key.startswith('/'):
            key = '/' + key
        self.key = key
        self.ttl = ttl
        self.value = value
        self._index = None

    def __enter__(self):
        self.acquire()

    def __exit__(self, type, value, traceback):
        self.release()

    @property
    def _path(self):
        return u'/mod/v2/lock{}'.format(self.key)

    def acquire(self, timeout=None):
        """Acquire the lock from etcd. Blocks until lock is acquired."""
        params = {u'ttl': self.ttl}
        if self.value is not None:
            params[u'value'] = self.value

        res = self.client.api_execute(
            self._path, self.client._MPOST, params=params, timeout=timeout)
        self._index = res.data.decode('utf-8')
        return self

    def get(self):
        """
        Get Information on the lock.
        This allows to operate on locks that have not been acquired directly.
        """
        res = self.client.api_execute(self._path, self.client._MGET)
        if res.data:
            self.value = res.data.decode('utf-8')
        else:
            raise etcd.EtcdException('Lock is non-existent (or expired)')
        self._get_index()
        return self

    def _get_index(self):
        res = self.client.api_execute(
            self._path,
            self.client._MGET,
            {u'field': u'index'})
        if not res.data:
            raise etcd.EtcdException('Lock is non-existent (or expired)')
        self._index = res.data.decode('utf-8')

    def is_locked(self):
        """Check if lock is currently locked."""
        params = {u'field': u'index'}
        res = self.client.api_execute(
            self._path, self.client._MGET, params=params)
        return bool(res.data)

    def release(self):
        """Release this lock."""
        if not self._index:
            raise etcd.EtcdException(
                u'Cannot release lock that has not been locked')
        params = {u'index': self._index}
        res = self.client.api_execute(
            self._path, self.client._MDELETE, params=params)
        self._index = None

    def renew(self, new_ttl, timeout=None):
        """
        Renew the TTL on this lock.

        Args:
            new_ttl (int): new TTL to set.
        """
        if not self._index:
            raise etcd.EtcdException(
                u'Cannot renew lock that has not been locked')
        params = {u'ttl': new_ttl, u'index': self._index}
        res = self.client.api_execute(
            self._path, self.client._MPUT, params=params)