Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup mars deserialization by __new__ #3283

Merged
merged 6 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mars/core/entity/tileables.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,11 @@ def __init__(self: TileableType, *args, **kwargs):
self._chunks = sorted(chunks, key=attrgetter("index"))
except AttributeError: # pragma: no cover
pass
self._entities = WeakSet()
self._executed_sessions = []

def __on_deserialize__(self):
super(TileableData, self).__on_deserialize__()
self._entities = WeakSet()
self._executed_sessions = []

Expand Down
12 changes: 11 additions & 1 deletion mars/dataframe/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,7 +1814,7 @@ def median(

class BaseDataFrameChunkData(LazyMetaChunkData):
__slots__ = ("_dtypes_value",)
_no_copy_attrs_ = ChunkData._no_copy_attrs_ | {"_dtypes"}
_no_copy_attrs_ = ChunkData._no_copy_attrs_ | {"_dtypes", "_columns_value"}

# required fields
_shape = TupleField(
Expand Down Expand Up @@ -1851,6 +1851,10 @@ def __init__(
)
self._dtypes_value = None

def __on_deserialize__(self):
super(BaseDataFrameChunkData, self).__on_deserialize__()
self._dtypes_value = None

def __len__(self):
return self.shape[0]

Expand Down Expand Up @@ -1992,6 +1996,12 @@ def __init__(
self._dtypes_value = None
self._dtypes_dict = None

def __on_deserialize__(self):
super().__on_deserialize__()
self._accessors = dict()
self._dtypes_value = None
self._dtypes_dict = None

def _get_params(self) -> Dict[str, Any]:
# params return the properties which useful to rebuild a new tileable object
return {
Expand Down
16 changes: 11 additions & 5 deletions mars/serialization/serializables/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

import cloudpickle

from ...utils import no_default
from ..core import Serializer, Placeholder, buffered
from .field import Field
from .field_type import (
Expand Down Expand Up @@ -144,6 +143,9 @@ def __init__(self, *args, **kwargs):
for k, v in values.items():
fields[k].set(self, v)

def __on_deserialize__(self):
pass

def __repr__(self):
values = ", ".join(
[
Expand All @@ -169,6 +171,10 @@ def copy(self) -> "Serializable":
_primitive_serial_cache = weakref.WeakKeyDictionary()


class _NoFieldValue:
pass


class SerializableSerializer(Serializer):
"""
Leverage DictSerializer to perform serde.
Expand All @@ -184,7 +190,7 @@ def _get_field_values(cls, obj: Serializable, fields):
value = field.on_serialize(value)
except AttributeError:
# Most field values are not None, serialize by list is more efficient than dict.
value = no_default
value = _NoFieldValue
values.append(value)
return values

Expand All @@ -203,7 +209,7 @@ def serial(self, obj: Serializable, context: Dict):

@staticmethod
def _set_field_value(obj: Serializable, field: Field, value):
if value is no_default:
if value is _NoFieldValue:
return
if type(value) is Placeholder:
if field.on_deserialize is not None:
Expand All @@ -224,7 +230,7 @@ def deserial(self, serialized: Tuple, context: Dict, subs: List) -> Serializable
if type(primitives) is not list:
primitives = cloudpickle.loads(primitives)

obj = obj_class()
obj = obj_class.__new__(obj_class)

if primitives:
for field, value in zip(obj_class._PRIMITIVE_FIELDS, primitives):
Expand All @@ -233,7 +239,7 @@ def deserial(self, serialized: Tuple, context: Dict, subs: List) -> Serializable
if obj_class._NON_PRIMITIVE_FIELDS:
for field, value in zip(obj_class._NON_PRIMITIVE_FIELDS, subs[0]):
self._set_field_value(obj, field, value)

obj.__on_deserialize__()
return obj


Expand Down
5 changes: 5 additions & 0 deletions mars/services/subtask/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ def __init__(
self._pure_depend_keys = None
self._repr = None

def __on_deserialize__(self):
super(Subtask, self).__on_deserialize__()
self._pure_depend_keys = None
self._repr = None

@property
def expect_band(self):
if self.expect_bands:
Expand Down
1 change: 1 addition & 0 deletions mars/tensor/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class TensorOrder(Enum):

class TensorChunkData(ChunkData):
__slots__ = ()
_no_copy_attrs_ = ChunkData._no_copy_attrs_ | {"dtype"}
type_name = "Tensor"

# required fields
Expand Down