Skip to content

Commit

Permalink
Don't stop Adaptive on error (#8871)
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait authored Sep 27, 2024
1 parent f7b7f17 commit e52da46
Show file tree
Hide file tree
Showing 4 changed files with 387 additions and 273 deletions.
115 changes: 100 additions & 15 deletions distributed/deploy/adaptive.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,39 @@
from __future__ import annotations

import logging
from collections.abc import Hashable
from datetime import timedelta
from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Callable, Literal, cast

from tornado.ioloop import IOLoop

import dask.config
from dask.utils import parse_timedelta

from distributed.compatibility import PeriodicCallback
from distributed.core import Status
from distributed.deploy.adaptive_core import AdaptiveCore
from distributed.protocol import pickle
from distributed.utils import log_errors

if TYPE_CHECKING:
from typing_extensions import TypeAlias

from distributed.deploy.cluster import Cluster
from distributed.scheduler import WorkerState

logger = logging.getLogger(__name__)


AdaptiveStateState: TypeAlias = Literal[
"starting",
"running",
"stopped",
"inactive",
]


class Adaptive(AdaptiveCore):
'''
Adaptively allocate workers based on scheduler load. A superclass.
Expand Down Expand Up @@ -81,16 +100,21 @@ class Adaptive(AdaptiveCore):
specified in the dask config under the distributed.adaptive key.
'''

interval: float | None
periodic_callback: PeriodicCallback | None
#: Whether this adaptive strategy is periodically adapting
state: AdaptiveStateState

