Skip to content

Commit 7732192

Browse files
authored
Optimize DataFrameIsin's tile (mars-project#2864)
1 parent 75e1c88 commit 7732192

File tree

14 files changed

+295
-339
lines changed

14 files changed

+295
-339
lines changed

mars/dataframe/arithmetic/bitwise_or.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
import operator
1616

1717
from ... import opcodes as OperandDef
18-
from ...utils import classproperty
19-
from .core import DataFrameBinopUfunc
18+
from ...utils import classproperty, TreeReductionBuilder
19+
from .core import DataFrameBinopUfunc, DataFrameArithmeticTreeMixin
2020

2121

2222
class DataFrameOr(DataFrameBinopUfunc):
@@ -36,6 +36,10 @@ def tensor_op_type(self):
3636
return TensorBitor
3737

3838

39+
class DataFrameTreeOr(DataFrameArithmeticTreeMixin, DataFrameOr):
40+
_op_type_ = OperandDef.TREE_OR
41+
42+
3943
def bitor(df, other, axis="columns", level=None, fill_value=None):
4044
op = DataFrameOr(axis=axis, level=level, fill_value=fill_value, lhs=df, rhs=other)
4145
return op(df, other)
@@ -44,3 +48,21 @@ def bitor(df, other, axis="columns", level=None, fill_value=None):
4448
def rbitor(df, other, axis="columns", level=None, fill_value=None):
4549
op = DataFrameOr(axis=axis, level=level, fill_value=fill_value, lhs=other, rhs=df)
4650
return op.rcall(df, other)
51+
52+
53+
def tree_dataframe_or(
54+
*args, index=None, combine_size=None, axis="columns", level=None, fill_value=None
55+
):
56+
class MultiplyBuilder(TreeReductionBuilder):
57+
def _build_reduction(self, inputs, final=False):
58+
op = DataFrameTreeOr(
59+
axis=axis,
60+
level=level,
61+
fill_value=fill_value,
62+
output_types=inputs[0].op.output_types,
63+
)
64+
params = inputs[0].params.copy()
65+
params["index"] = index
66+
return op.new_chunk(inputs, **params)
67+
68+
return MultiplyBuilder(combine_size).build(args)

mars/dataframe/arithmetic/core.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import itertools
1616
import copy
17+
from functools import reduce
1718

1819
import numpy as np
1920
import pandas as pd
@@ -618,6 +619,8 @@ def _new_chunks(self, inputs, kws=None, **kw):
618619
inp, (DATAFRAME_CHUNK_TYPE, SERIES_CHUNK_TYPE, TENSOR_CHUNK_TYPE)
619620
)
620621
]
622+
# use first two to infer(for tree operand)
623+
property_inputs = property_inputs[:2]
621624
if len(property_inputs) == 1:
622625
properties = self._calc_properties(*property_inputs)
623626
elif any(inp.ndim == 2 for inp in property_inputs):
@@ -871,6 +874,17 @@ def __call__(self, df):
871874
)
872875

873876

877+
class DataFrameArithmeticTreeMixin:
878+
@classmethod
879+
def execute(cls, ctx, op):
880+
inputs = [ctx[c.key] for c in op.inputs]
881+
ctx[op.outputs[0].key] = reduce(op._operator, inputs)
882+
883+
def _set_inputs(self, inputs):
884+
inputs = self._get_inputs_data(inputs)
885+
setattr(self, "_inputs", inputs)
886+
887+
874888
class DataFrameUnaryUfunc(DataFrameUnaryOp, TensorUfuncMixin):
875889
pass
876890

mars/dataframe/base/isin.py

Lines changed: 91 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -12,58 +12,48 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import itertools
16+
1517
import numpy as np
1618
import pandas as pd
1719
from pandas.api.types import is_list_like
1820

1921
from ... import opcodes as OperandDef
20-
from ...core import ENTITY_TYPE, recursive_tile
22+
from ...core import ENTITY_TYPE
2123
from ...serialization.serializables import KeyField, AnyField
2224
from ...tensor.core import TENSOR_TYPE
23-
from ...utils import has_unknown_shape
2425
from ..core import DATAFRAME_TYPE, SERIES_TYPE, INDEX_TYPE
2526
from ..operands import DataFrameOperand, DataFrameOperandMixin
2627

