diff --git a/python/pyspark/sql/connect/column.py b/python/pyspark/sql/connect/column.py index b63e06bccae1..ef48091a35b0 100644 --- a/python/pyspark/sql/connect/column.py +++ b/python/pyspark/sql/connect/column.py @@ -96,6 +96,10 @@ def _unary_op(name: str, self: ParentColumn) -> ParentColumn: return Column(UnresolvedFunction(name, [self._expr])) # type: ignore[list-item] +def _to_expr(v: Any) -> Expression: + return v._expr if isinstance(v, Column) else LiteralExpression._from_value(v) + + @with_origin_to_class class Column(ParentColumn): def __new__( @@ -310,14 +314,12 @@ def when(self, condition: ParentColumn, value: Any) -> ParentColumn: message_parameters={}, ) - if isinstance(value, Column): - _value = value._expr - else: - _value = LiteralExpression._from_value(value) - - _branches = self._expr._branches + [(condition._expr, _value)] - - return Column(CaseWhen(branches=_branches, else_value=None)) + return Column( + CaseWhen( + branches=self._expr._branches + [(condition._expr, _to_expr(value))], + else_value=None, + ) + ) def otherwise(self, value: Any) -> ParentColumn: if not isinstance(self._expr, CaseWhen): @@ -330,12 +332,12 @@ def otherwise(self, value: Any) -> ParentColumn: "otherwise() can only be applied once on a Column previously generated by when()" ) - if isinstance(value, Column): - _value = value._expr - else: - _value = LiteralExpression._from_value(value) - - return Column(CaseWhen(branches=self._expr._branches, else_value=_value)) + return Column( + CaseWhen( + branches=self._expr._branches, + else_value=_to_expr(value), + ) + ) def like(self: ParentColumn, other: str) -> ParentColumn: return _bin_op("like", self, other) @@ -360,22 +362,15 @@ def substr( }, ) - if isinstance(length, Column): - length_expr = length._expr - start_expr = startPos._expr # type: ignore[union-attr] - elif isinstance(length, int): - length_expr = LiteralExpression._from_value(length) - start_expr = LiteralExpression._from_value(startPos) + if isinstance(length, (Column, int)): + length_expr = _to_expr(length) + start_expr = _to_expr(startPos) else: raise PySparkTypeError( error_class="NOT_COLUMN_OR_INT", message_parameters={"arg_name": "startPos", "arg_type": type(length).__name__}, ) - return Column( - UnresolvedFunction( - "substr", [self._expr, start_expr, length_expr] # type: ignore[list-item] - ) - ) + return Column(UnresolvedFunction("substr", [self._expr, start_expr, length_expr])) def __eq__(self, other: Any) -> ParentColumn: # type: ignore[override] if other is None or isinstance( @@ -459,14 +454,7 @@ def isin(self, *cols: Any) -> ParentColumn: else: _cols = list(cols) - _exprs = [self._expr] - for c in _cols: - if isinstance(c, Column): - _exprs.append(c._expr) - else: - _exprs.append(LiteralExpression._from_value(c)) - - return Column(UnresolvedFunction("in", _exprs)) + return Column(UnresolvedFunction("in", [self._expr] + [_to_expr(c) for c in _cols])) def between( self, @@ -556,10 +544,8 @@ def __getitem__(self, k: Any) -> ParentColumn: message_parameters={}, ) return self.substr(k.start, k.stop) - elif isinstance(k, Column): - return Column(UnresolvedExtractValue(self._expr, k._expr)) else: - return Column(UnresolvedExtractValue(self._expr, LiteralExpression._from_value(k))) + return Column(UnresolvedExtractValue(self._expr, _to_expr(k))) def __iter__(self) -> None: raise PySparkTypeError(