diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index f267baf9854e..e86fc7b2ea14 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -4197,6 +4197,20 @@ object functions { */ def split(str: Column, pattern: String): Column = Column.fn("split", str, lit(pattern)) + /** + * Splits str around matches of the given pattern. + * + * @param str + * a string expression to split + * @param pattern + * a column of string representing a regular expression. The regex string should be a Java + * regular expression. + * + * @group string_funcs + * @since 4.0.0 + */ + def split(str: Column, pattern: Column): Column = Column.fn("split", str, pattern) + /** * Splits str around matches of the given pattern. * @@ -4218,6 +4232,27 @@ object functions { def split(str: Column, pattern: String, limit: Int): Column = Column.fn("split", str, lit(pattern), lit(limit)) + /** + * Splits str around matches of the given pattern. + * + * @param str + * a string expression to split + * @param pattern + * a column of string representing a regular expression. The regex string should be a Java + * regular expression. + * @param limit + * a column of integer expression which controls the number of times the regex is applied. + * + * + * @group string_funcs + * @since 4.0.0 + */ + def split(str: Column, pattern: Column, limit: Column): Column = + Column.fn("split", str, pattern, limit) + /** * Substring starts at `pos` and is of length `len` when str is String type or returns the slice * of byte array that starts at `pos` in byte and is of length `len` when str is Binary type diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index 5844df8a4889..b08197b8e3f8 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -1762,10 +1762,18 @@ class PlanGenerationTestSuite fn.split(fn.col("g"), ";") } + functionTest("split using columns") { + fn.split(fn.col("g"), fn.col("g")) + } + functionTest("split with limit") { fn.split(fn.col("g"), ";", 10) } + functionTest("split with limit using columns") { + fn.split(fn.col("g"), lit(";"), fn.col("a")) + } + functionTest("substring") { fn.substring(fn.col("g"), 4, 5) } diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_split_using_columns.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_split_using_columns.explain new file mode 100644 index 000000000000..2ce3052d7d75 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_split_using_columns.explain @@ -0,0 +1,2 @@ +Project [split(g#0, g#0, -1) AS split(g, g, -1)#0] ++- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_split_with_limit_using_columns.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_split_with_limit_using_columns.explain new file mode 100644 index 000000000000..2d16b9eed332 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_split_with_limit_using_columns.explain @@ -0,0 +1,2 @@ +Project [split(g#0, ;, a#0) AS split(g, ;, a)#0] ++- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_split_using_columns.json b/connector/connect/common/src/test/resources/query-tests/queries/function_split_using_columns.json new file mode 100644 index 000000000000..98ef0e54e621 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_split_using_columns.json @@ -0,0 +1,29 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "split", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "g" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "g" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_split_using_columns.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_split_using_columns.proto.bin new file mode 100644 index 000000000000..a87702f83d1b Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_split_using_columns.proto.bin differ diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_split_with_limit_using_columns.json b/connector/connect/common/src/test/resources/query-tests/queries/function_split_with_limit_using_columns.json new file mode 100644 index 000000000000..138f9d70b2c8 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_split_with_limit_using_columns.json @@ -0,0 +1,33 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "split", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "g" + } + }, { + "literal": { + "string": ";" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_split_with_limit_using_columns.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_split_with_limit_using_columns.proto.bin new file mode 100644 index 000000000000..04e24be40e9d Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_split_with_limit_using_columns.proto.bin differ diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 478ec317287d..29901d998242 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -2476,8 +2476,13 @@ def repeat(col: "ColumnOrName", n: Union["ColumnOrName", int]) -> Column: repeat.__doc__ = pysparkfuncs.repeat.__doc__ -def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: - return _invoke_function("split", _to_col(str), lit(pattern), lit(limit)) +def split( + str: "ColumnOrName", + pattern: Union[Column, str], + limit: Union["ColumnOrName", int] = -1, +) -> Column: + limit = lit(limit) if isinstance(limit, int) else _to_col(limit) + return _invoke_function("split", _to_col(str), lit(pattern), limit) split.__doc__ = pysparkfuncs.split.__doc__ diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 41f3d3d5909f..1cdba4b8b3ab 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -10944,7 +10944,11 @@ def repeat(col: "ColumnOrName", n: Union["ColumnOrName", int]) -> Column: @_try_remote_functions -def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: +def split( + str: "ColumnOrName", + pattern: Union[Column, str], + limit: Union["ColumnOrName", int] = -1, +) -> Column: """ Splits str around matches of the given pattern. @@ -10957,10 +10961,10 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: ---------- str : :class:`~pyspark.sql.Column` or str a string expression to split - pattern : str + pattern : :class:`~pyspark.sql.Column` or str a string representing a regular expression. The regex string should be a Java regular expression. - limit : int, optional + limit : :class:`~pyspark.sql.Column` or str or int an integer which controls the number of times `pattern` is applied. * ``limit > 0``: The resulting array's length will not be more than `limit`, and the @@ -10972,6 +10976,11 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: .. versionchanged:: 3.0 `split` now takes an optional `limit` field. If not provided, default limit value is -1. + .. versionchanged:: 4.0.0 + `pattern` now accepts column. Does not accept column name since string type remain + accepted as a regular expression representation, for backwards compatibility. + In addition to int, `limit` now accepts column and column name. + Returns ------- :class:`~pyspark.sql.Column` @@ -10979,13 +10988,53 @@ def split(str: "ColumnOrName", pattern: str, limit: int = -1) -> Column: Examples -------- + >>> import pyspark.sql.functions as sf >>> df = spark.createDataFrame([('oneAtwoBthreeC',)], ['s',]) - >>> df.select(split(df.s, '[ABC]', 2).alias('s')).collect() - [Row(s=['one', 'twoBthreeC'])] - >>> df.select(split(df.s, '[ABC]', -1).alias('s')).collect() - [Row(s=['one', 'two', 'three', ''])] + >>> df.select(sf.split(df.s, '[ABC]', 2).alias('s')).show() + +-----------------+ + | s| + +-----------------+ + |[one, twoBthreeC]| + +-----------------+ + + >>> import pyspark.sql.functions as sf + >>> df = spark.createDataFrame([('oneAtwoBthreeC',)], ['s',]) + >>> df.select(sf.split(df.s, '[ABC]', -1).alias('s')).show() + +-------------------+ + | s| + +-------------------+ + |[one, two, three, ]| + +-------------------+ + + >>> import pyspark.sql.functions as sf + >>> df = spark.createDataFrame( + ... [('oneAtwoBthreeC', '[ABC]'), ('1A2B3C', '[1-9]+'), ('aa2bb3cc4', '[1-9]+')], + ... ['s', 'pattern'] + ... ) + >>> df.select(sf.split(df.s, df.pattern).alias('s')).show() + +-------------------+ + | s| + +-------------------+ + |[one, two, three, ]| + | [, A, B, C]| + | [aa, bb, cc, ]| + +-------------------+ + + >>> import pyspark.sql.functions as sf + >>> df = spark.createDataFrame( + ... [('oneAtwoBthreeC', '[ABC]', 2), ('1A2B3C', '[1-9]+', -1)], + ... ['s', 'pattern', 'expected_parts'] + ... ) + >>> df.select(sf.split(df.s, df.pattern, df.expected_parts).alias('s')).show() + +-----------------+ + | s| + +-----------------+ + |[one, twoBthreeC]| + | [, A, B, C]| + +-----------------+ """ - return _invoke_function("split", _to_java_column(str), pattern, limit) + limit = lit(limit) if isinstance(limit, int) else limit + return _invoke_function_over_columns("split", str, lit(pattern), limit) @_try_remote_functions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e21375713b8a..99dc43a1da15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4144,9 +4144,11 @@ object functions { /** * Splits str around matches of the given pattern. * - * @param str a string expression to split - * @param pattern a string representing a regular expression. The regex string should be - * a Java regular expression. + * @param str + * a string expression to split + * @param pattern + * a string representing a regular expression. The regex string should be a Java regular + * expression. * * @group string_funcs * @since 1.5.0 @@ -4156,17 +4158,31 @@ object functions { /** * Splits str around matches of the given pattern. * - * @param str a string expression to split - * @param pattern a string representing a regular expression. The regex string should be - * a Java regular expression. - * @param limit an integer expression which controls the number of times the regex is applied. - * + * @param str + * a string expression to split + * @param pattern + * a column of string representing a regular expression. The regex string should be a Java + * regular expression. + * + * @group string_funcs + * @since 4.0.0 + */ + def split(str: Column, pattern: Column): Column = Column.fn("split", str, pattern) + + /** + * Splits str around matches of the given pattern. + * + * @param str + * a string expression to split + * @param pattern + * a string representing a regular expression. The regex string should be a Java regular + * expression. + * @param limit + * an integer expression which controls the number of times the regex is applied. * * @group string_funcs * @since 3.0.0 @@ -4174,6 +4190,27 @@ object functions { def split(str: Column, pattern: String, limit: Int): Column = Column.fn("split", str, lit(pattern), lit(limit)) + /** + * Splits str around matches of the given pattern. + * + * @param str + * a string expression to split + * @param pattern + * a column of string representing a regular expression. The regex string should be a Java + * regular expression. + * @param limit + * a column of integer expression which controls the number of times the regex is applied. + * + * + * @group string_funcs + * @since 4.0.0 + */ + def split(str: Column, pattern: Column, limit: Column): Column = + Column.fn("split", str, pattern, limit) + /** * Substring starts at `pos` and is of length `len` when str is String type or * returns the slice of byte array that starts at `pos` in byte and is of length `len` diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 75b4415db6b5..3fc0b572d80b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -525,6 +525,33 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { Row(Seq("aa", "bb", "cc", ""))) } + test("SPARK-47845: string split function with column types") { + val df = Seq( + ("aa2bb3cc4", "[1-9]+", 0), + ("aa2bb3cc4", "[1-9]+", 2), + ("aa2bb3cc4", "[1-9]+", -2)).toDF("a", "b", "c") + + // without limit + val expectedNoLimit = Seq( + Row(Seq("aa", "bb", "cc", "")), + Row(Seq("aa", "bb", "cc", "")), + Row(Seq("aa", "bb", "cc", ""))) + + checkAnswer(df.select(split($"a", $"b")), expectedNoLimit) + + checkAnswer(df.selectExpr("split(a, b)"), expectedNoLimit) + + // with limit + val expectedWithLimit = Seq( + Row(Seq("aa", "bb", "cc", "")), + Row(Seq("aa", "bb3cc4")), + Row(Seq("aa", "bb", "cc", ""))) + + checkAnswer(df.select(split($"a", $"b", $"c")), expectedWithLimit) + + checkAnswer(df.selectExpr("split(a, b, c)"), expectedWithLimit) + } + test("string / binary length function") { val df = Seq(("123", Array[Byte](1, 2, 3, 4), 123, 2.0f, 3.015)) .toDF("a", "b", "c", "d", "e")