Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4197,6 +4197,20 @@ object functions {
*/
def split(str: Column, pattern: String): Column = Column.fn("split", str, lit(pattern))

/**
* Splits str around matches of the given pattern.
*
* @param str
* a string expression to split
* @param pattern
* a column of string representing a regular expression. The regex string should be a Java
* regular expression.
*
* @group string_funcs
* @since 4.0.0
*/
def split(str: Column, pattern: Column): Column = Column.fn("split", str, pattern)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add some new test cases for the connect module in PlanGenerationTestSuite.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CTCC1 also need to run connect/ testOnly org. apache. spark. sql. connect. ProtoToParsedPlanTestSuite to generate the golden files needed for reverse validation testing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the pointer, fixed


/**
* Splits str around matches of the given pattern.
*
Expand All @@ -4218,6 +4232,27 @@ object functions {
def split(str: Column, pattern: String, limit: Int): Column =
Column.fn("split", str, lit(pattern), lit(limit))

/**
* Splits str around matches of the given pattern.
*
* @param str
* a string expression to split
* @param pattern
* a column of string representing a regular expression. The regex string should be a Java
* regular expression.
* @param limit
* a column of integer expression which controls the number of times the regex is applied.
* <ul> <li>limit greater than 0: The resulting array's length will not be more than limit,
* and the resulting array's last entry will contain all input beyond the last matched
* regex.</li> <li>limit less than or equal to 0: `regex` will be applied as many times as
* possible, and the resulting array can be of any size.</li> </ul>
*
* @group string_funcs
* @since 4.0.0
*/
def split(str: Column, pattern: Column, limit: Column): Column =
Column.fn("split", str, pattern, limit)

/**
* Substring starts at `pos` and is of length `len` when str is String type or returns the slice
* of byte array that starts at `pos` in byte and is of length `len` when str is Binary type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1762,10 +1762,18 @@ class PlanGenerationTestSuite
fn.split(fn.col("g"), ";")
}

functionTest("split using columns") {
fn.split(fn.col("g"), fn.col("g"))
}

functionTest("split with limit") {
fn.split(fn.col("g"), ";", 10)
}

functionTest("split with limit using columns") {
fn.split(fn.col("g"), lit(";"), fn.col("a"))
}

functionTest("substring") {
fn.substring(fn.col("g"), 4, 5)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [split(g#0, g#0, -1) AS split(g, g, -1)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [split(g#0, ;, a#0) AS split(g, ;, a)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
}
},
"expressions": [{
"unresolvedFunction": {
"functionName": "split",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "g"
}
}, {
"unresolvedAttribute": {
"unparsedIdentifier": "g"
}
}]
}
}]
}
}
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
}
},
"expressions": [{
"unresolvedFunction": {
"functionName": "split",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "g"
}
}, {
"literal": {
"string": ";"
}
}, {
"unresolvedAttribute": {
"unparsedIdentifier": "a"
}
}]
}
}]
}
}
Binary file not shown.
9 changes: 7 additions & 2 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2476,8 +2476,13 @@ def repeat(col: "ColumnOrName", n: Union["ColumnOrName", int]) -> Column:
repeat.__doc__ = pysparkfuncs.repeat.__doc__


def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column:
return _invoke_function("split", _to_col(str), lit(pattern), lit(limit))
def split(
str: "ColumnOrName",
Copy link
Contributor Author

@CTCC1 CTCC1 Apr 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to consider renaming variable that shadows python built-in such as str. It's annoying here that it breaks the usage of isinstance(pattern, str).
Given renaming variable would be a backwards incompatible change (for user code that uses kwargs) and needs further discussion, I will work around this by aliasing the python builtin (See the workaround comment in code).

pattern: Union[Column, str],
limit: Union["ColumnOrName", int] = -1,
) -> Column:
limit = lit(limit) if isinstance(limit, int) else _to_col(limit)
return _invoke_function("split", _to_col(str), lit(pattern), limit)


split.__doc__ = pysparkfuncs.split.__doc__
Expand Down
65 changes: 57 additions & 8 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10944,7 +10944,11 @@ def repeat(col: "ColumnOrName", n: Union["ColumnOrName", int]) -> Column:


@_try_remote_functions
def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column:
def split(
str: "ColumnOrName",
pattern: Union[Column, str],
limit: Union["ColumnOrName", int] = -1,
) -> Column:
"""
Splits str around matches of the given pattern.

Expand All @@ -10957,10 +10961,10 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column:
----------
str : :class:`~pyspark.sql.Column` or str
a string expression to split
pattern : str
pattern : :class:`~pyspark.sql.Column` or str
a string representing a regular expression. The regex string should be
a Java regular expression.
limit : int, optional
limit : :class:`~pyspark.sql.Column` or str or int
an integer which controls the number of times `pattern` is applied.

* ``limit > 0``: The resulting array's length will not be more than `limit`, and the
Expand All @@ -10972,20 +10976,65 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column:
.. versionchanged:: 3.0
`split` now takes an optional `limit` field. If not provided, default limit value is -1.

.. versionchanged:: 4.0.0
`pattern` now accepts column. Does not accept column name since string type remain
accepted as a regular expression representation, for backwards compatibility.
In addition to int, `limit` now accepts column and column name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add more doctest in the Examples section to test the new supported types

those doctests will automatically be reused in Spark Connect Python Client.

Returns
-------
:class:`~pyspark.sql.Column`
array of separated strings.

Examples
--------
>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([('oneAtwoBthreeC',)], ['s',])
>>> df.select(split(df.s, '[ABC]', 2).alias('s')).collect()
[Row(s=['one', 'twoBthreeC'])]
>>> df.select(split(df.s, '[ABC]', -1).alias('s')).collect()
[Row(s=['one', 'two', 'three', ''])]
>>> df.select(sf.split(df.s, '[ABC]', 2).alias('s')).show()
+-----------------+
| s|
+-----------------+
|[one, twoBthreeC]|
+-----------------+

>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([('oneAtwoBthreeC',)], ['s',])
>>> df.select(sf.split(df.s, '[ABC]', -1).alias('s')).show()
+-------------------+
| s|
+-------------------+
|[one, two, three, ]|
+-------------------+

>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame(
... [('oneAtwoBthreeC', '[ABC]'), ('1A2B3C', '[1-9]+'), ('aa2bb3cc4', '[1-9]+')],
... ['s', 'pattern']
... )
>>> df.select(sf.split(df.s, df.pattern).alias('s')).show()
+-------------------+
| s|
+-------------------+
|[one, two, three, ]|
| [, A, B, C]|
| [aa, bb, cc, ]|
+-------------------+

>>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame(
... [('oneAtwoBthreeC', '[ABC]', 2), ('1A2B3C', '[1-9]+', -1)],
... ['s', 'pattern', 'expected_parts']
... )
>>> df.select(sf.split(df.s, df.pattern, df.expected_parts).alias('s')).show()
+-----------------+
| s|
+-----------------+
|[one, twoBthreeC]|
| [, A, B, C]|
+-----------------+
"""
return _invoke_function("split", _to_java_column(str), pattern, limit)
limit = lit(limit) if isinstance(limit, int) else limit
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think lit function accept both Column and int

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

limit can also be of type str referring to a column, hence the check to avoid making the column name a string literal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, we can use it for pattern

return _invoke_function_over_columns("split", str, lit(pattern), limit)


@_try_remote_functions
Expand Down
65 changes: 51 additions & 14 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4144,9 +4144,11 @@ object functions {
/**
* Splits str around matches of the given pattern.
*
* @param str a string expression to split
* @param pattern a string representing a regular expression. The regex string should be
* a Java regular expression.
* @param str
* a string expression to split
* @param pattern
* a string representing a regular expression. The regex string should be a Java regular
* expression.
*
* @group string_funcs
* @since 1.5.0
Expand All @@ -4156,24 +4158,59 @@ object functions {
/**
* Splits str around matches of the given pattern.
*
* @param str a string expression to split
* @param pattern a string representing a regular expression. The regex string should be
* a Java regular expression.
* @param limit an integer expression which controls the number of times the regex is applied.
* <ul>
* <li>limit greater than 0: The resulting array's length will not be more than limit,
* and the resulting array's last entry will contain all input beyond the last
* matched regex.</li>
* <li>limit less than or equal to 0: `regex` will be applied as many times as
* possible, and the resulting array can be of any size.</li>
* </ul>
* @param str
* a string expression to split
* @param pattern
* a column of string representing a regular expression. The regex string should be a Java
* regular expression.
*
* @group string_funcs
* @since 4.0.0
*/
def split(str: Column, pattern: Column): Column = Column.fn("split", str, pattern)

/**
* Splits str around matches of the given pattern.
*
* @param str
* a string expression to split
* @param pattern
* a string representing a regular expression. The regex string should be a Java regular
* expression.
* @param limit
* an integer expression which controls the number of times the regex is applied. <ul>
* <li>limit greater than 0: The resulting array's length will not be more than limit, and the
* resulting array's last entry will contain all input beyond the last matched regex.</li>
* <li>limit less than or equal to 0: `regex` will be applied as many times as possible, and
* the resulting array can be of any size.</li> </ul>
*
* @group string_funcs
* @since 3.0.0
*/
def split(str: Column, pattern: String, limit: Int): Column =
Column.fn("split", str, lit(pattern), lit(limit))

/**
* Splits str around matches of the given pattern.
*
* @param str
* a string expression to split
* @param pattern
* a column of string representing a regular expression. The regex string should be a Java
* regular expression.
* @param limit
* a column of integer expression which controls the number of times the regex is applied.
* <ul> <li>limit greater than 0: The resulting array's length will not be more than limit,
* and the resulting array's last entry will contain all input beyond the last matched
* regex.</li> <li>limit less than or equal to 0: `regex` will be applied as many times as
* possible, and the resulting array can be of any size.</li> </ul>
*
* @group string_funcs
* @since 4.0.0
*/
def split(str: Column, pattern: Column, limit: Column): Column =
Column.fn("split", str, pattern, limit)

/**
* Substring starts at `pos` and is of length `len` when str is String type or
* returns the slice of byte array that starts at `pos` in byte and is of length `len`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,33 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession {
Row(Seq("aa", "bb", "cc", "")))
}

test("SPARK-47845: string split function with column types") {
val df = Seq(
("aa2bb3cc4", "[1-9]+", 0),
("aa2bb3cc4", "[1-9]+", 2),
("aa2bb3cc4", "[1-9]+", -2)).toDF("a", "b", "c")

// without limit
val expectedNoLimit = Seq(
Row(Seq("aa", "bb", "cc", "")),
Row(Seq("aa", "bb", "cc", "")),
Row(Seq("aa", "bb", "cc", "")))

checkAnswer(df.select(split($"a", $"b")), expectedNoLimit)

checkAnswer(df.selectExpr("split(a, b)"), expectedNoLimit)

// with limit
val expectedWithLimit = Seq(
Row(Seq("aa", "bb", "cc", "")),
Row(Seq("aa", "bb3cc4")),
Row(Seq("aa", "bb", "cc", "")))

checkAnswer(df.select(split($"a", $"b", $"c")), expectedWithLimit)

checkAnswer(df.selectExpr("split(a, b, c)"), expectedWithLimit)
}

test("string / binary length function") {
val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123, 2.0f, 3.015))
.toDF("a", "b", "c", "d", "e")
Expand Down