Skip to content

Commit 9c35c43

Browse files
committed
[SPARK-52580][PS] Avoid CAST_INVALID_INPUT of replace in ANSI mode
### What changes were proposed in this pull request? Avoid CAST_INVALID_INPUT of `replace` in ANSI mode. Specifically, under ANSI mode - used try_cast() to safely cast values - NaN checks, we now avoid F.isnan() on non-numeric types An example of the spark plan difference between ANSI on/off is: ``` # if the original column is of StringType # ANSI off Column<'CASE WHEN in(C, 0, 1, 2, 3, 5, 6) THEN 4 ELSE C END'> # ANSI on Column<'CASE WHEN in(C, TRY_CAST(0 AS STRING), TRY_CAST(1 AS STRING), TRY_CAST(2 AS STRING), TRY_CAST(3 AS STRING), TRY_CAST(5 AS STRING), TRY_CAST(6 AS STRING)) THEN TRY_CAST(4 AS STRING) ELSE TRY_CAST(C AS STRING) END'> ``` ### Why are the changes needed? Ensure pandas on Spark works well with ANSI mode on. Part of https://issues.apache.org/jira/browse/SPARK-52556. ### Does this PR introduce _any_ user-facing change? Yes, `replace` works in ANSI, for example ```py >>> ps.set_option("compute.fail_on_ansi_mode", False) >>> ps.set_option("compute.ansi_mode_support", True) >>> pdf = pd.DataFrame( ... {"A": [0, 1, 2, 3, np.nan], "B": [5, 6, 7, 8, np.nan], "C": ["a", "b", "c", "d", None]}, ... index=np.random.rand(5), ... ) >>> psdf = ps.from_pandas(pdf) >>> psdf["C"].replace([0, 1, 2, 3, 5, 6], 4) 0.458472 a 0.749773 b 0.222904 c 0.397280 d 0.293933 None Name: C, dtype: object >>> psdf.replace([0, 1, 2, 3, 5, 6], [6, 5, 4, 3, 2, 1]) A B C 0.458472 6.0 2.0 a 0.749773 5.0 1.0 b 0.222904 4.0 7.0 c 0.397280 3.0 8.0 d 0.293933 NaN NaN None ``` ### How was this patch tested? Unit tests ### Was this patch authored or co-authored using generative AI tooling? No Closes #51297 from xinrong-meng/replace. Authored-by: Xinrong Meng <[email protected]> Signed-off-by: Xinrong Meng <[email protected]>
1 parent e9a285e commit 9c35c43

File tree

2 files changed

+46
-11
lines changed

2 files changed

+46
-11
lines changed

python/pyspark/pandas/series.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
from pyspark.pandas.plot import PandasOnSparkPlotAccessor
105105
from pyspark.pandas.utils import (
106106
combine_frames,
107+
is_ansi_mode_enabled,
107108
is_name_like_tuple,
108109
is_name_like_value,
109110
name_like_string,
@@ -5106,33 +5107,68 @@ def replace(
51065107
)
51075108
)
51085109
to_replace = {k: v for k, v in zip(to_replace, value)}
5110+
5111+
spark_session = self._internal.spark_frame.sparkSession
5112+
ansi_mode = is_ansi_mode_enabled(spark_session)
5113+
col_type = self.spark.data_type
5114+
51095115
if isinstance(to_replace, dict):
51105116
is_start = True
51115117
if len(to_replace) == 0:
51125118
current = self.spark.column
51135119
else:
51145120
for to_replace_, value in to_replace.items():
5115-
cond = (
5116-
(F.isnan(self.spark.column) | self.spark.column.isNull())
5117-
if pd.isna(to_replace_)
5118-
else (self.spark.column == F.lit(to_replace_))
5119-
)
5121+
if pd.isna(to_replace_):
5122+
if ansi_mode and isinstance(col_type, NumericType):
5123+
cond = F.isnan(self.spark.column) | self.spark.column.isNull()
5124+
else:
5125+
cond = self.spark.column.isNull()
5126+
else:
5127+
to_replace_lit = (
5128+
F.lit(to_replace_).try_cast(col_type)
5129+
if ansi_mode
5130+
else F.lit(to_replace_)
5131+
)
5132+
cond = self.spark.column == to_replace_lit
5133+
value_expr = F.lit(value).try_cast(col_type) if ansi_mode else F.lit(value)
51205134
if is_start:
5121-
current = F.when(cond, value)
5135+
current = F.when(cond, value_expr)
51225136
is_start = False
51235137
else:
5124-
current = current.when(cond, value)
5138+
current = current.when(cond, value_expr)
51255139
current = current.otherwise(self.spark.column)
51265140
else:
51275141
if regex:
51285142
# to_replace must be a string
51295143
cond = self.spark.column.rlike(cast(str, to_replace))
51305144
else:
5131-
cond = self.spark.column.isin(to_replace)
5145+
if ansi_mode:
5146+
to_replace_values = (
5147+
[to_replace]
5148+
if not is_list_like(to_replace) or isinstance(to_replace, str)
5149+
else to_replace
5150+
)
5151+
to_replace_values = cast(List[Any], to_replace_values)
5152+
literals = [F.lit(v).try_cast(col_type) for v in to_replace_values]
5153+
cond = self.spark.column.isin(literals)
5154+
else:
5155+
cond = self.spark.column.isin(to_replace)
51325156
# to_replace may be a scalar
51335157
if np.array(pd.isna(to_replace)).any():
5134-
cond = cond | F.isnan(self.spark.column) | self.spark.column.isNull()
5135-
current = F.when(cond, value).otherwise(self.spark.column)
5158+
if ansi_mode:
5159+
if isinstance(col_type, NumericType):
5160+
cond = cond | F.isnan(self.spark.column) | self.spark.column.isNull()
5161+
else:
5162+
cond = cond | self.spark.column.isNull()
5163+
else:
5164+
cond = cond | F.isnan(self.spark.column) | self.spark.column.isNull()
5165+
5166+
if ansi_mode:
5167+
value_expr = F.lit(value).try_cast(col_type)
5168+
current = F.when(cond, value_expr).otherwise(self.spark.column.try_cast(col_type))
5169+
5170+
else:
5171+
current = F.when(cond, value).otherwise(self.spark.column)
51365172

51375173
return self._with_new_scol(current) # TODO: dtype?
51385174

python/pyspark/pandas/tests/computation/test_missing_data.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,6 @@ def test_fillna(self):
274274
pdf.fillna({("x", "a"): -1, ("x", "b"): -2, ("y", "c"): -5}),
275275
)
276276

277-
@unittest.skipIf(is_ansi_mode_test, ansi_mode_not_supported_message)
278277
def test_replace(self):
279278
pdf = pd.DataFrame(
280279
{

0 commit comments

Comments
 (0)