2728

2829
class DataFrameIsin(DataFrameOperand, DataFrameOperandMixin):
2930
_op_type_ = OperandDef.ISIN
3031

31-
_input = KeyField("input")
32-
_values = AnyField("values")
33-
34-
def __init__(self, values=None, output_types=None, **kw):
35-
super().__init__(_values=values, _output_types=output_types, **kw)
36-
37-
@property
38-
def input(self):
39-
return self._input
40-
41-
@property
42-
def values(self):
43-
return self._values
32+
input = KeyField("input")
33+
values = AnyField("values")
4434

4535
def _set_inputs(self, inputs):
4636
super()._set_inputs(inputs)
4737
inputs_iter = iter(self._inputs)
48-
self._input = next(inputs_iter)
38+
self.input = next(inputs_iter)
4939
if len(self._inputs) > 1:
50-
if isinstance(self._values, dict):
40+
if isinstance(self.values, dict):
5141
new_values = dict()
52-
for k, v in self._values.items():
42+
for k, v in self.values.items():
5343
if isinstance(v, ENTITY_TYPE):
5444
new_values[k] = next(inputs_iter)
5545
else:
5646
new_values[k] = v
57-
self._values = new_values
47+
self.values = new_values
5848
else:
59-
self._values = self._inputs[1]
49+
self.values = self._inputs[1]
6050

6151
def __call__(self, elements):
6252
inputs = [elements]
63-
if isinstance(self._values, ENTITY_TYPE):
64-
inputs.append(self._values)
65-
elif isinstance(self._values, dict):
66-
for v in self._values.values():
53+
if isinstance(self.values, ENTITY_TYPE):
54+
inputs.append(self.values)
55+
elif isinstance(self.values, dict):
56+
for v in self.values.values():
6757
if isinstance(v, ENTITY_TYPE):
6858
inputs.append(v)
6959

@@ -87,47 +77,63 @@ def __call__(self, elements):
8777
dtypes=dtypes,
8878
)
8979

80+
@classmethod
81+
def _tile_entity_values(cls, op):
82+
from ..utils import auto_merge_chunks
83+
from ..arithmetic.bitwise_or import tree_dataframe_or
84+
from ...core.context import get_context
85+
86+
in_elements = op.input
87+
out_elements = op.outputs[0]
88+
# values contains mars objects
89+
chunks_list = []
90+
in_chunks = in_elements.chunks
91+
if any(len(t.chunks) > 4 for t in op.inputs):
92+
# yield and merge value chunks to reduce graph nodes
93+
yield list(
94+
itertools.chain.from_iterable(
95+
t.chunks for t in op.inputs if isinstance(t, ENTITY_TYPE)
96+
)
97+
)
98+
in_elements = auto_merge_chunks(get_context(), op.input)
99+
in_chunks = in_elements.chunks
100+
for value in op.inputs[1:]:
101+
if isinstance(value, DATAFRAME_TYPE + SERIES_TYPE):
102+
merged = auto_merge_chunks(get_context(), value)
103+
chunks_list.append(merged.chunks)
104+
elif isinstance(value, ENTITY_TYPE):
105+
chunks_list.append(value.chunks)
106+
else:
107+
for value in op.inputs[1:]:
108+
if isinstance(value, ENTITY_TYPE):
109+
chunks_list.append(value.chunks)
110+
111+
out_chunks = []
112+
for in_chunk in in_chunks:
113+
isin_chunks = []
114+
for value_chunks in itertools.product(*chunks_list):
115+
input_chunks = [in_chunk] + list(value_chunks)
116+
isin_chunks.append(cls._new_chunk(op, in_chunk, input_chunks))
117+
out_chunk = tree_dataframe_or(*isin_chunks, index=in_chunk.index)
118+
out_chunks.append(out_chunk)
119+
120+
new_op = op.copy()
121+
params = out_elements.params
122+
params["nsplits"] = in_elements.nsplits
123+
params["chunks"] = out_chunks
124+
return new_op.new_tileables(op.inputs, kws=[params])
125+
90126
@classmethod
91127
def tile(cls, op):
92128
in_elements = op.input
93129
out_elements = op.outputs[0]
94130

