1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ import itertools
16+
1517import numpy as np
1618import pandas as pd
1719from pandas .api .types import is_list_like
1820
1921from ... import opcodes as OperandDef
20- from ...core import ENTITY_TYPE , recursive_tile
22+ from ...core import ENTITY_TYPE
2123from ...serialization .serializables import KeyField , AnyField
2224from ...tensor .core import TENSOR_TYPE
23- from ...utils import has_unknown_shape
2425from ..core import DATAFRAME_TYPE , SERIES_TYPE , INDEX_TYPE
2526from ..operands import DataFrameOperand , DataFrameOperandMixin
2627
2728
2829class DataFrameIsin (DataFrameOperand , DataFrameOperandMixin ):
2930 _op_type_ = OperandDef .ISIN
3031
31- _input = KeyField ("input" )
32- _values = AnyField ("values" )
33-
34- def __init__ (self , values = None , output_types = None , ** kw ):
35- super ().__init__ (_values = values , _output_types = output_types , ** kw )
36-
37- @property
38- def input (self ):
39- return self ._input
40-
41- @property
42- def values (self ):
43- return self ._values
32+ input = KeyField ("input" )
33+ values = AnyField ("values" )
4434
4535 def _set_inputs (self , inputs ):
4636 super ()._set_inputs (inputs )
4737 inputs_iter = iter (self ._inputs )
48- self ._input = next (inputs_iter )
38+ self .input = next (inputs_iter )
4939 if len (self ._inputs ) > 1 :
50- if isinstance (self ._values , dict ):
40+ if isinstance (self .values , dict ):
5141 new_values = dict ()
52- for k , v in self ._values .items ():
42+ for k , v in self .values .items ():
5343 if isinstance (v , ENTITY_TYPE ):
5444 new_values [k ] = next (inputs_iter )
5545 else :
5646 new_values [k ] = v
57- self ._values = new_values
47+ self .values = new_values
5848 else :
59- self ._values = self ._inputs [1 ]
49+ self .values = self ._inputs [1 ]
6050
6151 def __call__ (self , elements ):
6252 inputs = [elements ]
63- if isinstance (self ._values , ENTITY_TYPE ):
64- inputs .append (self ._values )
65- elif isinstance (self ._values , dict ):
66- for v in self ._values .values ():
53+ if isinstance (self .values , ENTITY_TYPE ):
54+ inputs .append (self .values )
55+ elif isinstance (self .values , dict ):
56+ for v in self .values .values ():
6757 if isinstance (v , ENTITY_TYPE ):
6858 inputs .append (v )
6959
@@ -87,47 +77,63 @@ def __call__(self, elements):
8777 dtypes = dtypes ,
8878 )
8979
80+ @classmethod
81+ def _tile_entity_values (cls , op ):
82+ from ..utils import auto_merge_chunks
83+ from ..arithmetic .bitwise_or import tree_dataframe_or
84+ from ...core .context import get_context
85+
86+ in_elements = op .input
87+ out_elements = op .outputs [0 ]
88+ # values contains mars objects
89+ chunks_list = []
90+ in_chunks = in_elements .chunks
91+ if any (len (t .chunks ) > 4 for t in op .inputs ):
92+ # yield and merge value chunks to reduce graph nodes
93+ yield list (
94+ itertools .chain .from_iterable (
95+ t .chunks for t in op .inputs if isinstance (t , ENTITY_TYPE )
96+ )
97+ )
98+ in_elements = auto_merge_chunks (get_context (), op .input )
99+ in_chunks = in_elements .chunks
100+ for value in op .inputs [1 :]:
101+ if isinstance (value , DATAFRAME_TYPE + SERIES_TYPE ):
102+ merged = auto_merge_chunks (get_context (), value )
103+ chunks_list .append (merged .chunks )
104+ elif isinstance (value , ENTITY_TYPE ):
105+ chunks_list .append (value .chunks )
106+ else :
107+ for value in op .inputs [1 :]:
108+ if isinstance (value , ENTITY_TYPE ):
109+ chunks_list .append (value .chunks )
110+
111+ out_chunks = []
112+ for in_chunk in in_chunks :
113+ isin_chunks = []
114+ for value_chunks in itertools .product (* chunks_list ):
115+ input_chunks = [in_chunk ] + list (value_chunks )
116+ isin_chunks .append (cls ._new_chunk (op , in_chunk , input_chunks ))
117+ out_chunk = tree_dataframe_or (* isin_chunks , index = in_chunk .index )
118+ out_chunks .append (out_chunk )
119+
120+ new_op = op .copy ()
121+ params = out_elements .params
122+ params ["nsplits" ] = in_elements .nsplits
123+ params ["chunks" ] = out_chunks
124+ return new_op .new_tileables (op .inputs , kws = [params ])
125+
90126 @classmethod
91127 def tile (cls , op ):
92128 in_elements = op .input
93129 out_elements = op .outputs [0 ]
94130
95- values_inputs = []
96131 if len (op .inputs ) > 1 :
97- for value in op .inputs [1 :]:
98- # make sure arg has known shape when it's a md.Series
99- if has_unknown_shape (value ):
100- yield
101- value = yield from recursive_tile (value .rechunk (value .shape ))
102- values_inputs .append (value )
132+ return (yield from cls ._tile_entity_values (op ))
103133
104134 out_chunks = []
105135 for chunk in in_elements .chunks :
106- chunk_op = op .copy ().reset_key ()
107- chunk_inputs = [chunk ]
108- if len (op .inputs ) > 1 :
109- chunk_inputs .extend (v .chunks [0 ] for v in values_inputs )
110- if out_elements .ndim == 1 :
111- out_chunk = chunk_op .new_chunk (
112- chunk_inputs ,
113- shape = chunk .shape ,
114- dtype = out_elements .dtype ,
115- index_value = chunk .index_value ,
116- name = out_elements .name ,
117- index = chunk .index ,
118- )
119- else :
120- chunk_dtypes = pd .Series (
121- [np .dtype (bool ) for _ in chunk .dtypes ], index = chunk .dtypes .index
122- )
123- out_chunk = chunk_op .new_chunk (
124- chunk_inputs ,
125- shape = chunk .shape ,
126- index_value = chunk .index_value ,
127- columns_value = chunk .columns_value ,
128- dtypes = chunk_dtypes ,
129- index = chunk .index ,
130- )
136+ out_chunk = cls ._new_chunk (op , chunk , [chunk ])
131137 out_chunks .append (out_chunk )
132138
133139 new_op = op .copy ()
@@ -136,6 +142,33 @@ def tile(cls, op):
136142 params ["chunks" ] = out_chunks
137143 return new_op .new_tileables (op .inputs , kws = [params ])
138144
145+ @classmethod
146+ def _new_chunk (cls , op , chunk , input_chunks ):
147+ out_elements = op .outputs [0 ]
148+ chunk_op = op .copy ().reset_key ()
149+ if out_elements .ndim == 1 :
150+ out_chunk = chunk_op .new_chunk (
151+ input_chunks ,
152+ shape = chunk .shape ,
153+ dtype = out_elements .dtype ,
154+ index_value = chunk .index_value ,
155+ name = out_elements .name ,
156+ index = chunk .index ,
157+ )
158+ else :
159+ chunk_dtypes = pd .Series (
160+ [np .dtype (bool ) for _ in chunk .dtypes ], index = chunk .dtypes .index
161+ )
162+ out_chunk = chunk_op .new_chunk (
163+ input_chunks ,
164+ shape = chunk .shape ,
165+ index_value = chunk .index_value ,
166+ columns_value = chunk .columns_value ,
167+ dtypes = chunk_dtypes ,
168+ index = chunk .index ,
169+ )
170+ return out_chunk
171+
139172 @classmethod
140173 def execute (cls , ctx , op ):
141174 inputs_iter = iter (op .inputs )
@@ -222,7 +255,7 @@ def series_isin(elements, values):
222255 "only list-like objects are allowed to be passed to isin(), "
223256 f"you passed a [{ type (values )} ]"
224257 )
225- op = DataFrameIsin (values )
258+ op = DataFrameIsin (values = values )
226259 return op (elements )
227260
228261
@@ -298,5 +331,5 @@ def df_isin(df, values):
298331 "only list-like objects or dict are allowed to be passed to isin(), "
299332 f"you passed a [{ type (values )} ]"
300333 )
301- op = DataFrameIsin (values )
334+ op = DataFrameIsin (values = values )
302335 return op (df )
0 commit comments