Skip to content
Closed
Next Next commit
Support usage of columns as parameters to more pyspark functions
  • Loading branch information
Ronserruya committed Jun 5, 2024
commit 7d85f1ccd89d20fb0327847e8d8713a8f8907b14
51 changes: 42 additions & 9 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10915,7 +10915,7 @@ def sentences(


@_try_remote_functions
def substring(str: "ColumnOrName", pos: int, len: int) -> Column:
def substring(str: "ColumnOrName", pos: Union["ColumnOrName", int], len: Union["ColumnOrName", int]) -> Column:
"""
Substring starts at `pos` and is of length `len` when str is String type or
returns the slice of byte array that starts at `pos` in byte and is of length `len`
Expand All @@ -10934,9 +10934,9 @@ def substring(str: "ColumnOrName", pos: int, len: int) -> Column:
----------
str : :class:`~pyspark.sql.Column` or str
target column to work on.
pos : int
pos : :class:`~pyspark.sql.Column` or str or int
starting position in str.
len : int
len : :class:`~pyspark.sql.Column` or str or int
length of chars.

Returns
Expand All @@ -10949,14 +10949,20 @@ def substring(str: "ColumnOrName", pos: int, len: int) -> Column:
>>> df = spark.createDataFrame([('abcd',)], ['s',])
>>> df.select(substring(df.s, 1, 2).alias('s')).collect()
[Row(s='ab')]

>>> df = spark.createDataFrame([('abcd', 2, 3)], ['s', 'start', 'len'])
>>> df.select(substring(df.s, df.start, df.len).alias('s')).collect()
[Row(s='bcd')]
"""
from pyspark.sql.classic.column import _to_java_column

pos = _to_java_column(pos) if isinstance(pos, (str, Column)) else pos
len = _to_java_column(len) if isinstance(pos, (str, Column)) else len
return _invoke_function("substring", _to_java_column(str), pos, len)


@_try_remote_functions
def substring_index(str: "ColumnOrName", delim: str, count: int) -> Column:
def substring_index(str: "ColumnOrName", delim: Union[Column, str], count: Union["ColumnOrName", int]) -> Column:
"""
Returns the substring from string str before count occurrences of the delimiter delim.
If count is positive, everything the left of the final delimiter (counting from left) is
Expand All @@ -10972,9 +10978,9 @@ def substring_index(str: "ColumnOrName", delim: str, count: int) -> Column:
----------
str : :class:`~pyspark.sql.Column` or str
target column to work on.
delim : str
delim : :class:`~pyspark.sql.Column` or str
delimiter of values.
count : int
count : :class:`~pyspark.sql.Column` or str or int
number of occurrences.

Returns
Expand All @@ -10992,6 +10998,8 @@ def substring_index(str: "ColumnOrName", delim: str, count: int) -> Column:
"""
from pyspark.sql.classic.column import _to_java_column

delim = delim._jc if isinstance(delim, Column) else delim
count = _to_java_column(count) if isinstance(count, (str, Column)) else count
return _invoke_function("substring_index", _to_java_column(str), delim, count)


Expand Down Expand Up @@ -13969,7 +13977,7 @@ def array_position(col: "ColumnOrName", value: Any) -> Column:
col : :class:`~pyspark.sql.Column` or str
target column to work on.
value : Any
value to look for.
value or a :class:`~pyspark.sql.Column` expression to look for.

Returns
-------
Expand Down Expand Up @@ -14034,9 +14042,21 @@ def array_position(col: "ColumnOrName", value: Any) -> Column:
+-----------------------+
| 3|
+-----------------------+

Example 6: Finding the position of a column's value in an array of integers

>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame([([10, 20, 30], 20)], ['data', 'col'])
>>> df.select(sf.array_position(df.data, df.col)).show()
+-------------------------+
|array_position(data, col)|
+-------------------------+
| 2 |
+-------------------------+
"""
from pyspark.sql.classic.column import _to_java_column

value = value._jc if isinstance(value, Column) else value
return _invoke_function("array_position", _to_java_column(col), value)


Expand Down Expand Up @@ -14402,7 +14422,7 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column:
col : :class:`~pyspark.sql.Column` or str
name of column containing array
element :
element to be removed from the array
element or a :class:`~pyspark.sql.Column` expression to be removed from the array

Returns
-------
Expand Down Expand Up @@ -14470,9 +14490,21 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column:
+---------------------+
| []|
+---------------------+

Example 6: Removing a column's value from a simple array

>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame([([1, 2, 3, 1, 1], 1)], ['data', 'col'])
>>> df.select(sf.array_remove(df.data, df.col)).show()
+-----------------------+
|array_remove(data, col)|
+-----------------------+
| [2, 3]|
+-----------------------+
"""
from pyspark.sql.classic.column import _to_java_column

element = element._jc if isinstance(element, Column) else element
return _invoke_function("array_remove", _to_java_column(col), element)


Expand Down Expand Up @@ -17237,7 +17269,7 @@ def map_contains_key(col: "ColumnOrName", value: Any) -> Column:
col : :class:`~pyspark.sql.Column` or str
The name of the column or an expression that represents the map.
value :
A literal value.
A literal value, or a :class:`~pyspark.sql.Column` expression.

Returns
-------
Expand Down Expand Up @@ -17270,6 +17302,7 @@ def map_contains_key(col: "ColumnOrName", value: Any) -> Column:
"""
from pyspark.sql.classic.column import _to_java_column

value = value._jc if isinstance(value, Column) else value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Example 3: Check for key using a column was already supported

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it?
In spark 3.5.1 this gives me an error

df = spark.sql("select map(1, 2, 3, 4) as m, 1 as k")
df.select(F.map_contains_key(df.m, df.k))

# pyspark.errors.exceptions.base.PySparkTypeError: [NOT_ITERABLE] Column is not iterable.

which makes sense since you try to pass a Column type to _invoke_function which expects only native types or JavaObject for the args

Copy link
Contributor

@zhengruifeng zhengruifeng Jun 14, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, it was not supported in Classic mode, but supported in Connect mode.

Classic:

In [2]: df = spark.sql("select map(1, 2, 3, 4) as m, 1 as k")
   ...: df.select(F.map_contains_key(df.m, df.k))
---------------------------------------------------------------------------
PySparkTypeError                          Traceback (most recent call last)
Cell In[2], line 2
      1 df = spark.sql("select map(1, 2, 3, 4) as m, 1 as k")
----> 2 df.select(F.map_contains_key(df.m, df.k))

...

File ~/Dev/spark/python/pyspark/sql/classic/column.py:415, in Column.__iter__(self)
    414 def __iter__(self) -> None:
--> 415     raise PySparkTypeError(
    416         error_class="NOT_ITERABLE", message_parameters={"objectName": "Column"}
    417     )

PySparkTypeError: [NOT_ITERABLE] Column is not iterable.

Connect:

In [1]: from pyspark.sql import functions as F

In [2]: df = spark.sql("select map(1, 2, 3, 4) as m, 1 as k")
   ...: df.select(F.map_contains_key(df.m, df.k))
Out[2]: DataFrame[map_contains_key(m, k): boolean]

There is a slight difference in the handling of value: Any typed value: Spark Connect always convert value: Any to Column/Expression (because of the requirement of the UnresolvedFunction proto), while some functions (e.g. map_contains_key) in Classic don't do this.

We will need to revisit all the Any typed parameters in functions. cc @HyukjinKwon

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting behavior difference. What's the reason for not converting the classic PySpark value to a column/expression?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it was not by design, seems just due to the type mismatch in the internal helper functions

return _invoke_function("map_contains_key", _to_java_column(col), value)


Expand Down