diff --git a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala index c287e3469108..eae239a25589 100644 --- a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala +++ b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/functions.scala @@ -4276,6 +4276,19 @@ object functions { def substring(str: Column, pos: Int, len: Int): Column = Column.fn("substring", str, lit(pos), lit(len)) + /** + * Substring starts at `pos` and is of length `len` when str is String type or returns the slice + * of byte array that starts at `pos` in byte and is of length `len` when str is Binary type + * + * @note + * The position is not zero based, but 1 based index. + * + * @group string_funcs + * @since 4.0.0 + */ + def substring(str: Column, pos: Column, len: Column): Column = + Column.fn("substring", str, pos, len) + /** * Returns the substring from string str before count occurrences of the delimiter delim. If * count is positive, everything the left of the final delimiter (counting from left) is diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala index e0ad8f7078ca..987a50e13645 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/PlanGenerationTestSuite.scala @@ -1780,6 +1780,10 @@ class PlanGenerationTestSuite fn.substring(fn.col("g"), 4, 5) } + functionTest("substring using columns") { + fn.substring(fn.col("g"), fn.col("a"), fn.col("b")) + } + functionTest("substring_index") { fn.substring_index(fn.col("g"), ";", 5) } diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_substring_using_columns.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_substring_using_columns.explain new file mode 100644 index 000000000000..3050d15d9754 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_substring_using_columns.explain @@ -0,0 +1,2 @@ +Project [substring(g#0, a#0, cast(b#0 as int)) AS substring(g, a, b)#0] ++- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/explain-results/function_substring_with_columns.explain b/connector/connect/common/src/test/resources/query-tests/explain-results/function_substring_with_columns.explain new file mode 100644 index 000000000000..fe07244fc9ce --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/explain-results/function_substring_with_columns.explain @@ -0,0 +1,2 @@ +Project [substring(g#0, 4, 5) AS substring(g, 4, 5)#0] ++- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0] diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_substring_using_columns.json b/connector/connect/common/src/test/resources/query-tests/queries/function_substring_using_columns.json new file mode 100644 index 000000000000..ba28b1c7f570 --- /dev/null +++ b/connector/connect/common/src/test/resources/query-tests/queries/function_substring_using_columns.json @@ -0,0 +1,33 @@ +{ + "common": { + "planId": "1" + }, + "project": { + "input": { + "common": { + "planId": "0" + }, + "localRelation": { + "schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e" + } + }, + "expressions": [{ + "unresolvedFunction": { + "functionName": "substring", + "arguments": [{ + "unresolvedAttribute": { + "unparsedIdentifier": "g" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "a" + } + }, { + "unresolvedAttribute": { + "unparsedIdentifier": "b" + } + }] + } + }] + } +} \ No newline at end of file diff --git a/connector/connect/common/src/test/resources/query-tests/queries/function_substring_using_columns.proto.bin b/connector/connect/common/src/test/resources/query-tests/queries/function_substring_using_columns.proto.bin new file mode 100644 index 000000000000..f14b44ef5a50 Binary files /dev/null and b/connector/connect/common/src/test/resources/query-tests/queries/function_substring_using_columns.proto.bin differ diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 07dfcaf2e2b7..2edbc9f5abe1 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -10915,7 +10915,9 @@ def sentences( @_try_remote_functions -def substring(str: "ColumnOrName", pos: int, len: int) -> Column: +def substring( + str: "ColumnOrName", pos: Union["ColumnOrName", int], len: Union["ColumnOrName", int] +) -> Column: """ Substring starts at `pos` and is of length `len` when str is String type or returns the slice of byte array that starts at `pos` in byte and is of length `len` @@ -10934,11 +10936,14 @@ def substring(str: "ColumnOrName", pos: int, len: int) -> Column: ---------- str : :class:`~pyspark.sql.Column` or str target column to work on. - pos : int + pos : :class:`~pyspark.sql.Column` or str or int starting position in str. - len : int + len : :class:`~pyspark.sql.Column` or str or int length of chars. + .. versionchanged:: 4.0.0 + `pos` and `len` now also accept Columns or names of Columns. + Returns ------- :class:`~pyspark.sql.Column` @@ -10949,9 +10954,18 @@ def substring(str: "ColumnOrName", pos: int, len: int) -> Column: >>> df = spark.createDataFrame([('abcd',)], ['s',]) >>> df.select(substring(df.s, 1, 2).alias('s')).collect() [Row(s='ab')] + >>> df = spark.createDataFrame([('Spark', 2, 3)], ['s', 'p', 'l']) + >>> df.select(substring(df.s, 2, df.l).alias('s')).collect() + [Row(s='par')] + >>> df.select(substring(df.s, df.p, 3).alias('s')).collect() + [Row(s='par')] + >>> df.select(substring(df.s, df.p, df.l).alias('s')).collect() + [Row(s='par')] """ from pyspark.sql.classic.column import _to_java_column + pos = _to_java_column(lit(pos) if isinstance(pos, int) else pos) + len = _to_java_column(lit(len) if isinstance(len, int) else len) return _invoke_function("substring", _to_java_column(str), pos, len) @@ -13969,7 +13983,10 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: col : :class:`~pyspark.sql.Column` or str target column to work on. value : Any - value to look for. + value or a :class:`~pyspark.sql.Column` expression to look for. + + .. versionchanged:: 4.0.0 + `value` now also accepts a Column type. Returns ------- @@ -14034,9 +14051,22 @@ def array_position(col: "ColumnOrName", value: Any) -> Column: +-----------------------+ | 3| +-----------------------+ + + Example 6: Finding the position of a column's value in an array of integers + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([([10, 20, 30], 20)], ['data', 'col']) + >>> df.select(sf.array_position(df.data, df.col)).show() + +-------------------------+ + |array_position(data, col)| + +-------------------------+ + | 2| + +-------------------------+ + """ from pyspark.sql.classic.column import _to_java_column + value = _to_java_column(value) if isinstance(value, Column) else value return _invoke_function("array_position", _to_java_column(col), value) @@ -14402,7 +14432,10 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: col : :class:`~pyspark.sql.Column` or str name of column containing array element : - element to be removed from the array + element or a :class:`~pyspark.sql.Column` expression to be removed from the array + + .. versionchanged:: 4.0.0 + `element` now also accepts a Column type. Returns ------- @@ -14470,9 +14503,21 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column: +---------------------+ | []| +---------------------+ + + Example 6: Removing a column's value from a simple array + + >>> from pyspark.sql import functions as sf + >>> df = spark.createDataFrame([([1, 2, 3, 1, 1], 1)], ['data', 'col']) + >>> df.select(sf.array_remove(df.data, df.col)).show() + +-----------------------+ + |array_remove(data, col)| + +-----------------------+ + | [2, 3]| + +-----------------------+ """ from pyspark.sql.classic.column import _to_java_column + element = _to_java_column(element) if isinstance(element, Column) else element return _invoke_function("array_remove", _to_java_column(col), element) @@ -17237,7 +17282,10 @@ def map_contains_key(col: "ColumnOrName", value: Any) -> Column: col : :class:`~pyspark.sql.Column` or str The name of the column or an expression that represents the map. value : - A literal value. + A literal value, or a :class:`~pyspark.sql.Column` expression. + + .. versionchanged:: 4.0.0 + `value` now also accepts a Column type. Returns ------- @@ -17267,9 +17315,21 @@ def map_contains_key(col: "ColumnOrName", value: Any) -> Column: +--------------------------+ | false| +--------------------------+ + + Example 3: Check for key using a column + + >>> from pyspark.sql import functions as sf + >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data, 1 as key") + >>> df.select(sf.map_contains_key("data", sf.col("key"))).show() + +---------------------------+ + |map_contains_key(data, key)| + +---------------------------+ + | true| + +---------------------------+ """ from pyspark.sql.classic.column import _to_java_column + value = _to_java_column(value) if isinstance(value, Column) else value return _invoke_function("map_contains_key", _to_java_column(col), value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 52733611e42a..882918eb78c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -4234,6 +4234,19 @@ object functions { def substring(str: Column, pos: Int, len: Int): Column = Column.fn("substring", str, lit(pos), lit(len)) + /** + * Substring starts at `pos` and is of length `len` when str is String type or + * returns the slice of byte array that starts at `pos` in byte and is of length `len` + * when str is Binary type + * + * @note The position is not zero based, but 1 based index. + * + * @group string_funcs + * @since 4.0.0 + */ + def substring(str: Column, pos: Column, len: Column): Column = + Column.fn("substring", str, pos, len) + /** * Returns the substring from string str before count occurrences of the delimiter delim. * If count is positive, everything the left of the final delimiter (counting from left) is diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 3fc0b572d80b..31c1cac9fb71 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -332,6 +332,11 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession { // scalastyle:on } + test("string substring function using columns") { + val df = Seq(("Spark", 2, 3)).toDF("a", "b", "c") + checkAnswer(df.select(substring($"a", $"b", $"c")), Row("par")) + } + test("string encode/decode function") { val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116) // scalastyle:off