Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 22 additions & 36 deletions python/pyspark/sql/connect/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def _unary_op(name: str, self: ParentColumn) -> ParentColumn:
return Column(UnresolvedFunction(name, [self._expr])) # type: ignore[list-item]


def _to_expr(v: Any) -> Expression:
return v._expr if isinstance(v, Column) else LiteralExpression._from_value(v)


@with_origin_to_class
class Column(ParentColumn):
def __new__(
Expand Down Expand Up @@ -310,14 +314,12 @@ def when(self, condition: ParentColumn, value: Any) -> ParentColumn:
message_parameters={},
)

if isinstance(value, Column):
_value = value._expr
else:
_value = LiteralExpression._from_value(value)

_branches = self._expr._branches + [(condition._expr, _value)]

return Column(CaseWhen(branches=_branches, else_value=None))
return Column(
CaseWhen(
branches=self._expr._branches + [(condition._expr, _to_expr(value))],
else_value=None,
)
)

def otherwise(self, value: Any) -> ParentColumn:
if not isinstance(self._expr, CaseWhen):
Expand All @@ -330,12 +332,12 @@ def otherwise(self, value: Any) -> ParentColumn:
"otherwise() can only be applied once on a Column previously generated by when()"
)

if isinstance(value, Column):
_value = value._expr
else:
_value = LiteralExpression._from_value(value)

return Column(CaseWhen(branches=self._expr._branches, else_value=_value))
return Column(
CaseWhen(
branches=self._expr._branches,
else_value=_to_expr(value),
)
)

def like(self: ParentColumn, other: str) -> ParentColumn:
return _bin_op("like", self, other)
Expand All @@ -360,22 +362,15 @@ def substr(
},
)

if isinstance(length, Column):
length_expr = length._expr
start_expr = startPos._expr # type: ignore[union-attr]
elif isinstance(length, int):
length_expr = LiteralExpression._from_value(length)
start_expr = LiteralExpression._from_value(startPos)
if isinstance(length, (Column, int)):
length_expr = _to_expr(length)
start_expr = _to_expr(startPos)
else:
raise PySparkTypeError(
error_class="NOT_COLUMN_OR_INT",
message_parameters={"arg_name": "startPos", "arg_type": type(length).__name__},
)
return Column(
UnresolvedFunction(
"substr", [self._expr, start_expr, length_expr] # type: ignore[list-item]
)
)
return Column(UnresolvedFunction("substr", [self._expr, start_expr, length_expr]))

def __eq__(self, other: Any) -> ParentColumn: # type: ignore[override]
if other is None or isinstance(
Expand Down Expand Up @@ -459,14 +454,7 @@ def isin(self, *cols: Any) -> ParentColumn:
else:
_cols = list(cols)

_exprs = [self._expr]
for c in _cols:
if isinstance(c, Column):
_exprs.append(c._expr)
else:
_exprs.append(LiteralExpression._from_value(c))

return Column(UnresolvedFunction("in", _exprs))
return Column(UnresolvedFunction("in", [self._expr] + [_to_expr(c) for c in _cols]))

def between(
self,
Expand Down Expand Up @@ -556,10 +544,8 @@ def __getitem__(self, k: Any) -> ParentColumn:
message_parameters={},
)
return self.substr(k.start, k.stop)
elif isinstance(k, Column):
return Column(UnresolvedExtractValue(self._expr, k._expr))
else:
return Column(UnresolvedExtractValue(self._expr, LiteralExpression._from_value(k)))
return Column(UnresolvedExtractValue(self._expr, _to_expr(k)))

def __iter__(self) -> None:
raise PySparkTypeError(
Expand Down