diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
index f267baf9854e..e86fc7b2ea14 100644
--- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala
@@ -4197,6 +4197,20 @@ object functions {
*/
def split(str: Column, pattern: String): Column = Column.fn("split", str, lit(pattern))
+ /**
+ * Splits str around matches of the given pattern.
+ *
+ * @param str
+ * a string expression to split
+ * @param pattern
+ * a column of string representing a regular expression. The regex string should be a Java
+ * regular expression.
+ *
+ * @group string_funcs
+ * @since 4.0.0
+ */
+ def split(str: Column, pattern: Column): Column = Column.fn("split", str, pattern)
+
/**
* Splits str around matches of the given pattern.
*
@@ -4218,6 +4232,27 @@ object functions {
def split(str: Column, pattern: String, limit: Int): Column =
Column.fn("split", str, lit(pattern), lit(limit))
+ /**
+ * Splits str around matches of the given pattern.
+ *
+ * @param str
+ * a string expression to split
+ * @param pattern
+ * a column of string representing a regular expression. The regex string should be a Java
+ * regular expression.
+ * @param limit
+ * a column of integer expression which controls the number of times the regex is applied.
+ *
- limit greater than 0: The resulting array's length will not be more than limit,
+ * and the resulting array's last entry will contain all input beyond the last matched
+ * regex.
- limit less than or equal to 0: `regex` will be applied as many times as
+ * possible, and the resulting array can be of any size.
+ *
+ * @group string_funcs
+ * @since 4.0.0
+ */
+ def split(str: Column, pattern: Column, limit: Column): Column =
+ Column.fn("split", str, pattern, limit)
+
/**
* Substring starts at `pos` and is of length `len` when str is String type or returns the slice
* of byte array that starts at `pos` in byte and is of length `len` when str is Binary type
diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
index 5844df8a4889..b08197b8e3f8 100644
--- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
+++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala
@@ -1762,10 +1762,18 @@ class PlanGenerationTestSuite
fn.split(fn.col("g"), ";")
}
+ functionTest("split using columns") {
+ fn.split(fn.col("g"), fn.col("g"))
+ }
+
functionTest("split with limit") {
fn.split(fn.col("g"), ";", 10)
}
+ functionTest("split with limit using columns") {
+ fn.split(fn.col("g"), lit(";"), fn.col("a"))
+ }
+
functionTest("substring") {
fn.substring(fn.col("g"), 4, 5)
}
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_split_using_columns.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_split_using_columns.explain
new file mode 100644
index 000000000000..2ce3052d7d75
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_split_using_columns.explain
@@ -0,0 +1,2 @@
+Project [split(g#0, g#0, -1) AS split(g, g, -1)#0]
++- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_split_with_limit_using_columns.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_split_with_limit_using_columns.explain
new file mode 100644
index 000000000000..2d16b9eed332
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_split_with_limit_using_columns.explain
@@ -0,0 +1,2 @@
+Project [split(g#0, ;, a#0) AS split(g, ;, a)#0]
++- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_split_using_columns.json b/connector/connect/common/src/test/resources/query-tests/queries/function_split_using_columns.json
new file mode 100644
index 000000000000..98ef0e54e621
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/queries/function_split_using_columns.json
@@ -0,0 +1,29 @@
+{
+ "common": {
+ "planId": "1"
+ },
+ "project": {
+ "input": {
+ "common": {
+ "planId": "0"
+ },
+ "localRelation": {
+ "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
+ }
+ },
+ "expressions": [{
+ "unresolvedFunction": {
+ "functionName": "split",
+ "arguments": [{
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "g"
+ }
+ }, {
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "g"
+ }
+ }]
+ }
+ }]
+ }
+}
\ No newline at end of file
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_split_using_columns.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_split_using_columns.proto.bin
new file mode 100644
index 000000000000..a87702f83d1b
Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_split_using_columns.proto.bin differ
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_split_with_limit_using_columns.json b/connector/connect/common/src/test/resources/query-tests/queries/function_split_with_limit_using_columns.json
new file mode 100644
index 000000000000..138f9d70b2c8
--- /dev/null
+++ b/connector/connect/common/src/test/resources/query-tests/queries/function_split_with_limit_using_columns.json
@@ -0,0 +1,33 @@
+{
+ "common": {
+ "planId": "1"
+ },
+ "project": {
+ "input": {
+ "common": {
+ "planId": "0"
+ },
+ "localRelation": {
+ "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
+ }
+ },
+ "expressions": [{
+ "unresolvedFunction": {
+ "functionName": "split",
+ "arguments": [{
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "g"
+ }
+ }, {
+ "literal": {
+ "string": ";"
+ }
+ }, {
+ "unresolvedAttribute": {
+ "unparsedIdentifier": "a"
+ }
+ }]
+ }
+ }]
+ }
+}
\ No newline at end of file
diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_split_with_limit_using_columns.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_split_with_limit_using_columns.proto.bin
new file mode 100644
index 000000000000..04e24be40e9d
Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_split_with_limit_using_columns.proto.bin differ
diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py
index 478ec317287d..29901d998242 100644
--- a/python/pyspark/sql/connect/functions/builtin.py
+++ b/python/pyspark/sql/connect/functions/builtin.py
@@ -2476,8 +2476,13 @@ def repeat(col: "ColumnOrName", n: Union["ColumnOrName", int]) -> Column:
repeat.__doc__ = pysparkfuncs.repeat.__doc__
-def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column:
- return _invoke_function("split", _to_col(str), lit(pattern), lit(limit))
+def split(
+ str: "ColumnOrName",
+ pattern: Union[Column, str],
+ limit: Union["ColumnOrName", int] = -1,
+) -> Column:
+ limit = lit(limit) if isinstance(limit, int) else _to_col(limit)
+ return _invoke_function("split", _to_col(str), lit(pattern), limit)
split.__doc__ = pysparkfuncs.split.__doc__
diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py
index 41f3d3d5909f..1cdba4b8b3ab 100644
--- a/python/pyspark/sql/functions/builtin.py
+++ b/python/pyspark/sql/functions/builtin.py
@@ -10944,7 +10944,11 @@ def repeat(col: "ColumnOrName", n: Union["ColumnOrName", int]) -> Column:
@_try_remote_functions
-def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column:
+def split(
+ str: "ColumnOrName",
+ pattern: Union[Column, str],
+ limit: Union["ColumnOrName", int] = -1,
+) -> Column:
"""
Splits str around matches of the given pattern.
@@ -10957,10 +10961,10 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column:
----------
str : :class:`~pyspark.sql.Column` or str
a string expression to split
- pattern : str
+ pattern : :class:`~pyspark.sql.Column` or str
a string representing a regular expression. The regex string should be
a Java regular expression.
- limit : int, optional
+ limit : :class:`~pyspark.sql.Column` or str or int
an integer which controls the number of times `pattern` is applied.
* ``limit > 0``: The resulting array's length will not be more than `limit`, and the
@@ -10972,6 +10976,11 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column:
.. versionchanged:: 3.0
`split` now takes an optional `limit` field. If not provided, default limit value is -1.
+ .. versionchanged:: 4.0.0
+ `pattern` now accepts column. Does not accept column name since string type remain
+ accepted as a regular expression representation, for backwards compatibility.
+ In addition to int, `limit` now accepts column and column name.
+
Returns
-------
:class:`~pyspark.sql.Column`
@@ -10979,13 +10988,53 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column:
Examples
--------
+ >>> import pyspark.sql.functions as sf
>>> df = spark.createDataFrame([('oneAtwoBthreeC',)], ['s',])
- >>> df.select(split(df.s, '[ABC]', 2).alias('s')).collect()
- [Row(s=['one', 'twoBthreeC'])]
- >>> df.select(split(df.s, '[ABC]', -1).alias('s')).collect()
- [Row(s=['one', 'two', 'three', ''])]
+ >>> df.select(sf.split(df.s, '[ABC]', 2).alias('s')).show()
+ +-----------------+
+ | s|
+ +-----------------+
+ |[one, twoBthreeC]|
+ +-----------------+
+
+ >>> import pyspark.sql.functions as sf
+ >>> df = spark.createDataFrame([('oneAtwoBthreeC',)], ['s',])
+ >>> df.select(sf.split(df.s, '[ABC]', -1).alias('s')).show()
+ +-------------------+
+ | s|
+ +-------------------+
+ |[one, two, three, ]|
+ +-------------------+
+
+ >>> import pyspark.sql.functions as sf
+ >>> df = spark.createDataFrame(
+ ... [('oneAtwoBthreeC', '[ABC]'), ('1A2B3C', '[1-9]+'), ('aa2bb3cc4', '[1-9]+')],
+ ... ['s', 'pattern']
+ ... )
+ >>> df.select(sf.split(df.s, df.pattern).alias('s')).show()
+ +-------------------+
+ | s|
+ +-------------------+
+ |[one, two, three, ]|
+ | [, A, B, C]|
+ | [aa, bb, cc, ]|
+ +-------------------+
+
+ >>> import pyspark.sql.functions as sf
+ >>> df = spark.createDataFrame(
+ ... [('oneAtwoBthreeC', '[ABC]', 2), ('1A2B3C', '[1-9]+', -1)],
+ ... ['s', 'pattern', 'expected_parts']
+ ... )
+ >>> df.select(sf.split(df.s, df.pattern, df.expected_parts).alias('s')).show()
+ +-----------------+
+ | s|
+ +-----------------+
+ |[one, twoBthreeC]|
+ | [, A, B, C]|
+ +-----------------+
"""
- return _invoke_function("split", _to_java_column(str), pattern, limit)
+ limit = lit(limit) if isinstance(limit, int) else limit
+ return _invoke_function_over_columns("split", str, lit(pattern), limit)
@_try_remote_functions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index e21375713b8a..99dc43a1da15 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -4144,9 +4144,11 @@ object functions {
/**
* Splits str around matches of the given pattern.
*
- * @param str a string expression to split
- * @param pattern a string representing a regular expression. The regex string should be
- * a Java regular expression.
+ * @param str
+ * a string expression to split
+ * @param pattern
+ * a string representing a regular expression. The regex string should be a Java regular
+ * expression.
*
* @group string_funcs
* @since 1.5.0
@@ -4156,17 +4158,31 @@ object functions {
/**
* Splits str around matches of the given pattern.
*
- * @param str a string expression to split
- * @param pattern a string representing a regular expression. The regex string should be
- * a Java regular expression.
- * @param limit an integer expression which controls the number of times the regex is applied.
- *
- * - limit greater than 0: The resulting array's length will not be more than limit,
- * and the resulting array's last entry will contain all input beyond the last
- * matched regex.
- * - limit less than or equal to 0: `regex` will be applied as many times as
- * possible, and the resulting array can be of any size.
- *
+ * @param str
+ * a string expression to split
+ * @param pattern
+ * a column of string representing a regular expression. The regex string should be a Java
+ * regular expression.
+ *
+ * @group string_funcs
+ * @since 4.0.0
+ */
+ def split(str: Column, pattern: Column): Column = Column.fn("split", str, pattern)
+
+ /**
+ * Splits str around matches of the given pattern.
+ *
+ * @param str
+ * a string expression to split
+ * @param pattern
+ * a string representing a regular expression. The regex string should be a Java regular
+ * expression.
+ * @param limit
+ * an integer expression which controls the number of times the regex is applied.
+ * - limit greater than 0: The resulting array's length will not be more than limit, and the
+ * resulting array's last entry will contain all input beyond the last matched regex.
+ * - limit less than or equal to 0: `regex` will be applied as many times as possible, and
+ * the resulting array can be of any size.
*
* @group string_funcs
* @since 3.0.0
@@ -4174,6 +4190,27 @@ object functions {
def split(str: Column, pattern: String, limit: Int): Column =
Column.fn("split", str, lit(pattern), lit(limit))
+ /**
+ * Splits str around matches of the given pattern.
+ *
+ * @param str
+ * a string expression to split
+ * @param pattern
+ * a column of string representing a regular expression. The regex string should be a Java
+ * regular expression.
+ * @param limit
+ * a column of integer expression which controls the number of times the regex is applied.
+ * - limit greater than 0: The resulting array's length will not be more than limit,
+ * and the resulting array's last entry will contain all input beyond the last matched
+ * regex.
- limit less than or equal to 0: `regex` will be applied as many times as
+ * possible, and the resulting array can be of any size.
+ *
+ * @group string_funcs
+ * @since 4.0.0
+ */
+ def split(str: Column, pattern: Column, limit: Column): Column =
+ Column.fn("split", str, pattern, limit)
+
/**
* Substring starts at `pos` and is of length `len` when str is String type or
* returns the slice of byte array that starts at `pos` in byte and is of length `len`
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 75b4415db6b5..3fc0b572d80b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -525,6 +525,33 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession {
Row(Seq("aa", "bb", "cc", "")))
}
+ test("SPARK-47845: string split function with column types") {
+ val df = Seq(
+ ("aa2bb3cc4", "[1-9]+", 0),
+ ("aa2bb3cc4", "[1-9]+", 2),
+ ("aa2bb3cc4", "[1-9]+", -2)).toDF("a", "b", "c")
+
+ // without limit
+ val expectedNoLimit = Seq(
+ Row(Seq("aa", "bb", "cc", "")),
+ Row(Seq("aa", "bb", "cc", "")),
+ Row(Seq("aa", "bb", "cc", "")))
+
+ checkAnswer(df.select(split($"a", $"b")), expectedNoLimit)
+
+ checkAnswer(df.selectExpr("split(a, b)"), expectedNoLimit)
+
+ // with limit
+ val expectedWithLimit = Seq(
+ Row(Seq("aa", "bb", "cc", "")),
+ Row(Seq("aa", "bb3cc4")),
+ Row(Seq("aa", "bb", "cc", "")))
+
+ checkAnswer(df.select(split($"a", $"b", $"c")), expectedWithLimit)
+
+ checkAnswer(df.selectExpr("split(a, b, c)"), expectedWithLimit)
+ }
+
test("string / binary length function") {
val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123, 2.0f, 3.015))
.toDF("a", "b", "c", "d", "e")