Skip to content

Commit

Permalink
Remove dumps_task (#8067)
Browse files Browse the repository at this point in the history
Co-authored-by: Hendrik Makait <hendrik@coiled.io>
  • Loading branch information
fjetter and hendrikmakait authored Aug 11, 2023
1 parent 163165b commit 4f30abc
Show file tree
Hide file tree
Showing 12 changed files with 101 additions and 346 deletions.
18 changes: 0 additions & 18 deletions distributed/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from dask.utils import parse_timedelta

from distributed import profile, protocol
from distributed.collections import LRU
from distributed.comm import (
Comm,
CommClosedError,
Expand All @@ -40,7 +39,6 @@
from distributed.counter import Counter
from distributed.diskutils import WorkDir, WorkSpace
from distributed.metrics import context_meter, time
from distributed.protocol import pickle
from distributed.system_monitor import SystemMonitor
from distributed.utils import (
NoOpAwaitable,
Expand All @@ -64,21 +62,6 @@
Coro = Coroutine[Any, Any, T]


cache_loads: LRU[bytes, Callable[..., Any]] = LRU(maxsize=100)


def loads_function(bytes_object):
"""Load a function from bytes, cache bytes"""
if len(bytes_object) < 100000:
try:
result = cache_loads[bytes_object]
except KeyError:
result = pickle.loads(bytes_object)
cache_loads[bytes_object] = result
return result
return pickle.loads(bytes_object)


class Status(Enum):
"""
This Enum contains the various states a cluster, worker, scheduler and nanny can be
Expand Down Expand Up @@ -519,7 +502,6 @@ def func(data):
if load:
try:
import_file(out_filename)
cache_loads.data.clear()
except Exception as e:
logger.exception(e)
raise e
Expand Down
15 changes: 6 additions & 9 deletions distributed/recreate_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
from dask.utils import stringify

from distributed.client import futures_of, wait
from distributed.protocol.serialize import ToPickle
from distributed.utils import sync
from distributed.utils_comm import pack_data
from distributed.worker import _deserialize

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -42,7 +42,10 @@ def get_error_cause(self, *args, keys=(), **kwargs):
def get_runspec(self, *args, key=None, **kwargs):
key = self._process_key(key)
ts = self.scheduler.tasks.get(key)
return {"task": ts.run_spec, "deps": [dts.key for dts in ts.dependencies]}
return {
"task": ToPickle(ts.run_spec),
"deps": [dts.key for dts in ts.dependencies],
}


class ReplayTaskClient:
Expand Down Expand Up @@ -83,13 +86,7 @@ async def _get_raw_components_from_future(self, future):
await wait(future)
key = future.key
spec = await self.scheduler.get_runspec(key=key)
deps, task = spec["deps"], spec["task"]
if isinstance(task, dict):
function, args, kwargs = _deserialize(**task)
return (function, args, kwargs, deps)
else:
function, args, kwargs = _deserialize(task=task)
return (function, args, kwargs, deps)
return (*spec["task"], spec["deps"])

async def _prepare_raw_components(self, raw_components):
"""
Expand Down
32 changes: 15 additions & 17 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
)
from distributed.utils_perf import disable_gc_diagnosis, enable_gc_diagnosis
from distributed.variable import VariableExtension
from distributed.worker import dumps_task
from distributed.worker import _normalize_task

if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
Expand Down Expand Up @@ -156,6 +156,8 @@
# (recommendations, client messages, worker messages)
RecsMsgs: TypeAlias = tuple[Recs, Msgs, Msgs]

T_runspec: TypeAlias = tuple[Callable, tuple, dict[str, Any]]

logger = logging.getLogger(__name__)
LOG_PDB = dask.config.get("distributed.admin.pdb-on-err")
DEFAULT_DATA_SIZE = parse_bytes(
Expand Down Expand Up @@ -1176,7 +1178,7 @@ class TaskState:
#: "pure data" (such as, for example, a piece of data loaded in the scheduler using
#: :meth:`Client.scatter`). A "pure data" task cannot be computed again if its
#: value is lost.
run_spec: object
run_spec: T_runspec | None

#: The priority provides each task with a relative ranking which is used to break
#: ties when many tasks are being considered for execution.
Expand Down Expand Up @@ -1375,7 +1377,7 @@ class TaskState:
def __init__(
self,
key: str,
run_spec: object,
run_spec: T_runspec | None,
state: TaskStateState,
):
self.key = key
Expand Down Expand Up @@ -1787,7 +1789,7 @@ def __pdict__(self) -> dict[str, Any]:
def new_task(
self,
key: str,
spec: object,
spec: T_runspec | None,
state: TaskStateState,
computation: Computation | None = None,
) -> TaskState:
Expand Down Expand Up @@ -3339,10 +3341,7 @@ def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]:
dts.key: [ws.address for ws in dts.who_has] for dts in ts.dependencies
},
"nbytes": {dts.key: dts.nbytes for dts in ts.dependencies},
"run_spec": None,
"function": None,
"args": None,
"kwargs": None,
"run_spec": ToPickle(ts.run_spec),
"resource_restrictions": ts.resource_restrictions,
"actor": ts.actor,
"annotations": ts.annotations,
Expand All @@ -3351,11 +3350,6 @@ def _task_to_msg(self, ts: TaskState, duration: float = -1) -> dict[str, Any]:
if self.validate:
assert all(msg["who_has"].values())

if isinstance(ts.run_spec, dict):
msg.update(ts.run_spec)
else:
msg["run_spec"] = ts.run_spec

return msg


Expand Down Expand Up @@ -4602,7 +4596,11 @@ async def update_graph(
self.digest_metric("update-graph-duration", end - start)

def _generate_taskstates(
self, keys: set[str], dsk: dict, dependencies: dict, computation: Computation
self,
keys: set[str],
dsk: dict[str, T_runspec],
dependencies: dict[str, set[str]],
computation: Computation,
) -> tuple:
# Get or create task states
runnable = []
Expand Down Expand Up @@ -8483,8 +8481,8 @@ def transition(


def _materialize_graph(
graph: HighLevelGraph, global_annotations: dict
) -> tuple[dict, dict, dict]:
graph: HighLevelGraph, global_annotations: dict[str, Any]
) -> tuple[dict[str, T_runspec], dict[str, set[str]], dict[str, Any]]:
dsk = dask.utils.ensure_dict(graph)
annotations_by_type: defaultdict[str, dict[str, Any]] = defaultdict(dict)
for annotations_type, value in global_annotations.items():
Expand Down Expand Up @@ -8540,6 +8538,6 @@ def _materialize_graph(
for k in list(dsk):
if dsk[k] is k:
del dsk[k]
dsk = valmap(dumps_task, dsk)
dsk = valmap(_normalize_task, dsk)

return dsk, dependencies, annotations_by_type
14 changes: 3 additions & 11 deletions distributed/tests/test_cancelled_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -1111,12 +1111,11 @@ def test_workerstate_resumed_waiting_to_flight(ws):
assert ws.tasks["x"].state == "flight"


@pytest.mark.parametrize("critical_section", ["execute", "deserialize_task"])
@pytest.mark.parametrize("resume_inside_critical_section", [False, True])
@pytest.mark.parametrize("resumed_status", ["executing", "resumed"])
@gen_cluster(client=True, nthreads=[("", 1)])
async def test_execute_preamble_early_cancel(
c, s, b, critical_section, resume_inside_critical_section, resumed_status
c, s, b, resume_inside_critical_section, resumed_status
):
"""Test multiple race conditions in the preamble of Worker.execute(), which used to
cause a task to remain permanently in resumed state or to crash the worker through
Expand All @@ -1129,15 +1128,8 @@ async def test_execute_preamble_early_cancel(
test_worker.py::test_execute_preamble_abort_retirement
"""
async with BlockedExecute(s.address) as a:
if critical_section == "execute":
in_ev = a.in_execute
block_ev = a.block_execute
a.block_deserialize_task.set()
else:
assert critical_section == "deserialize_task"
in_ev = a.in_deserialize_task
block_ev = a.block_deserialize_task
a.block_execute.set()
in_ev = a.in_execute
block_ev = a.block_execute

async def resume():
if resumed_status == "executing":
Expand Down
78 changes: 32 additions & 46 deletions distributed/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import dask
from dask import delayed
from dask.highlevelgraph import HighLevelGraph, MaterializedLayer
from dask.utils import apply, parse_timedelta, stringify, tmpfile, typename
from dask.utils import parse_timedelta, stringify, tmpfile, typename

from distributed import (
CancelledError,
Expand Down Expand Up @@ -74,7 +74,7 @@
varying,
wait_for_state,
)
from distributed.worker import dumps_function, dumps_task, get_worker, secede
from distributed.worker import dumps_function, get_worker, secede

pytestmark = pytest.mark.ci1

Expand Down Expand Up @@ -345,7 +345,26 @@ async def test_decide_worker_rootish_while_last_worker_is_retiring(c, s, a):
await wait(xs + ys)


@pytest.mark.slow
from distributed import WorkerPlugin


class CountData(WorkerPlugin):
def __init__(self, keys):
self.keys = keys
self.worker = None
self.count = 0

def setup(self, worker):
self.worker = worker

def transition(self, start, finish, *args, **kwargs):
count = 0
for k in self.worker.data:
if k in self.keys:
count += 1
self.count = max(self.count, count)


@gen_cluster(
nthreads=[("", 2)] * 4,
client=True,
Expand All @@ -359,33 +378,18 @@ async def test_graph_execution_width(c, s, *workers):
The number of parallel work streams match the number of threads.
"""

class Refcount:
"Track how many instances of this class exist; logs the count at creation and deletion"

count = 0
lock = dask.utils.SerializableLock()
log = []

def __init__(self):
with self.lock:
type(self).count += 1
self.log.append(self.count)

def __del__(self):
with self.lock:
self.log.append(self.count)
type(self).count -= 1

roots = [delayed(Refcount)() for _ in range(32)]
roots = [delayed(inc)(ix) for ix in range(32)]
passthrough1 = [delayed(slowidentity)(r, delay=0) for r in roots]
passthrough2 = [delayed(slowidentity)(r, delay=0) for r in passthrough1]
done = [delayed(lambda r: None)(r) for r in passthrough2]

await c.register_worker_plugin(
CountData(keys=[f.key for f in roots]), name="count-roots"
)
fs = c.compute(done)
await wait(fs)
# NOTE: the max should normally equal `total_nthreads`. But some macOS CI machines
# are slow enough that they aren't able to reach the full parallelism of 8 threads.
assert max(Refcount.log) <= s.total_nthreads

res = await c.run(lambda dask_worker: dask_worker.plugins["count-roots"].count)
assert all(0 < count <= 2 for count in res.values())


@gen_cluster(client=True, nthreads=[("", 1)])
Expand Down Expand Up @@ -953,24 +957,6 @@ def test_dumps_function():
assert a != c


def test_dumps_task():
d = dumps_task((inc, 1))
assert set(d) == {"function", "args"}

def f(x, y=2):
return x + y

d = dumps_task((apply, f, (1,), {"y": 10}))
assert cloudpickle.loads(d["function"])(1, 2) == 3
assert cloudpickle.loads(d["args"]) == (1,)
assert cloudpickle.loads(d["kwargs"]) == {"y": 10}

d = dumps_task((apply, f, (1,)))
assert cloudpickle.loads(d["function"])(1, 2) == 3
assert cloudpickle.loads(d["args"]) == (1,)
assert set(d) == {"function", "args"}


@pytest.mark.parametrize("worker_saturation", [1.0, float("inf")])
@gen_cluster(client=True)
async def test_ready_remove_worker(c, s, a, b, worker_saturation):
Expand Down Expand Up @@ -1357,9 +1343,9 @@ async def test_update_graph_culls(s, a, b):
layers={
"foo": MaterializedLayer(
{
"x": dumps_task((inc, 1)),
"y": dumps_task((inc, "x")),
"z": dumps_task((inc, 2)),
"x": (inc, 1),
"y": (inc, "x"),
"z": (inc, 2),
}
)
},
Expand Down
3 changes: 0 additions & 3 deletions distributed/tests/test_spans.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,14 +518,12 @@ async def test_worker_metrics(c, s, a, b):

# metrics for foo include self and its child bar
assert list(foo_metrics) == [
("execute", "x", "deserialize", "seconds"),
("execute", "x", "thread-cpu", "seconds"),
("execute", "x", "thread-noncpu", "seconds"),
("execute", "x", "executor", "seconds"),
("execute", "x", "other", "seconds"),
("execute", "x", "memory-read", "count"),
("execute", "x", "memory-read", "bytes"),
("execute", "y", "deserialize", "seconds"),
("execute", "y", "thread-cpu", "seconds"),
("execute", "y", "thread-noncpu", "seconds"),
("execute", "y", "executor", "seconds"),
Expand All @@ -536,7 +534,6 @@ async def test_worker_metrics(c, s, a, b):
list(bar0_metrics)
== list(bar1_metrics)
== [
("execute", "y", "deserialize", "seconds"),
("execute", "y", "thread-cpu", "seconds"),
("execute", "y", "thread-noncpu", "seconds"),
("execute", "y", "executor", "seconds"),
Expand Down
Loading

0 comments on commit 4f30abc

Please sign in to comment.