95-
values_inputs = []
96131
if len(op.inputs) > 1:
97-
for value in op.inputs[1:]:
98-
# make sure arg has known shape when it's a md.Series
99-
if has_unknown_shape(value):
100-
yield
101-
value = yield from recursive_tile(value.rechunk(value.shape))
102-
values_inputs.append(value)
132+
return (yield from cls._tile_entity_values(op))
103133

104134
out_chunks = []
105135
for chunk in in_elements.chunks:
106-
chunk_op = op.copy().reset_key()
107-
chunk_inputs = [chunk]
108-
if len(op.inputs) > 1:
109-
chunk_inputs.extend(v.chunks[0] for v in values_inputs)
110-
if out_elements.ndim == 1:
111-
out_chunk = chunk_op.new_chunk(
112-
chunk_inputs,
113-
shape=chunk.shape,
114-
dtype=out_elements.dtype,
115-
index_value=chunk.index_value,
116-
name=out_elements.name,
117-
index=chunk.index,
118-
)
119-
else:
120-
chunk_dtypes = pd.Series(
121-
[np.dtype(bool) for _ in chunk.dtypes], index=chunk.dtypes.index
122-
)
123-
out_chunk = chunk_op.new_chunk(
124-
chunk_inputs,
125-
shape=chunk.shape,
126-
index_value=chunk.index_value,
127-
columns_value=chunk.columns_value,
128-
dtypes=chunk_dtypes,
129-
index=chunk.index,
130-
)
136+
out_chunk = cls._new_chunk(op, chunk, [chunk])
131137
out_chunks.append(out_chunk)
132138

133139
new_op = op.copy()
@@ -136,6 +142,33 @@ def tile(cls, op):
136142
params["chunks"] = out_chunks
137143
return new_op.new_tileables(op.inputs, kws=[params])
138144

145+
@classmethod
146+
def _new_chunk(cls, op, chunk, input_chunks):
147+
out_elements = op.outputs[0]
148+
chunk_op = op.copy().reset_key()
149+
if out_elements.ndim == 1:
150+
out_chunk = chunk_op.new_chunk(
151+
input_chunks,
152+
shape=chunk.shape,
153+
dtype=out_elements.dtype,
154+
index_value=chunk.index_value,
155+
name=out_elements.name,
156+
index=chunk.index,
157+
)
158+
else:
159+
chunk_dtypes = pd.Series(
160+
[np.dtype(bool) for _ in chunk.dtypes], index=chunk.dtypes.index
161+
)
162+
out_chunk = chunk_op.new_chunk(
163+
input_chunks,
164+
shape=chunk.shape,
165+
index_value=chunk.index_value,
166+
columns_value=chunk.columns_value,
167+
dtypes=chunk_dtypes,
168+
index=chunk.index,
169+
)
170+
return out_chunk
171+
139172
@classmethod
140173
def execute(cls, ctx, op):
141174
inputs_iter = iter(op.inputs)
@@ -222,7 +255,7 @@ def series_isin(elements, values):
222255
"only list-like objects are allowed to be passed to isin(), "
223256
f"you passed a [{type(values)}]"
224257
)
225-
op = DataFrameIsin(values)
258+
op = DataFrameIsin(values=values)
226259
return op(elements)
227260

228261

@@ -298,5 +331,5 @@ def df_isin(df, values):
298331
"only list-like objects or dict are allowed to be passed to isin(), "
299332
f"you passed a [{type(values)}]"
300333
)
301-
op = DataFrameIsin(values)
334+
op = DataFrameIsin(values=values)
302335
return op(df)

mars/dataframe/base/tests/test_base.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -681,39 +681,39 @@ def test_series_isin():
681681
assert c.op.inputs[0].index == (i,)
682682
assert c.op.inputs[0].shape == (10,)
683683
assert c.op.inputs[1].index == (0,)
684-
assert c.op.inputs[1].shape == (4,) # has been rechunked
684+
assert c.op.inputs[1].shape == (10,)
685685

