Skip to content

Commit 0bbd049

Browse files
zhengruifengHyukjinKwon
authored andcommitted
[SPARK-48591][PYTHON] Simplify the if-else branches with F.lit
### What changes were proposed in this pull request? Simplify the if-else branches with `F.lit` which accept both Column and non-Column input ### Why are the changes needed? code clean up ### Does this PR introduce _any_ user-facing change? No, internal minor refactor ### How was this patch tested? CI ### Was this patch authored or co-authored using generative AI tooling? No Closes #46946 from zhengruifeng/column_simplify. Authored-by: Ruifeng Zheng <ruifengz@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent d1d29c9 commit 0bbd049

File tree

1 file changed

+20
-25
lines changed

1 file changed

+20
-25
lines changed

python/pyspark/sql/connect/column.py

Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@
2727
Any,
2828
Union,
2929
Optional,
30+
cast,
3031
)
3132

3233
from pyspark.sql.column import Column as ParentColumn
3334
from pyspark.errors import PySparkTypeError, PySparkAttributeError, PySparkValueError
3435
from pyspark.sql.types import DataType
3536

3637
import pyspark.sql.connect.proto as proto
38+
from pyspark.sql.connect.functions import builtin as F
3739
from pyspark.sql.connect.expressions import (
3840
Expression,
3941
UnresolvedFunction,
@@ -308,14 +310,12 @@ def when(self, condition: ParentColumn, value: Any) -> ParentColumn:
308310
message_parameters={},
309311
)
310312

311-
if isinstance(value, Column):
312-
_value = value._expr
313-
else:
314-
_value = LiteralExpression._from_value(value)
315-
316-
_branches = self._expr._branches + [(condition._expr, _value)]
317-
318-
return Column(CaseWhen(branches=_branches, else_value=None))
313+
return Column(
314+
CaseWhen(
315+
branches=self._expr._branches + [(condition._expr, F.lit(value)._expr)],
316+
else_value=None,
317+
)
318+
)
319319

320320
def otherwise(self, value: Any) -> ParentColumn:
321321
if not isinstance(self._expr, CaseWhen):
@@ -328,12 +328,12 @@ def otherwise(self, value: Any) -> ParentColumn:
328328
"otherwise() can only be applied once on a Column previously generated by when()"
329329
)
330330

331-
if isinstance(value, Column):
332-
_value = value._expr
333-
else:
334-
_value = LiteralExpression._from_value(value)
335-
336-
return Column(CaseWhen(branches=self._expr._branches, else_value=_value))
331+
return Column(
332+
CaseWhen(
333+
branches=self._expr._branches,
334+
else_value=cast(Expression, F.lit(value)._expr),
335+
)
336+
)
337337

338338
def like(self: ParentColumn, other: str) -> ParentColumn:
339339
return _bin_op("like", self, other)
@@ -457,14 +457,11 @@ def isin(self, *cols: Any) -> ParentColumn:
457457
else:
458458
_cols = list(cols)
459459

460-
_exprs = [self._expr]
461-
for c in _cols:
462-
if isinstance(c, Column):
463-
_exprs.append(c._expr)
464-
else:
465-
_exprs.append(LiteralExpression._from_value(c))
466-
467-
return Column(UnresolvedFunction("in", _exprs))
460+
return Column(
461+
UnresolvedFunction(
462+
"in", [self._expr] + [cast(Expression, F.lit(c)._expr) for c in _cols]
463+
)
464+
)
468465

469466
def between(
470467
self,
@@ -554,10 +551,8 @@ def __getitem__(self, k: Any) -> ParentColumn:
554551
message_parameters={},
555552
)
556553
return self.substr(k.start, k.stop)
557-
elif isinstance(k, Column):
558-
return Column(UnresolvedExtractValue(self._expr, k._expr))
559554
else:
560-
return Column(UnresolvedExtractValue(self._expr, LiteralExpression._from_value(k)))
555+
return Column(UnresolvedExtractValue(self._expr, cast(Expression, F.lit(k)._expr)))
561556

562557
def __iter__(self) -> None:
563558
raise PySparkTypeError(

0 commit comments

Comments
 (0)