diff --git a/distributed/client.py b/distributed/client.py index 4e424a7386..5dfc436553 100644 --- a/distributed/client.py +++ b/distributed/client.py @@ -2210,8 +2210,8 @@ async def _gather(self, futures, errors="raise", direct=None, local_worker=None) if mismatched_futures: raise ValueError( "Cannot gather Futures created by another client. " - f"These are the {len(mismatched_futures)} (out of {len(futures)}) mismatched Futures and their client IDs " - f"(this client is {self.id}): " + f"These are the {len(mismatched_futures)} (out of {len(futures)}) " + f"mismatched Futures and their client IDs (this client is {self.id}): " f"{ {f: f.client.id for f in mismatched_futures} }" ) keys = [stringify(future.key) for future in future_set] diff --git a/distributed/scheduler.py b/distributed/scheduler.py index 3d59b15982..e83e97c7eb 100644 --- a/distributed/scheduler.py +++ b/distributed/scheduler.py @@ -5887,7 +5887,6 @@ async def gather( self, keys: Collection[str], serializers: list[str] | None = None ) -> dict[str, Any]: """Collect data from workers to the scheduler""" - stimulus_id = f"gather-{time()}" data = {} missing_keys = list(keys) failed_keys: list[str] = [] @@ -5924,20 +5923,6 @@ async def gather( for key in failed_keys } logger.error("Couldn't gather keys: %s", failed_states) - - if missing_workers: - with log_errors(): - # Remove suspicious workers from the scheduler and shut them down. - await asyncio.gather( - *( - self.remove_worker( - address=worker, close=True, stimulus_id=stimulus_id - ) - for worker in missing_workers - ) - ) - logger.error("Shut down unresponsive workers:: %s", missing_workers) - return {"status": "error", "keys": list(failed_keys)} @log_errors diff --git a/distributed/tests/test_scheduler.py b/distributed/tests/test_scheduler.py index 890b4d8dd1..95d0ee9849 100644 --- a/distributed/tests/test_scheduler.py +++ b/distributed/tests/test_scheduler.py @@ -74,7 +74,7 @@ varying, wait_for_state, ) -from distributed.worker import dumps_function, get_worker, secede +from distributed.worker import dumps_function, secede pytestmark = pytest.mark.ci1 @@ -2860,56 +2860,41 @@ async def test_gather_no_workers(c, s, a, b): assert list(res["keys"]) == ["x"] -@gen_cluster(client=True, client_kwargs={"direct_to_workers": False}) -async def test_gather_bad_worker_removed(c, s, a, b): - """ - Upon connection failure or missing expected keys during gather, a worker is - shut down. The tasks should be rescheduled onto different workers, transparently - to `client.gather`. +@pytest.mark.parametrize("direct", [False, True]) +@gen_cluster( + client=True, + nthreads=[("", 1)], + # This behaviour is independent of retries. + # Disable them to reduce the complexity of this test. + config={"distributed.comm.retry.count": 0}, +) +async def test_gather_bad_worker(c, s, a, direct): + """Upon connection failure, gather() tries again indefinitely and transparently, + for as long as the batched comms channel is active. """ - x = c.submit(slowinc, 1, workers=[a.address], allow_other_workers=True) - - def finalizer(*args): - return get_worker().address + x = c.submit(inc, 1, key="x") + c.rpc = await FlakyConnectionPool(failing_connections=3) + s.rpc = await FlakyConnectionPool(failing_connections=1) - fin = c.submit( - finalizer, x, key="final", workers=[a.address], allow_other_workers=True - ) + with captured_logger("distributed.scheduler") as sched_logger: + with captured_logger("distributed.client") as client_logger: + assert await c.gather(x, direct=direct) == 2 - s.rpc = await FlakyConnectionPool(failing_connections=1) + assert "Couldn't gather keys: {'x': 'memory'}" in sched_logger.getvalue() + assert "Couldn't gather 1 keys, rescheduling ('x',)" in client_logger.getvalue() - # This behaviour is independent of retries. Remove them to reduce complexity - # of this setup - with dask.config.set({"distributed.comm.retry.count": 0}): - with captured_logger( - logging.getLogger("distributed.scheduler") - ) as sched_logger, captured_logger( - logging.getLogger("distributed.client") - ) as client_logger: - # Gather using the client (as an ordinary user would) - # Upon a missing key, the client will remove the bad worker and - # reschedule the computations - - # Both tasks are rescheduled onto `b`, since `a` was removed. - assert await fin == b.address - - await a.finished() - assert list(s.workers) == [b.address] - - sched_logger = sched_logger.getvalue() - client_logger = client_logger.getvalue() - assert "Shut down unresponsive workers" in sched_logger - - assert "Couldn't gather 1 keys, rescheduling" in client_logger - - assert s.tasks[fin.key].who_has == {s.workers[b.address]} - assert a.state.executed_count == 2 - assert b.state.executed_count >= 1 - # ^ leave room for a future switch from `remove_worker` to `retire_workers` - - # Ensure that the communication was done via the scheduler, i.e. we actually hit a - # bad connection - assert s.rpc.cnn_count > 0 + if direct: + # 1. try direct=True; fail + # 2. fall back to direct=False; fail + # 3. try direct=True again; fail + # 4. fall back to direct=False again; success + assert c.rpc.cnn_count == 2 + assert s.rpc.cnn_count == 2 + else: + # 1. try direct=False; fail + # 2. try again direct=False; success + assert c.rpc.cnn_count == 0 + assert s.rpc.cnn_count == 2 @gen_cluster(client=True)