diff --git a/kazoo/recipe/lock.py b/kazoo/recipe/lock.py index 7722a978..b2144a3b 100644 --- a/kazoo/recipe/lock.py +++ b/kazoo/recipe/lock.py @@ -294,6 +294,7 @@ def _get_predecessor(self, node): (e.g. rlock), this and also edge cases where the lock's ephemeral node is gone. """ + node_sequence = node[len(self.prefix):] children = self.client.get_children(self.path) found_self = False # Filter out the contenders using the computed regex @@ -301,7 +302,12 @@ def _get_predecessor(self, node): for child in children: match = self._contenders_re.search(child) if match is not None: - contender_matches.append(match) + contender_sequence = match.group(1) + # Only consider contenders with a smaller sequence number. + # A contender with a smaller sequence number has a higher + # priority. + if contender_sequence < node_sequence: + contender_matches.append(match) if child == node: # Remember the node's match object so we can short circuit # below. @@ -313,15 +319,13 @@ def _get_predecessor(self, node): # node was removed. raise ForceRetryError() - predecessor = None - # Sort the contenders using the sequence number extracted by the regex, - # then extract the original string. - for match in sorted(contender_matches, key=lambda m: m.groups()): - if match is found_self: - break - predecessor = match.string + if not contender_matches: + return None - return predecessor + # Sort the contenders using the sequence number extracted by the regex + # and return the original string of the predecessor. + sorted_matches = sorted(contender_matches, key=lambda m: m.groups()) + return sorted_matches[-1].string def _find_node(self): children = self.client.get_children(self.path) diff --git a/kazoo/tests/test_lock.py b/kazoo/tests/test_lock.py index 0e16949e..97847cf3 100644 --- a/kazoo/tests/test_lock.py +++ b/kazoo/tests/test_lock.py @@ -467,6 +467,93 @@ def test_write_lock(self): gotten = lock2.acquire(blocking=False) assert gotten is False + def _rw_lock_order(self): + writer_running = threading.Event() + reader_running = threading.Event() + + def _thread(lock, event): + event.set() + with lock: + pass + + write_lock1 = self.client.WriteLock(self.lockpath, "writer 1") + write_lock2 = self.client.WriteLock(self.lockpath, "writer 2") + read_lock = self.client.ReadLock(self.lockpath, "reader") + + writer_thread = self.make_thread( + target=_thread, + args=(write_lock2, writer_running) + ) + reader_thread = self.make_thread( + target=_thread, + args=(read_lock, reader_running) + ) + + with write_lock1: + reader_thread.start() + reader_running.wait() + writer_thread.start() + writer_running.wait() + time.sleep(5) + + def test_rw_lock(self): + reader_event = self.make_event() + reader_lock = self.client.ReadLock(self. lockpath, "reader") + reader_thread = self.make_thread( + target=self._thread_lock_acquire_til_event, + args=("reader", reader_lock, reader_event) + ) + + writer_event = self.make_event() + writer_lock = self.client.WriteLock(self. lockpath, "writer") + writer_thread = self.make_thread( + target=self._thread_lock_acquire_til_event, + args=("writer", writer_lock, writer_event) + ) + + # acquire a write lock ourselves first to make the others line up + lock = self.client.WriteLock(self.lockpath, "test") + lock.acquire() + + reader_thread.start() + writer_thread.start() + + # wait for everyone to line up on the lock + wait = self.make_wait() + wait(lambda: len(lock.contenders()) == 3) + contenders = lock.contenders() + + assert contenders[0] == "test" + remaining = contenders[1:] + + # release the lock and contenders should claim it in order + lock.release() + + contender_bits = { + "reader": (reader_thread, reader_event), + "writer": (writer_thread, writer_event), + } + + for contender in ("reader", "writer"): + thread, event = contender_bits[contender] + + with self.condition: + while not self.active_thread: + self.condition.wait() + assert self.active_thread == contender + + assert lock.contenders() == remaining + remaining = remaining[1:] + + event.set() + + with self.condition: + while self.active_thread: + self.condition.wait() + + reader_thread.join() + writer_thread.join() + class TestSemaphore(KazooTestCase): def __init__(self, *args, **kw):