Skip to content

Commit

Permalink
Don't shut down unresponsive workers on gather() (#8101)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Aug 14, 2023
1 parent 6ce34d8 commit ac5ddc3
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 64 deletions.
4 changes: 2 additions & 2 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2210,8 +2210,8 @@ async def _gather(self, futures, errors="raise", direct=None, local_worker=None)
if mismatched_futures:
raise ValueError(
"Cannot gather Futures created by another client. "
f"These are the {len(mismatched_futures)} (out of {len(futures)}) mismatched Futures and their client IDs "
f"(this client is {self.id}): "
f"These are the {len(mismatched_futures)} (out of {len(futures)}) "
f"mismatched Futures and their client IDs (this client is {self.id}): "
f"{ {f: f.client.id for f in mismatched_futures} }"
)
keys = [stringify(future.key) for future in future_set]
Expand Down
15 changes: 0 additions & 15 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5887,7 +5887,6 @@ async def gather(
self, keys: Collection[str], serializers: list[str] | None = None
) -> dict[str, Any]:
"""Collect data from workers to the scheduler"""
stimulus_id = f"gather-{time()}"
data = {}
missing_keys = list(keys)
failed_keys: list[str] = []
Expand Down Expand Up @@ -5924,20 +5923,6 @@ async def gather(
for key in failed_keys
}
logger.error("Couldn't gather keys: %s", failed_states)

if missing_workers:
with log_errors():
# Remove suspicious workers from the scheduler and shut them down.
await asyncio.gather(
*(
self.remove_worker(
address=worker, close=True, stimulus_id=stimulus_id
)
for worker in missing_workers
)
)
logger.error("Shut down unresponsive workers:: %s", missing_workers)

return {"status": "error", "keys": list(failed_keys)}

@log_errors
Expand Down
79 changes: 32 additions & 47 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
varying,
wait_for_state,
)
from distributed.worker import dumps_function, get_worker, secede
from distributed.worker import dumps_function, secede

pytestmark = pytest.mark.ci1

Expand Down Expand Up @@ -2860,56 +2860,41 @@ async def test_gather_no_workers(c, s, a, b):
assert list(res["keys"]) == ["x"]


@gen_cluster(client=True, client_kwargs={"direct_to_workers": False})
async def test_gather_bad_worker_removed(c, s, a, b):
"""
Upon connection failure or missing expected keys during gather, a worker is
shut down. The tasks should be rescheduled onto different workers, transparently
to `client.gather`.
@pytest.mark.parametrize("direct", [False, True])
@gen_cluster(
client=True,
nthreads=[("", 1)],
# This behaviour is independent of retries.
# Disable them to reduce the complexity of this test.
config={"distributed.comm.retry.count": 0},
)
async def test_gather_bad_worker(c, s, a, direct):
"""Upon connection failure, gather() tries again indefinitely and transparently,
for as long as the batched comms channel is active.
"""
x = c.submit(slowinc, 1, workers=[a.address], allow_other_workers=True)

def finalizer(*args):
return get_worker().address
x = c.submit(inc, 1, key="x")
c.rpc = await FlakyConnectionPool(failing_connections=3)
s.rpc = await FlakyConnectionPool(failing_connections=1)

fin = c.submit(
finalizer, x, key="final", workers=[a.address], allow_other_workers=True
)
with captured_logger("distributed.scheduler") as sched_logger:
with captured_logger("distributed.client") as client_logger:
assert await c.gather(x, direct=direct) == 2

s.rpc = await FlakyConnectionPool(failing_connections=1)
assert "Couldn't gather keys: {'x': 'memory'}" in sched_logger.getvalue()
assert "Couldn't gather 1 keys, rescheduling ('x',)" in client_logger.getvalue()

# This behaviour is independent of retries. Remove them to reduce complexity
# of this setup
with dask.config.set({"distributed.comm.retry.count": 0}):
with captured_logger(
logging.getLogger("distributed.scheduler")
) as sched_logger, captured_logger(
logging.getLogger("distributed.client")
) as client_logger:
# Gather using the client (as an ordinary user would)
# Upon a missing key, the client will remove the bad worker and
# reschedule the computations

# Both tasks are rescheduled onto `b`, since `a` was removed.
assert await fin == b.address

await a.finished()
assert list(s.workers) == [b.address]

sched_logger = sched_logger.getvalue()
client_logger = client_logger.getvalue()
assert "Shut down unresponsive workers" in sched_logger

assert "Couldn't gather 1 keys, rescheduling" in client_logger

assert s.tasks[fin.key].who_has == {s.workers[b.address]}
assert a.state.executed_count == 2
assert b.state.executed_count >= 1
# ^ leave room for a future switch from `remove_worker` to `retire_workers`

# Ensure that the communication was done via the scheduler, i.e. we actually hit a
# bad connection
assert s.rpc.cnn_count > 0
if direct:
# 1. try direct=True; fail
# 2. fall back to direct=False; fail
# 3. try direct=True again; fail
# 4. fall back to direct=False again; success
assert c.rpc.cnn_count == 2
assert s.rpc.cnn_count == 2
else:
# 1. try direct=False; fail
# 2. try again direct=False; success
assert c.rpc.cnn_count == 0
assert s.rpc.cnn_count == 2


@gen_cluster(client=True)
Expand Down

0 comments on commit ac5ddc3

Please sign in to comment.