Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure that serialized data is measured correctly #7593

Merged
2 changes: 1 addition & 1 deletion distributed/comm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import socket

import dask
from dask.sizeof import sizeof
from dask.utils import parse_bytes

from distributed import protocol
from distributed.sizeof import safe_sizeof as sizeof
from distributed.utils import get_ip, get_ipv6, nbytes, offload

logger = logging.getLogger(__name__)
Expand Down
13 changes: 13 additions & 0 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import dask
from dask.base import normalize_token
from dask.sizeof import sizeof
from dask.utils import typename

from distributed.protocol import pickle
Expand Down Expand Up @@ -639,6 +640,18 @@ def replace_inner(x):
return replace_inner(x)


@sizeof.register(ToPickle)
@sizeof.register(Serialize)
def sizeof_serialize(obj):
return sizeof(obj.data)


@sizeof.register(Pickled)
@sizeof.register(Serialized)
def sizeof_serialized(obj):
return sizeof(obj.header) + sizeof(obj.frames)


def serialize_bytelist(x, **kwargs):
header, frames = serialize_and_split(x, **kwargs)
if frames:
Expand Down
19 changes: 19 additions & 0 deletions distributed/protocol/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import dask
from dask.sizeof import sizeof

from distributed.protocol import dumps, loads, maybe_compress, msgpack, to_serialize
from distributed.protocol.compression import (
Expand All @@ -15,8 +16,10 @@
)
from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
from distributed.protocol.serialize import (
Pickled,
Serialize,
Serialized,
ToPickle,
dask_deserialize,
dask_serialize,
deserialize,
Expand Down Expand Up @@ -393,3 +396,19 @@ def _(header, frames):
header, frames = serialize(MyObj(), serializers=serializers)
o = deserialize(header, frames)
assert isinstance(o, MyObj)


@pytest.mark.parametrize(
"Wrapper, Wrapped",
[
(Serialize, Serialized),
(to_serialize, Serialized),
(ToPickle, Pickled),
],
)
def test_sizeof_serialize(Wrapper, Wrapped):
size = 100_000
ser_obj = Wrapper(b"0" * size)
assert size <= sizeof(ser_obj) < size * 1.05
serialized = Wrapped(*serialize(ser_obj))
assert size <= sizeof(serialized) < size * 1.05
102 changes: 51 additions & 51 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tornado.ioloop import IOLoop

import dask
from dask import delayed
from dask import delayed, istask
from dask.system import CPU_COUNT
from dask.utils import tmpfile

Expand All @@ -42,7 +42,8 @@
wait,
)
from distributed.comm.registry import backends
from distributed.compatibility import LINUX, WINDOWS, to_thread
from distributed.comm.utils import OFFLOAD_THRESHOLD
from distributed.compatibility import LINUX, WINDOWS, randbytes, to_thread
from distributed.core import CommClosedError, Status, rpc
from distributed.diagnostics import nvml
from distributed.diagnostics.plugin import (
Expand Down Expand Up @@ -2780,64 +2781,36 @@ async def test_forget_dependents_after_release(c, s, a):
assert fut2.key not in {d.key for d in a.state.tasks[fut.key].dependents}


@pytest.mark.filterwarnings("ignore:Large object of size")
@gen_cluster(client=True)
async def test_steal_during_task_deserialization(c, s, a, b, monkeypatch):
stealing_ext = s.extensions["stealing"]
await stealing_ext.stop()
from distributed.utils import ThreadPoolExecutor

class CountingThreadPool(ThreadPoolExecutor):
counter = 0
in_deserialize = asyncio.Event()
wait_in_deserialize = asyncio.Event()

def submit(self, *args, **kwargs):
CountingThreadPool.counter += 1
return super().submit(*args, **kwargs)
async def custom_worker_offload(func, *args):
res = func(*args)
if not istask(args) and istask(res):
in_deserialize.set()
await wait_in_deserialize.wait()
return res

# Ensure we're always offloading
monkeypatch.setattr("distributed.worker.OFFLOAD_THRESHOLD", 1)
threadpool = CountingThreadPool(
max_workers=1, thread_name_prefix="Counting-Offload-Threadpool"
)
try:
monkeypatch.setattr("distributed.utils._offload_executor", threadpool)

class SlowDeserializeCallable:
def __init__(self, delay=0.1):
self.delay = delay

def __getstate__(self):
return self.delay

def __setstate__(self, state):
delay = state
import time

time.sleep(delay)
return SlowDeserializeCallable(delay)

def __call__(self, *args, **kwargs):
return 41

slow_deserialized_func = SlowDeserializeCallable()
fut = c.submit(
slow_deserialized_func, 1, workers=[a.address], allow_other_workers=True
)

while CountingThreadPool.counter == 0:
await asyncio.sleep(0)
monkeypatch.setattr("distributed.worker.offload", custom_worker_offload)
obj = randbytes(OFFLOAD_THRESHOLD + 1)
fut = c.submit(lambda _: 41, obj, workers=[a.address], allow_other_workers=True)
fjetter marked this conversation as resolved.
Show resolved Hide resolved

ts = s.tasks[fut.key]
a.handle_stimulus(StealRequestEvent(key=fut.key, stimulus_id="test"))
stealing_ext.scheduler.send_task_to_worker(b.address, ts)
await in_deserialize.wait()
ts = s.tasks[fut.key]
a.handle_stimulus(StealRequestEvent(key=fut.key, stimulus_id="test"))
stealing_ext.scheduler.send_task_to_worker(b.address, ts)

fut2 = c.submit(inc, fut, workers=[a.address])
fut3 = c.submit(inc, fut2, workers=[a.address])

assert await fut2 == 42
await fut3

finally:
threadpool.shutdown()
Comment on lines -2839 to -2840
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

somehow, this actually shutdown the actual offload threadpool, not just the mock, i.e. nothing in our test suite was using the offloader threadpool after this test ran 🤯

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not entirely understand why this is shutting down the actual threadpool but I don't care. I removed the mock and it works now

fut2 = c.submit(inc, fut, workers=[a.address])
fut3 = c.submit(inc, fut2, workers=[a.address])
wait_in_deserialize.set()
assert await fut2 == 42
await fut3


@gen_cluster(client=True)
Expand Down Expand Up @@ -3768,3 +3741,30 @@ def print_stderr(*args, **kwargs):

assert "" == out
assert "" == err


class EnsureOffloaded:
def __init__(self, main_thread_id):
self.main_thread_id = main_thread_id
self.data = randbytes(OFFLOAD_THRESHOLD + 1)

def __sizeof__(self):
return len(self.data)

def __getstate__(self):
assert self.main_thread_id
assert self.main_thread_id != threading.get_ident()
return (self.data, self.main_thread_id, threading.get_ident())

def __setstate__(self, state):
_, main_thread, serialize_thread = state
assert main_thread != threading.get_ident()
return EnsureOffloaded(main_thread)


@gen_cluster(client=True)
async def test_offload_getdata(c, s, a, b):
"""Test that functions wrapped by offload() are metered"""
x = c.submit(EnsureOffloaded, threading.get_ident(), key="x", workers=[a.address])
y = c.submit(lambda x: None, x, key="y", workers=[b.address])
await y