diff --git a/dask_cuda/device_host_file.py b/dask_cuda/device_host_file.py index a0fe92e8..e8d8bc08 100644 --- a/dask_cuda/device_host_file.py +++ b/dask_cuda/device_host_file.py @@ -1,4 +1,3 @@ -import functools import itertools import logging import os @@ -8,6 +7,7 @@ from zict import Buffer, File, Func from zict.common import ZictBase +import dask from distributed.protocol import ( dask_deserialize, dask_serialize, @@ -17,13 +17,24 @@ serialize_bytelist, ) from distributed.sizeof import safe_sizeof -from distributed.utils import nbytes +from distributed.utils import has_arg, nbytes from .is_device_object import is_device_object from .is_spillable_object import is_spillable_object from .utils import nvtx_annotate +def _serialize_bytelist(x, **kwargs): + kwargs["on_error"] = "raise" + + if has_arg(serialize_bytelist, "compression"): + compression = dask.config.get("distributed.worker.memory.spill-compression") + return serialize_bytelist(x, compression=compression, **kwargs) + else: + # For Distributed < 2023.5.0 compatibility + return serialize_bytelist(x, **kwargs) + + class LoggedBuffer(Buffer): """Extends zict.Buffer with logging capabilities @@ -192,7 +203,7 @@ def __init__( self.host_func = dict() self.disk_func = Func( - functools.partial(serialize_bytelist, on_error="raise"), + _serialize_bytelist, deserialize_bytes, File(self.disk_func_path), ) diff --git a/dask_cuda/tests/test_spill.py b/dask_cuda/tests/test_spill.py index d795f8f8..cd36cb78 100644 --- a/dask_cuda/tests/test_spill.py +++ b/dask_cuda/tests/test_spill.py @@ -220,6 +220,7 @@ async def test_cudf_cluster_device_spill(params): { "distributed.comm.compression": False, "distributed.worker.memory.terminate": False, + "distributed.worker.memory.spill-compression": False, } ): async with LocalCUDACluster(