Skip to content

Commit

Permalink
use randbytes compat
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Mar 9, 2023
1 parent 9584b28 commit 4fe360d
Showing 1 changed file with 20 additions and 18 deletions.
38 changes: 20 additions & 18 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
wait,
)
from distributed.comm.registry import backends
from distributed.compatibility import LINUX, WINDOWS, to_thread
from distributed.comm.utils import OFFLOAD_THRESHOLD
from distributed.compatibility import LINUX, WINDOWS, randbytes, to_thread
from distributed.core import CommClosedError, Status, rpc
from distributed.diagnostics import nvml
from distributed.diagnostics.plugin import (
Expand Down Expand Up @@ -3770,27 +3771,28 @@ def print_stderr(*args, **kwargs):
assert "" == err


@gen_cluster(client=True)
async def test_offload_getdata(c, s, a, b):
"""Test that functions wrapped by offload() are metered"""
import random
import threading
class EnsureOffloaded:
def __init__(self, main_thread_id):
self.main_thread_id = main_thread_id
self.data = randbytes(OFFLOAD_THRESHOLD + 1)

# TODO: this is not a real test, yet
print("main thread", threading.get_ident())
n = 200_000_000
def __sizeof__(self):
return len(self.data)

class C:
def __sizeof__(self):
return n
def __getstate__(self):
assert self.main_thread_id
assert self.main_thread_id != threading.get_ident()
return (self.data, self.main_thread_id, threading.get_ident())

def __getstate__(self):
print("__getstate__", threading.get_ident())
return random.randbytes(n)
def __setstate__(self, state):
_, main_thread, serialize_thread = state
assert main_thread != threading.get_ident()
return EnsureOffloaded(main_thread)

def __setstate__(self, state):
print("__setstate__", threading.get_ident())

x = c.submit(C, key="x", workers=[a.address])
@gen_cluster(client=True)
async def test_offload_getdata(c, s, a, b):
"""Test that functions wrapped by offload() are metered"""
x = c.submit(EnsureOffloaded, threading.get_ident(), key="x", workers=[a.address])
y = c.submit(lambda x: None, x, key="y", workers=[b.address])
await y

0 comments on commit 4fe360d

Please sign in to comment.