11from __future__ import annotations
22
3- from collections import defaultdict
43import copy
54import itertools
6- from typing import TYPE_CHECKING , Dict , List , Sequence , cast
5+ from typing import TYPE_CHECKING , Dict , List , Sequence
76
87import numpy as np
98
1413from pandas .core .dtypes .cast import ensure_dtype_can_hold_na , find_common_type
1514from pandas .core .dtypes .common import (
1615 is_categorical_dtype ,
17- is_datetime64_dtype ,
1816 is_datetime64tz_dtype ,
17+ is_dtype_equal ,
1918 is_extension_array_dtype ,
20- is_float_dtype ,
21- is_numeric_dtype ,
2219 is_sparse ,
23- is_timedelta64_dtype ,
2420)
2521from pandas .core .dtypes .concat import concat_compat
26- from pandas .core .dtypes .missing import isna_all
22+ from pandas .core .dtypes .missing import is_valid_na_for_dtype , isna_all
2723
2824import pandas .core .algorithms as algos
2925from pandas .core .arrays import DatetimeArray , ExtensionArray
3329
3430if TYPE_CHECKING :
3531 from pandas import Index
36- from pandas .core .arrays .sparse .dtype import SparseDtype
3732
3833
3934def concatenate_block_managers (
@@ -232,6 +227,29 @@ def dtype(self):
232227 return blk .dtype
233228 return ensure_dtype_can_hold_na (blk .dtype )
234229
230+ def is_valid_na_for (self , dtype : DtypeObj ) -> bool :
231+ """
232+ Check that we are all-NA of a type/dtype that is compatible with this dtype.
233+ Augments `self.is_na` with an additional check of the type of NA values.
234+ """
235+ if not self .is_na :
236+ return False
237+ if self .block is None :
238+ return True
239+
240+ if self .dtype == object :
241+ values = self .block .values
242+ return all (is_valid_na_for_dtype (x , dtype ) for x in values .ravel (order = "K" ))
243+
244+ if self .dtype .kind == dtype .kind == "M" and not is_dtype_equal (
245+ self .dtype , dtype
246+ ):
247+ # fill_values match but we should not cast self.block.values to dtype
248+ return False
249+
250+ na_value = self .block .fill_value
251+ return is_valid_na_for_dtype (na_value , dtype )
252+
235253 @cache_readonly
236254 def is_na (self ) -> bool :
237255 if self .block is None :
@@ -262,7 +280,7 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike:
262280 else :
263281 fill_value = upcasted_na
264282
265- if self .is_na :
283+ if self .is_valid_na_for ( empty_dtype ) :
266284 blk_dtype = getattr (self .block , "dtype" , None )
267285
268286 if blk_dtype == np .dtype (object ):
@@ -276,10 +294,9 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike:
276294 if is_datetime64tz_dtype (blk_dtype ) or is_datetime64tz_dtype (
277295 empty_dtype
278296 ):
279- if self .block is None :
280- # TODO(EA2D): special case unneeded with 2D EAs
281- i8values = np .full (self .shape [1 ], fill_value .value )
282- return DatetimeArray (i8values , dtype = empty_dtype )
297+ # TODO(EA2D): special case unneeded with 2D EAs
298+ i8values = np .full (self .shape [1 ], fill_value .value )
299+ return DatetimeArray (i8values , dtype = empty_dtype )
283300 elif is_categorical_dtype (blk_dtype ):
284301 pass
285302 elif is_extension_array_dtype (blk_dtype ):
@@ -295,6 +312,8 @@ def get_reindexed_values(self, empty_dtype: DtypeObj, upcasted_na) -> ArrayLike:
295312 empty_arr , allow_fill = True , fill_value = fill_value
296313 )
297314 else :
315+ # NB: we should never get here with empty_dtype integer or bool;
316+ # if we did, the missing_arr.fill would cast to gibberish
298317 missing_arr = np .empty (self .shape , dtype = empty_dtype )
299318 missing_arr .fill (fill_value )
300319 return missing_arr
@@ -362,14 +381,12 @@ def _concatenate_join_units(
362381 # concatting with at least one EA means we are concatting a single column
363382 # the non-EA values are 2D arrays with shape (1, n)
364383 to_concat = [t if isinstance (t , ExtensionArray ) else t [0 , :] for t in to_concat ]
365- concat_values = concat_compat (to_concat , axis = 0 )
366- if not isinstance (concat_values , ExtensionArray ) or (
367- isinstance (concat_values , DatetimeArray ) and concat_values .tz is None
368- ):
384+ concat_values = concat_compat (to_concat , axis = 0 , ea_compat_axis = True )
385+ if not is_extension_array_dtype (concat_values .dtype ):
369386 # if the result of concat is not an EA but an ndarray, reshape to
370387 # 2D to put it a non-EA Block
371- # special case DatetimeArray, which *is* an EA, but is put in a
372- # consolidated 2D block
388+ # special case DatetimeArray/TimedeltaArray , which *is* an EA, but
389+ # is put in a consolidated 2D block
373390 concat_values = np .atleast_2d (concat_values )
374391 else :
375392 concat_values = concat_compat (to_concat , axis = concat_axis )
@@ -419,108 +436,17 @@ def _get_empty_dtype(join_units: Sequence[JoinUnit]) -> DtypeObj:
419436 return empty_dtype
420437
421438 has_none_blocks = any (unit .block is None for unit in join_units )
422- dtypes = [None if unit .block is None else unit .dtype for unit in join_units ]
423439
424- filtered_dtypes = [
440+ dtypes = [
425441 unit .dtype for unit in join_units if unit .block is not None and not unit .is_na
426442 ]
427- if not len (filtered_dtypes ):
428- filtered_dtypes = [unit .dtype for unit in join_units if unit .block is not None ]
429- dtype_alt = find_common_type (filtered_dtypes )
430-
431- upcast_classes = _get_upcast_classes (join_units , dtypes )
432-
433- if is_extension_array_dtype (dtype_alt ):
434- return dtype_alt
435- elif dtype_alt == object :
436- return dtype_alt
437-
438- # TODO: de-duplicate with maybe_promote?
439- # create the result
440- if "extension" in upcast_classes :
441- return np .dtype ("object" )
442- elif "bool" in upcast_classes :
443- if has_none_blocks :
444- return np .dtype (np .object_ )
445- else :
446- return np .dtype (np .bool_ )
447- elif "datetimetz" in upcast_classes :
448- # GH-25014. We use NaT instead of iNaT, since this eventually
449- # ends up in DatetimeArray.take, which does not allow iNaT.
450- dtype = upcast_classes ["datetimetz" ]
451- return dtype [0 ]
452- elif "datetime" in upcast_classes :
453- return np .dtype ("M8[ns]" )
454- elif "timedelta" in upcast_classes :
455- return np .dtype ("m8[ns]" )
456- else :
457- try :
458- common_dtype = np .find_common_type (upcast_classes , [])
459- except TypeError :
460- # At least one is an ExtensionArray
461- return np .dtype (np .object_ )
462- else :
463- if is_float_dtype (common_dtype ):
464- return common_dtype
465- elif is_numeric_dtype (common_dtype ):
466- if has_none_blocks :
467- return np .dtype (np .float64 )
468- else :
469- return common_dtype
470-
471- msg = "invalid dtype determination in get_concat_dtype"
472- raise AssertionError (msg )
473-
474-
475- def _get_upcast_classes (
476- join_units : Sequence [JoinUnit ],
477- dtypes : Sequence [DtypeObj ],
478- ) -> Dict [str , List [DtypeObj ]]:
479- """Create mapping between upcast class names and lists of dtypes."""
480- upcast_classes : Dict [str , List [DtypeObj ]] = defaultdict (list )
481- null_upcast_classes : Dict [str , List [DtypeObj ]] = defaultdict (list )
482- for dtype , unit in zip (dtypes , join_units ):
483- if dtype is None :
484- continue
485-
486- upcast_cls = _select_upcast_cls_from_dtype (dtype )
487- # Null blocks should not influence upcast class selection, unless there
488- # are only null blocks, when same upcasting rules must be applied to
489- # null upcast classes.
490- if unit .is_na :
491- null_upcast_classes [upcast_cls ].append (dtype )
492- else :
493- upcast_classes [upcast_cls ].append (dtype )
494-
495- if not upcast_classes :
496- upcast_classes = null_upcast_classes
497-
498- return upcast_classes
499-
500-
501- def _select_upcast_cls_from_dtype (dtype : DtypeObj ) -> str :
502- """Select upcast class name based on dtype."""
503- if is_categorical_dtype (dtype ):
504- return "extension"
505- elif is_datetime64tz_dtype (dtype ):
506- return "datetimetz"
507- elif is_extension_array_dtype (dtype ):
508- return "extension"
509- elif issubclass (dtype .type , np .bool_ ):
510- return "bool"
511- elif issubclass (dtype .type , np .object_ ):
512- return "object"
513- elif is_datetime64_dtype (dtype ):
514- return "datetime"
515- elif is_timedelta64_dtype (dtype ):
516- return "timedelta"
517- elif is_sparse (dtype ):
518- dtype = cast ("SparseDtype" , dtype )
519- return dtype .subtype .name
520- elif is_float_dtype (dtype ) or is_numeric_dtype (dtype ):
521- return dtype .name
522- else :
523- return "float"
443+ if not len (dtypes ):
444+ dtypes = [unit .dtype for unit in join_units if unit .block is not None ]
445+
446+ dtype = find_common_type (dtypes )
447+ if has_none_blocks :
448+ dtype = ensure_dtype_can_hold_na (dtype )
449+ return dtype
524450
525451
526452def _is_uniform_join_units (join_units : List [JoinUnit ]) -> bool :
0 commit comments