Skip to content

Commit

Permalink
Ensure that serialized data is measured correctly (#7593)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Mar 13, 2023
1 parent 1373685 commit 6cab0e2
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 52 deletions.
2 changes: 1 addition & 1 deletion distributed/comm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import socket

import dask
from dask.sizeof import sizeof
from dask.utils import parse_bytes

from distributed import protocol
from distributed.sizeof import safe_sizeof as sizeof
from distributed.utils import get_ip, get_ipv6, nbytes, offload

logger = logging.getLogger(__name__)
Expand Down
13 changes: 13 additions & 0 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import dask
from dask.base import normalize_token
from dask.sizeof import sizeof
from dask.utils import typename

from distributed.protocol import pickle
Expand Down Expand Up @@ -639,6 +640,18 @@ def replace_inner(x):
return replace_inner(x)


@sizeof.register(ToPickle)
@sizeof.register(Serialize)
def sizeof_serialize(obj):
return sizeof(obj.data)


@sizeof.register(Pickled)
@sizeof.register(Serialized)
def sizeof_serialized(obj):
return sizeof(obj.header) + sizeof(obj.frames)


def serialize_bytelist(x, **kwargs):
header, frames = serialize_and_split(x, **kwargs)
if frames:
Expand Down
19 changes: 19 additions & 0 deletions distributed/protocol/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import dask
from dask.sizeof import sizeof

from distributed.protocol import dumps, loads, maybe_compress, msgpack, to_serialize
from distributed.protocol.compression import (
Expand All @@ -15,8 +16,10 @@
)
from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
from distributed.protocol.serialize import (
Pickled,
Serialize,
Serialized,
ToPickle,
dask_deserialize,
dask_serialize,
deserialize,
Expand Down Expand Up @@ -393,3 +396,19 @@ def _(header, frames):
header, frames = serialize(MyObj(), serializers=serializers)
o = deserialize(header, frames)
assert isinstance(o, MyObj)


@pytest.mark.parametrize(
"Wrapper, Wrapped",
[
(Serialize, Serialized),
(to_serialize, Serialized),
(ToPickle, Pickled),
],
)
def test_sizeof_serialize(Wrapper, Wrapped):
size = 100_000
ser_obj = Wrapper(b"0" * size)
assert size <= sizeof(ser_obj) < size * 1.05
serialized = Wrapped(*serialize(ser_obj))
assert size <= sizeof(serialized) < size * 1.05
102 changes: 51 additions & 51 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tornado.ioloop import IOLoop

import dask
from dask import delayed
from dask import delayed, istask
from dask.system import CPU_COUNT
from dask.utils import tmpfile

Expand All @@ -42,7 +42,8 @@
wait,
)
from distributed.comm.registry import backends
from distributed.compatibility import LINUX, WINDOWS, to_thread
from distributed.comm.utils import OFFLOAD_THRESHOLD
from distributed.compatibility import LINUX, WINDOWS, randbytes, to_thread
from distributed.core import CommClosedError, Status, rpc
from distributed.diagnostics import nvml
from distributed.diagnostics.plugin import (
Expand Down Expand Up @@ -2780,64 +2781,36 @@ async def test_forget_dependents_after_release(c, s, a):
assert fut2.key not in {d.key for d in a.state.tasks[fut.key].dependents}


@pytest.mark.filterwarnings("ignore:Large object of size")
@gen_cluster(client=True)
async def test_steal_during_task_deserialization(c, s, a, b, monkeypatch):
stealing_ext = s.extensions["stealing"]
await stealing_ext.stop()
from distributed.utils import ThreadPoolExecutor

class CountingThreadPool(ThreadPoolExecutor):
counter = 0
in_deserialize = asyncio.Event()
wait_in_deserialize = asyncio.Event()

def submit(self, *args, **kwargs):
CountingThreadPool.counter += 1
return super().submit(*args, **kwargs)
async def custom_worker_offload(func, *args):
res = func(*args)
if not istask(args) and istask(res):
in_deserialize.set()
await wait_in_deserialize.wait()
return res

# Ensure we're always offloading
monkeypatch.setattr("distributed.worker.OFFLOAD_THRESHOLD", 1)
threadpool = CountingThreadPool(
max_workers=1, thread_name_prefix="Counting-Offload-Threadpool"
)
try:
monkeypatch.setattr("distributed.utils._offload_executor", threadpool)

class SlowDeserializeCallable:
def __init__(self, delay=0.1):
self.delay = delay

def __getstate__(self):
return self.delay

def __setstate__(self, state):
delay = state
import time

time.sleep(delay)
return SlowDeserializeCallable(delay)

def __call__(self, *args, **kwargs):
return 41

slow_deserialized_func = SlowDeserializeCallable()
fut = c.submit(
slow_deserialized_func, 1, workers=[a.address], allow_other_workers=True
)

while CountingThreadPool.counter == 0:
await asyncio.sleep(0)
monkeypatch.setattr("distributed.worker.offload", custom_worker_offload)
obj = randbytes(OFFLOAD_THRESHOLD + 1)
fut = c.submit(lambda _: 41, obj, workers=[a.address], allow_other_workers=True)

ts = s.tasks[fut.key]
a.handle_stimulus(StealRequestEvent(key=fut.key, stimulus_id="test"))
stealing_ext.scheduler.send_task_to_worker(b.address, ts)
await in_deserialize.wait()
ts = s.tasks[fut.key]
a.handle_stimulus(StealRequestEvent(key=fut.key, stimulus_id="test"))
stealing_ext.scheduler.send_task_to_worker(b.address, ts)

fut2 = c.submit(inc, fut, workers=[a.address])
fut3 = c.submit(inc, fut2, workers=[a.address])

assert await fut2 == 42
await fut3

finally:
threadpool.shutdown()
fut2 = c.submit(inc, fut, workers=[a.address])
fut3 = c.submit(inc, fut2, workers=[a.address])
wait_in_deserialize.set()
assert await fut2 == 42
await fut3


@gen_cluster(client=True)
Expand Down Expand Up @@ -3768,3 +3741,30 @@ def print_stderr(*args, **kwargs):

assert "" == out
assert "" == err


class EnsureOffloaded:
def __init__(self, main_thread_id):
self.main_thread_id = main_thread_id
self.data = randbytes(OFFLOAD_THRESHOLD + 1)

def __sizeof__(self):
return len(self.data)

def __getstate__(self):
assert self.main_thread_id
assert self.main_thread_id != threading.get_ident()
return (self.data, self.main_thread_id, threading.get_ident())

def __setstate__(self, state):
_, main_thread, serialize_thread = state
assert main_thread != threading.get_ident()
return EnsureOffloaded(main_thread)


@gen_cluster(client=True)
async def test_offload_getdata(c, s, a, b):
"""Test that functions wrapped by offload() are metered"""
x = c.submit(EnsureOffloaded, threading.get_ident(), key="x", workers=[a.address])
y = c.submit(lambda x: None, x, key="y", workers=[b.address])
await y

0 comments on commit 6cab0e2

Please sign in to comment.