686686
# multiple chunk in one chunks
687-
a = from_pandas_series(pd.Series([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), chunk_size=2)
687+
a = from_pandas_series(pd.Series([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), chunk_size=5)
688688
b = from_pandas_series(pd.Series([2, 1, 9, 3]), chunk_size=4)
689689

690690
r = tile(a.isin(b))
691691
for i, c in enumerate(r.chunks):
692692
assert c.index == (i,)
693693
assert c.dtype == np.dtype("bool")
694-
assert c.shape == (2,)
694+
assert c.shape == (5,)
695695
assert len(c.op.inputs) == 2
696696
assert c.op.output_types[0] == OutputType.series
697697
assert c.op.inputs[0].index == (i,)
698-
assert c.op.inputs[0].shape == (2,)
698+
assert c.op.inputs[0].shape == (5,)
699699
assert c.op.inputs[1].index == (0,)
700700
assert c.op.inputs[1].shape == (4,)
701701

702702
# multiple chunk in multiple chunks
703-
a = from_pandas_series(pd.Series([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), chunk_size=2)
703+
a = from_pandas_series(pd.Series([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]), chunk_size=5)
704704
b = from_pandas_series(pd.Series([2, 1, 9, 3]), chunk_size=2)
705705

706706
r = tile(a.isin(b))
707707
for i, c in enumerate(r.chunks):
708708
assert c.index == (i,)
709709
assert c.dtype == np.dtype("bool")
710-
assert c.shape == (2,)
710+
assert c.shape == (5,)
711711
assert len(c.op.inputs) == 2
712712
assert c.op.output_types[0] == OutputType.series
713713
assert c.op.inputs[0].index == (i,)
714-
assert c.op.inputs[0].shape == (2,)
715-
assert c.op.inputs[1].index == (0,)
716-
assert c.op.inputs[1].shape == (4,) # has been rechunked
714+
assert c.op.inputs[0].shape == (5,)
715+
assert c.op.inputs[1].index == (i,)
716+
assert c.op.inputs[1].shape == (5,)
717717

718718
with pytest.raises(TypeError):
719719
_ = a.isin("sth")

mars/dataframe/base/tests/test_base_execution.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,7 @@ def test_isin_execution(setup):
702702

703703
# multiple chunk in multiple chunks
704704
a = pd.Series([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
705-
b = pd.Series([2, 1, 9, 3])
705+
b = pd.Series([2, 1, 9, 3] * 2)
706706
sa = from_pandas_series(a, chunk_size=2)
707707
sb = from_pandas_series(b, chunk_size=2)
708708

@@ -747,7 +747,17 @@ def test_isin_execution(setup):
747747
pd.testing.assert_frame_equal(result, expected)
748748

749749
# mars object
750-
b = tensor([2, 1, raw[1][0]], chunk_size=2)
750+
b = tensor([2, 1, raw[1][0]] * 2, chunk_size=2)
751+
r = df.isin(b)
752+
result = r.execute().fetch()
753+
expected = raw.isin([2, 1, raw[1][0]])
754+
pd.testing.assert_frame_equal(result, expected)
755+
756+
# mars object and trigger iterative tiling
757+
raw = pd.DataFrame(rs.randint(1000, size=(10, 3)))
758+
df = from_pandas_df(raw, chunk_size=(5, 2))
759+
760+
b = from_pandas_series(pd.Series([raw[1][0]] + list(range(9))), chunk_size=2)
751761
r = df.isin(b)
752762
result = r.execute().fetch()
753763
expected = raw.isin([2, 1, raw[1][0]])

mars/dataframe/core.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,8 @@ def refresh_index_value(tileable: ENTITY_TYPE):
433433
index_value._index_value.should_be_monotonic = getattr(
434434
tileable.index_value, "should_be_monotonic", None
435435
)
436+
# keep key as original index_value's
437+
index_value._index_value._key = tileable.index_value.key
436438
tileable._index_value = index_value
437439

438440

0 commit comments

Comments
 (0)