Skip to content

Commit 86f06d8

Browse files
authored
Optimize serialization performance (mars-project#2914)
1 parent 44afdb6 commit 86f06d8

File tree

8 files changed

+323
-106
lines changed

8 files changed

+323
-106
lines changed

asv_bench/benchmarks/serialize.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import cloudpickle
1516
import numpy as np
1617
import pandas as pd
1718

@@ -26,7 +27,17 @@
2627
NDArrayField,
2728
StringField,
2829
FieldTypes,
30+
BoolField,
31+
Int32Field,
32+
Float32Field,
33+
SliceField,
34+
Datetime64Field,
35+
Timedelta64Field,
36+
TupleField,
37+
DictField,
2938
)
39+
from mars.services.subtask import Subtask
40+
from mars.services.task import new_task_id
3041

3142

3243
class SerializableChild(Serializable):
@@ -45,6 +56,22 @@ class SerializableParent(Serializable):
4556
children = ListField("children", field_type=FieldTypes.reference)
4657

4758

59+
class MySerializable(Serializable):
60+
_bool_val = BoolField("f1")
61+
_int32_val = Int32Field("f2")
62+
_int64_val = Int64Field("f3")
63+
_float32_val = Float32Field("f4")
64+
_float64_val = Float64Field("f5")
65+
_string_val = StringField("f6")
66+
_datetime64_val = Datetime64Field("f7")
67+
_timedelta64_val = Timedelta64Field("f8")
68+
_datatype_val = DataTypeField("f9")
69+
_slice_val = SliceField("f10")
70+
_list_val = ListField("list_val", FieldTypes.int64)
71+
_tuple_val = TupleField("tuple_val", FieldTypes.string)
72+
_dict_val = DictField("dict_val", FieldTypes.string, FieldTypes.bytes)
73+
74+
4875
class SerializationSuite:
4976
def setup(self):
5077
children = []
@@ -63,5 +90,72 @@ def setup(self):
6390
children.append(child)
6491
self.test_data = SerializableParent(children=children)
6592

93+
self.subtasks = []
94+
for i in range(10000):
95+
subtask = Subtask(
96+
subtask_id=new_task_id(),
97+
stage_id=new_task_id(),
98+
logic_key=new_task_id(),
99+
session_id=new_task_id(),
100+
task_id=new_task_id(),
101+
chunk_graph=None,
102+
expect_bands=[
103+
("ray://mars_cluster_1649927648/17/0", "numa-0"),
104+
],
105+
bands_specified=False,
106+
virtual=False,
107+
priority=(1, 0),
108+
retryable=True,
109+
extra_config={},
110+
)
111+
self.subtasks.append(subtask)
112+
113+
self.test_basic_serializable = []
114+
for i in range(10000):
115+
my_serializable = MySerializable(
116+
_bool_val=True,
117+
_int32_val=-32,
118+
_int64_val=-64,
119+
_float32_val=np.float32(2.0),
120+
_float64_val=2.0,
121+
_complex64_val=np.complex64(1 + 2j),
122+
_complex128_val=1 + 2j,
123+
_string_val="string_value",
124+
_datetime64_val=pd.Timestamp(123),
125+
_timedelta64_val=pd.Timedelta(days=1),
126+
_datatype_val=np.dtype(np.int32),
127+
_slice_val=slice(1, 10, 2),
128+
_list_val=[1, 2],
129+
_tuple_val=("a", "b"),
130+
_dict_val={"a": b"bytes_value"},
131+
)
132+
self.test_basic_serializable.append(my_serializable)
133+
134+
self.test_list = list(range(100000))
135+
self.test_tuple = tuple(range(100000))
136+
self.test_dict = {i: i for i in range(100000)}
137+
66138
def time_serialize_deserialize(self):
67139
deserialize(*serialize(self.test_data))
140+
141+
def time_serialize_deserialize_basic(self):
142+
deserialize(*serialize(self.test_basic_serializable))
143+
144+
def time_pickle_serialize_deserialize_basic(self):
145+
deserialize(
146+
*cloudpickle.loads(
147+
cloudpickle.dumps(serialize(self.test_basic_serializable))
148+
)
149+
)
150+
151+
def time_pickle_serialize_deserialize_subtask(self):
152+
deserialize(*cloudpickle.loads(cloudpickle.dumps(serialize(self.subtasks))))
153+
154+
def time_pickle_serialize_deserialize_list(self):
155+
deserialize(*cloudpickle.loads(cloudpickle.dumps(serialize(self.test_list))))
156+
157+
def time_pickle_serialize_deserialize_tuple(self):
158+
deserialize(*cloudpickle.loads(cloudpickle.dumps(serialize(self.test_tuple))))
159+
160+
def time_pickle_serialize_deserialize_dict(self):
161+
deserialize(*cloudpickle.loads(cloudpickle.dumps(serialize(self.test_dict))))

mars/core/operand/base.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,10 @@ class Operand(Base, OperatorLogicKeyGeneratorMixin, metaclass=OperandMetaclass):
181181
_inputs = ListField(
182182
"inputs", FieldTypes.reference(EntityData), default_factory=list
183183
)
184-
_outputs = ListField("outputs", default=None)
184+
# outputs are weak-refs which are not pickle-able
185+
_outputs = ListField(
186+
"outputs", default=None, on_serialize=lambda outputs: [o() for o in outputs]
187+
)
185188
_output_types = ListField(
186189
"output_type", FieldTypes.reference(OutputType), default=None
187190
)
@@ -327,13 +330,6 @@ def on_input_modify(self, new_input):
327330
class OperandSerializer(SerializableSerializer):
328331
serializer_name = "operand"
329332

330-
@classmethod
331-
def _get_tag_to_values(cls, obj: Operand):
332-
tag_to_values = super()._get_tag_to_values(obj)
333-
# outputs are weak-refs which are not pickle-able
334-
tag_to_values["outputs"] = [out_ref() for out_ref in tag_to_values["outputs"]]
335-
return tag_to_values
336-
337333
def deserialize(self, header: Dict, buffers: List, context: Dict) -> Operand:
338334
# convert outputs back to weak-refs
339335
operand: Operand = (yield from super().deserialize(header, buffers, context))

mars/core/operand/fetch.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import cloudpickle
16+
1517
from ... import opcodes
18+
from ...serialization.core import cached_pickle_dumps
1619
from ...serialization.serializables import FieldTypes, StringField, ListField
1720
from .base import Operand
1821
from .core import TileableOperandMixin
@@ -44,6 +47,21 @@ def execute(cls, ctx, op):
4447
class FetchShuffle(Operand):
4548
_op_type_ = opcodes.FETCH_SHUFFLE
4649

47-
source_keys = ListField("source_keys", FieldTypes.string)
48-
source_idxes = ListField("source_idxes", FieldTypes.tuple(FieldTypes.uint64))
49-
source_mappers = ListField("source_mappers", FieldTypes.uint16)
50+
source_keys = ListField(
51+
"source_keys",
52+
FieldTypes.string,
53+
on_serialize=cached_pickle_dumps,
54+
on_deserialize=cloudpickle.loads,
55+
)
56+
source_idxes = ListField(
57+
"source_idxes",
58+
FieldTypes.tuple(FieldTypes.uint64),
59+
on_serialize=cached_pickle_dumps,
60+
on_deserialize=cloudpickle.loads,
61+
)
62+
source_mappers = ListField(
63+
"source_mappers",
64+
FieldTypes.uint16,
65+
on_serialize=cached_pickle_dumps,
66+
on_deserialize=cloudpickle.loads,
67+
)

0 commit comments

Comments
 (0)