def __init__(
self,
cluster=None,
interval=None,
minimum=None,
maximum=None,
wait_count=None,
target_duration=None,
worker_key=None,
**kwargs,
cluster: Cluster,
interval: str | float | timedelta | None = None,
minimum: int | None = None,
maximum: int | float | None = None,
wait_count: int | None = None,
target_duration: str | float | timedelta | None = None,
worker_key: Callable[[WorkerState], Hashable] | None = None,
**kwargs: Any,
):
self.cluster = cluster
self.worker_key = worker_key
Expand All @@ -99,20 +123,78 @@ def __init__(
if interval is None:
interval = dask.config.get("distributed.adaptive.interval")
if minimum is None:
minimum = dask.config.get("distributed.adaptive.minimum")
minimum = cast(int, dask.config.get("distributed.adaptive.minimum"))
if maximum is None:
maximum = dask.config.get("distributed.adaptive.maximum")
maximum = cast(float, dask.config.get("distributed.adaptive.maximum"))
if wait_count is None:
wait_count = dask.config.get("distributed.adaptive.wait-count")
wait_count = cast(int, dask.config.get("distributed.adaptive.wait-count"))
if target_duration is None:
target_duration = dask.config.get("distributed.adaptive.target-duration")
target_duration = cast(
str, dask.config.get("distributed.adaptive.target-duration")
)

self.interval = parse_timedelta(interval, "seconds")
self.periodic_callback = None

if self.interval and self.cluster:
import weakref

self_ref = weakref.ref(self)

async def _adapt():
adaptive = self_ref()
if not adaptive or adaptive.state != "running":
return
if adaptive.cluster.status != Status.running:
adaptive.stop(reason="cluster-not-running")
return
try:
await adaptive.adapt()
except Exception:
logger.warning(
"Adaptive encountered an error while adapting", exc_info=True
)

self.periodic_callback = PeriodicCallback(_adapt, self.interval * 1000)
self.state = "starting"
self.loop.add_callback(self._start)
else:
self.state = "inactive"

self.target_duration = parse_timedelta(target_duration)

super().__init__(
minimum=minimum, maximum=maximum, wait_count=wait_count, interval=interval
super().__init__(minimum=minimum, maximum=maximum, wait_count=wait_count)

def _start(self) -> None:
if self.state != "starting":
return

assert self.periodic_callback is not None
self.periodic_callback.start()
self.state = "running"
logger.info(
"Adaptive scaling started: minimum=%s maximum=%s",
self.minimum,
self.maximum,
)

def stop(self, reason: str = "unknown") -> None:
if self.state in ("inactive", "stopped"):
return

if self.state == "running":
assert self.periodic_callback is not None
self.periodic_callback.stop()
logger.info(
"Adaptive scaling stopped: minimum=%s maximum=%s. Reason: %s",
self.minimum,
self.maximum,
reason,
)

self.periodic_callback = None
self.state = "stopped"

@property
def scheduler(self):
return self.cluster.scheduler_comm
Expand Down Expand Up @@ -210,6 +292,9 @@ async def scale_up(self, n):
def loop(self) -> IOLoop:
"""Override Adaptive.loop"""
if self.cluster:
return self.cluster.loop
return self.cluster.loop # type: ignore[return-value]
else:
return IOLoop.current()

def __del__(self):
self.stop(reason="adaptive-deleted")
116 changes: 20 additions & 96 deletions distributed/deploy/adaptive_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,24 @@

import logging
import math
from abc import ABC, abstractmethod
from collections import defaultdict, deque
from collections.abc import Iterable
from datetime import timedelta
from typing import TYPE_CHECKING, Literal, cast
from typing import TYPE_CHECKING, cast

import tlz as toolz
from tornado.ioloop import IOLoop

import dask.config
from dask.utils import parse_timedelta

from distributed.compatibility import PeriodicCallback
from distributed.metrics import time

if TYPE_CHECKING:
from typing_extensions import TypeAlias

from distributed.scheduler import WorkerState

logger = logging.getLogger(__name__)


AdaptiveStateState: TypeAlias = Literal[
"starting",
"running",
"stopped",
"inactive",
]


class AdaptiveCore:
class AdaptiveCore(ABC):
"""
The core logic for adaptive deployments, with none of the cluster details
Expand Down Expand Up @@ -91,54 +78,22 @@ class AdaptiveCore:
minimum: int
maximum: int | float
wait_count: int
interval: int | float
periodic_callback: PeriodicCallback | None
plan: set[WorkerState]
requested: set[WorkerState]
observed: set[WorkerState]
close_counts: defaultdict[WorkerState, int]
_adapting: bool
#: Whether this adaptive strategy is periodically adapting
_state: AdaptiveStateState
log: deque[tuple[float, dict]]
_adapting: bool

def __init__(
self,
minimum: int = 0,
maximum: int | float = math.inf,
wait_count: int = 3,
interval: str | int | float | timedelta = "1s",
):
if not isinstance(maximum, int) and not math.isinf(maximum):
raise TypeError(f"maximum must be int or inf; got {maximum}")
raise ValueError(f"maximum must be int or inf; got {maximum}")

self.minimum = minimum
self.maximum = maximum
self.wait_count = wait_count
self.interval = parse_timedelta(interval, "seconds")
self.periodic_callback = None

if self.interval:
import weakref

self_ref = weakref.ref(self)

async def _adapt():
core = self_ref()
if core:
await core.adapt()

self.periodic_callback = PeriodicCallback(_adapt, self.interval * 1000)
self._state = "starting"
self.loop.add_callback(self._start)
else:
self._state = "inactive"
try:
self.plan = set()
self.requested = set()
self.observed = set()
except Exception:
pass

# internal state
self.close_counts = defaultdict(int)
Expand All @@ -147,38 +102,22 @@ async def _adapt():
maxlen=dask.config.get("distributed.admin.low-level-log-length")
)

def _start(self) -> None:
if self._state != "starting":
return

assert self.periodic_callback is not None
self.periodic_callback.start()
self._state = "running"
logger.info(
"Adaptive scaling started: minimum=%s maximum=%s",
self.minimum,
self.maximum,
)

def stop(self) -> None:
if self._state in ("inactive", "stopped"):
return
@property
@abstractmethod
def plan(self) -> set[WorkerState]: ...

if self._state == "running":
assert self.periodic_callback is not None
self.periodic_callback.stop()
logger.info(
"Adaptive scaling stopped: minimum=%s maximum=%s",
self.minimum,
self.maximum,
)
@property
@abstractmethod
def requested(self) -> set[WorkerState]: ...

self.periodic_callback = None
self._state = "stopped"
@property
@abstractmethod
def observed(self) -> set[WorkerState]: ...

@abstractmethod
async def target(self) -> int:
"""The target number of workers that should exist"""
raise NotImplementedError()
...

async def workers_to_close(self, target: int) -> list:
"""
Expand All @@ -198,11 +137,11 @@ async def safe_target(self) -> int:

return n

async def scale_down(self, n: int) -> None:
raise NotImplementedError()
@abstractmethod
async def scale_down(self, n: int) -> None: ...

async def scale_up(self, workers: Iterable) -> None:
raise NotImplementedError()
@abstractmethod
async def scale_up(self, workers: Iterable) -> None: ...

async def recommendations(self, target: int) -> dict:
"""
Expand Down Expand Up @@ -270,20 +209,5 @@ async def adapt(self) -> None:
await self.scale_up(**recommendations)
if status == "down":
await self.scale_down(**recommendations)
except OSError:
if status != "down":
logger.error("Adaptive stopping due to error", exc_info=True)
self.stop()
else:
logger.error(
"Error during adaptive downscaling. Ignoring.", exc_info=True
)
finally:
self._adapting = False

def __del__(self):
self.stop()

@property
def loop(self) -> IOLoop:
return IOLoop.current()
Loading

0 comments on commit e52da46

Please sign in to comment.