Skip to content

Commit aa4bfb0

Browse files
committed
Revert "[SPARK-48591][PYTHON] Simplify the if-else branches with F.lit"
revert #46946 since it may cause circular import issue ``` File "/home/jenkins/python/pyspark/sql/connect/functions/__init__.py", line 20, in <module> from pyspark.sql.connect.functions.builtin import * # noqa: F401,F403 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/jenkins/python/pyspark/sql/connect/functions/builtin.py", line 60, in <module> from pyspark.sql.connect.udf import _create_py_udf File "/home/jenkins/python/pyspark/sql/connect/udf.py", line 38, in <module> from pyspark.sql.connect.column import Column ImportError: cannot import name 'Column' from partially initialized module 'pyspark.sql.connect.column' (most likely due to a circular import) (/home/jenkins/python/pyspark/sql/connect/column.py) Had test failures in delta.connect.tests.test_deltatable with python; see logs. ``` Closes #46985 from zhengruifeng/revert_simplify_column. Authored-by: Ruifeng Zheng <[email protected]> Signed-off-by: Ruifeng Zheng <[email protected]>
1 parent 2d2bedf commit aa4bfb0

File tree

1 file changed

+25
-20
lines changed

1 file changed

+25
-20
lines changed

python/pyspark/sql/connect/column.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,13 @@
2727
Any,
2828
Union,
2929
Optional,
30-
cast,
3130
)
3231

3332
from pyspark.sql.column import Column as ParentColumn
3433
from pyspark.errors import PySparkTypeError, PySparkAttributeError, PySparkValueError
3534
from pyspark.sql.types import DataType
3635

3736
import pyspark.sql.connect.proto as proto
38-
from pyspark.sql.connect.functions import builtin as F
3937
from pyspark.sql.connect.expressions import (
4038
Expression,
4139
UnresolvedFunction,
@@ -310,12 +308,14 @@ def when(self, condition: ParentColumn, value: Any) -> ParentColumn:
310308
message_parameters={},
311309
)
312310

313-
return Column(
314-
CaseWhen(
315-
branches=self._expr._branches + [(condition._expr, F.lit(value)._expr)],
316-
else_value=None,
317-
)
318-
)
311+
if isinstance(value, Column):
312+
_value = value._expr
313+
else:
314+
_value = LiteralExpression._from_value(value)
315+
316+
_branches = self._expr._branches + [(condition._expr, _value)]
317+
318+
return Column(CaseWhen(branches=_branches, else_value=None))
319319

320320
def otherwise(self, value: Any) -> ParentColumn:
321321
if not isinstance(self._expr, CaseWhen):
@@ -328,12 +328,12 @@ def otherwise(self, value: Any) -> ParentColumn:
328328
"otherwise() can only be applied once on a Column previously generated by when()"
329329
)
330330

331-
return Column(
332-
CaseWhen(
333-
branches=self._expr._branches,
334-
else_value=cast(Expression, F.lit(value)._expr),
335-
)
336-
)
331+
if isinstance(value, Column):
332+
_value = value._expr
333+
else:
334+
_value = LiteralExpression._from_value(value)
335+
336+
return Column(CaseWhen(branches=self._expr._branches, else_value=_value))
337337

338338
def like(self: ParentColumn, other: str) -> ParentColumn:
339339
return _bin_op("like", self, other)
@@ -457,11 +457,14 @@ def isin(self, *cols: Any) -> ParentColumn:
457457
else:
458458
_cols = list(cols)
459459

460-
return Column(
461-
UnresolvedFunction(
462-
"in", [self._expr] + [cast(Expression, F.lit(c)._expr) for c in _cols]
463-
)
464-
)
460+
_exprs = [self._expr]
461+
for c in _cols:
462+
if isinstance(c, Column):
463+
_exprs.append(c._expr)
464+
else:
465+
_exprs.append(LiteralExpression._from_value(c))
466+
467+
return Column(UnresolvedFunction("in", _exprs))
465468

466469
def between(
467470
self,
@@ -551,8 +554,10 @@ def __getitem__(self, k: Any) -> ParentColumn:
551554
message_parameters={},
552555
)
553556
return self.substr(k.start, k.stop)
557+
elif isinstance(k, Column):
558+
return Column(UnresolvedExtractValue(self._expr, k._expr))
554559
else:
555-
return Column(UnresolvedExtractValue(self._expr, cast(Expression, F.lit(k)._expr)))
560+
return Column(UnresolvedExtractValue(self._expr, LiteralExpression._from_value(k)))
556561

557562
def __iter__(self) -> None:
558563
raise PySparkTypeError(

0 commit comments

Comments
 (0)