Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure that serialized data is measured correctly #7593

Merged
2 changes: 1 addition & 1 deletion distributed/comm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import socket

import dask
from dask.sizeof import sizeof
from dask.utils import parse_bytes

from distributed import protocol
from distributed.sizeof import safe_sizeof as sizeof
from distributed.utils import get_ip, get_ipv6, nbytes, offload

logger = logging.getLogger(__name__)
Expand Down
13 changes: 13 additions & 0 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import dask
from dask.base import normalize_token
from dask.sizeof import sizeof
from dask.utils import typename

from distributed.protocol import pickle
Expand Down Expand Up @@ -639,6 +640,18 @@ def replace_inner(x):
return replace_inner(x)


@sizeof.register(ToPickle)
@sizeof.register(Serialize)
def sizeof_serialize(obj):
return sizeof(obj.data)


@sizeof.register(Pickled)
@sizeof.register(Serialized)
def sizeof_serialized(obj):
return sizeof(obj.header) + sizeof(obj.frames)


def serialize_bytelist(x, **kwargs):
header, frames = serialize_and_split(x, **kwargs)
if frames:
Expand Down
19 changes: 19 additions & 0 deletions distributed/protocol/tests/test_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import dask
from dask.sizeof import sizeof

from distributed.protocol import dumps, loads, maybe_compress, msgpack, to_serialize
from distributed.protocol.compression import (
Expand All @@ -15,8 +16,10 @@
)
from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
from distributed.protocol.serialize import (
Pickled,
Serialize,
Serialized,
ToPickle,
dask_deserialize,
dask_serialize,
deserialize,
Expand Down Expand Up @@ -393,3 +396,19 @@ def _(header, frames):
header, frames = serialize(MyObj(), serializers=serializers)
o = deserialize(header, frames)
assert isinstance(o, MyObj)


@pytest.mark.parametrize(
"Wrapper, Wrapped",
[
(Serialize, Serialized),
(to_serialize, Serialized),
(ToPickle, Pickled),
],
)
def test_sizeof_serialize(Wrapper, Wrapped):
size = 100_000
ser_obj = Wrapper(b"0" * size)
assert size <= sizeof(ser_obj) < size * 1.05
serialized = Wrapped(*serialize(ser_obj))
assert size <= sizeof(serialized) < size * 1.05
108 changes: 62 additions & 46 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from tornado.ioloop import IOLoop

import dask
from dask import delayed
from dask import delayed, istask
from dask.system import CPU_COUNT
from dask.utils import tmpfile

Expand All @@ -42,7 +42,8 @@
wait,
)
from distributed.comm.registry import backends
from distributed.compatibility import LINUX, WINDOWS, to_thread
from distributed.comm.utils import OFFLOAD_THRESHOLD
from distributed.compatibility import LINUX, WINDOWS, randbytes, to_thread
from distributed.core import CommClosedError, Status, rpc
from distributed.diagnostics import nvml
from distributed.diagnostics.plugin import (
Expand Down Expand Up @@ -2780,64 +2781,52 @@ async def test_forget_dependents_after_release(c, s, a):
assert fut2.key not in {d.key for d in a.state.tasks[fut.key].dependents}


@pytest.mark.filterwarnings("ignore:Large object of size")
@gen_cluster(client=True)
async def test_steal_during_task_deserialization(c, s, a, b, monkeypatch):
stealing_ext = s.extensions["stealing"]
await stealing_ext.stop()
from distributed.utils import ThreadPoolExecutor

class CountingThreadPool(ThreadPoolExecutor):
counter = 0

def submit(self, *args, **kwargs):
CountingThreadPool.counter += 1
return super().submit(*args, **kwargs)

# Ensure we're always offloading
monkeypatch.setattr("distributed.worker.OFFLOAD_THRESHOLD", 1)
threadpool = CountingThreadPool(
max_workers=1, thread_name_prefix="Counting-Offload-Threadpool"
)
try:
monkeypatch.setattr("distributed.utils._offload_executor", threadpool)
class SlowDeserializeCallable:
def __init__(self):
self.data = b"0" * (OFFLOAD_THRESHOLD + 1)
fjetter marked this conversation as resolved.
Show resolved Hide resolved

class SlowDeserializeCallable:
def __init__(self, delay=0.1):
self.delay = delay
def __getstate__(self):
return self.data

def __getstate__(self):
return self.delay
def __setstate__(self, state):
return SlowDeserializeCallable()

def __setstate__(self, state):
delay = state
import time
def __sizeof__(self):
return OFFLOAD_THRESHOLD * 2

time.sleep(delay)
return SlowDeserializeCallable(delay)
def __call__(self, *args, **kwargs):
return 41

def __call__(self, *args, **kwargs):
return 41
in_deserialize = asyncio.Event()
wait_in_deserialize = asyncio.Event()

slow_deserialized_func = SlowDeserializeCallable()
fut = c.submit(
slow_deserialized_func, 1, workers=[a.address], allow_other_workers=True
)
async def custom_worker_offload(func, *args):
res = func(*args)
if not istask(args) and istask(res):
in_deserialize.set()
await wait_in_deserialize.wait()
return res

while CountingThreadPool.counter == 0:
await asyncio.sleep(0)
monkeypatch.setattr("distributed.worker.offload", custom_worker_offload)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test logic is now slightly different but I believe more robust. We don't truly care about the offloading part but rather that there is an await during deserialization. Therefore, I'll only patch the offload method. To ensure that we're truly in the task spec deserialization I put in the istask guard above

obj = SlowDeserializeCallable()
fut = c.submit(lambda _: 41, obj, workers=[a.address], allow_other_workers=True)
fjetter marked this conversation as resolved.
Show resolved Hide resolved

ts = s.tasks[fut.key]
a.handle_stimulus(StealRequestEvent(key=fut.key, stimulus_id="test"))
stealing_ext.scheduler.send_task_to_worker(b.address, ts)
await in_deserialize.wait()
ts = s.tasks[fut.key]
a.handle_stimulus(StealRequestEvent(key=fut.key, stimulus_id="test"))
stealing_ext.scheduler.send_task_to_worker(b.address, ts)

fut2 = c.submit(inc, fut, workers=[a.address])
fut3 = c.submit(inc, fut2, workers=[a.address])

assert await fut2 == 42
await fut3

finally:
threadpool.shutdown()
Comment on lines -2839 to -2840
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

somehow, this actually shutdown the actual offload threadpool, not just the mock, i.e. nothing in our test suite was using the offloader threadpool after this test ran 🤯

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not entirely understand why this is shutting down the actual threadpool but I don't care. I removed the mock and it works now

fut2 = c.submit(inc, fut, workers=[a.address])
fut3 = c.submit(inc, fut2, workers=[a.address])
wait_in_deserialize.set()
assert await fut2 == 42
await fut3


@gen_cluster(client=True)
Expand Down Expand Up @@ -3768,3 +3757,30 @@ def print_stderr(*args, **kwargs):

assert "" == out
assert "" == err


class EnsureOffloaded:
def __init__(self, main_thread_id):
self.main_thread_id = main_thread_id
self.data = randbytes(OFFLOAD_THRESHOLD + 1)

def __sizeof__(self):
return len(self.data)

def __getstate__(self):
assert self.main_thread_id
assert self.main_thread_id != threading.get_ident()
return (self.data, self.main_thread_id, threading.get_ident())

def __setstate__(self, state):
_, main_thread, serialize_thread = state
assert main_thread != threading.get_ident()
return EnsureOffloaded(main_thread)


@gen_cluster(client=True)
async def test_offload_getdata(c, s, a, b):
"""Test that functions wrapped by offload() are metered"""
x = c.submit(EnsureOffloaded, threading.get_ident(), key="x", workers=[a.address])
y = c.submit(lambda x: None, x, key="y", workers=[b.address])
await y
1 change: 0 additions & 1 deletion distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1397,7 +1397,6 @@ def is_valid_xml(text):


_offload_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Dask-Offload")
weakref.finalize(_offload_executor, _offload_executor.shutdown)
Copy link
Member Author

@fjetter fjetter Feb 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just an attempt to fix the error I'm seeing

Two reasons why I believe this should be removed regardless of whether this is a fix or not

  1. All python versions 3.8+ are ensuring that worker threads are terminating on interpreter shutdown already. They explicitly handle the case of collected executors, interpreter shutdown and instance shutdown identically.
  2. Judging by the finalize docs I'm not even sure if this callback is ever triggered

A finalizer will never invoke its callback during the later part of the interpreter shutdown when module globals are liable to have been replaced by None.

since _offload_executor is a module global and unless it has been replaced by None, it has no chance of being GCed/finalized

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not responsible after all but I still suggest to remove this line

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Judging by the finalize docs I'm not even sure if this callback is ever triggered

It's interesting, later in the docs it states:

Note It is important to ensure that func, args and kwargs do not own any references to obj, either directly or indirectly, since otherwise obj will never be garbage collected. In particular, func should not be a bound method of obj.

Notably the last part, which I suppose is part of the unneeded bit of this line.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, that note is interesting as well, i.e. this finalize is useless for many reasons

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened #7639

Our code base is riddled with this pattern

Copy link
Collaborator

@crusaderky crusaderky Mar 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one line makes me very nervous. Everything you wrote makes perfect sense, but just in case you're wrong, could you move it to its own PR so that it's not ending up in the release? It has the potential of leaving workers stuck on shutdown, and it also has the potential of different behaviour on different Python versions and on different OSs, so I believe some thorough testing is in order.



def import_term(name: str) -> AnyType:
Expand Down