Skip to content
Closed
Original file line number Diff line number Diff line change
Expand Up @@ -4276,6 +4276,19 @@ object functions {
def substring(str: Column, pos: Int, len: Int): Column =
Column.fn("substring", str, lit(pos), lit(len))

/**
* Substring starts at `pos` and is of length `len` when str is String type or returns the slice
* of byte array that starts at `pos` in byte and is of length `len` when str is Binary type
*
* @note
* The position is not zero based, but 1 based index.
*
* @group string_funcs
* @since 4.0.0
*/
def substring(str: Column, pos: Column, len: Column): Column =
Column.fn("substring", str, pos, len)

/**
* Returns the substring from string str before count occurrences of the delimiter delim. If
* count is positive, everything the left of the final delimiter (counting from left) is
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1780,6 +1780,10 @@ class PlanGenerationTestSuite
fn.substring(fn.col("g"), 4, 5)
}

functionTest("substring using columns") {
fn.substring(fn.col("g"), fn.col("a"), fn.col("b"))
}

functionTest("substring_index") {
fn.substring_index(fn.col("g"), ";", 5)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [substring(g#0, a#0, cast(b#0 as int)) AS substring(g, a, b)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Project [substring(g#0, 4, 5) AS substring(g, 4, 5)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"common": {
"planId": "1"
},
"project": {
"input": {
"common": {
"planId": "0"
},
"localRelation": {
"schema": "struct\u003cid:bigint,a:int,b:double,d:struct\u003cid:bigint,a:int,b:double\u003e,e:array\u003cint\u003e,f:map\u003cstring,struct\u003cid:bigint,a:int,b:double\u003e\u003e,g:string\u003e"
}
},
"expressions": [{
"unresolvedFunction": {
"functionName": "substring",
"arguments": [{
"unresolvedAttribute": {
"unparsedIdentifier": "g"
}
}, {
"unresolvedAttribute": {
"unparsedIdentifier": "a"
}
}, {
"unresolvedAttribute": {
"unparsedIdentifier": "b"
}
}]
}
}]
}
}
Binary file not shown.
72 changes: 66 additions & 6 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10915,7 +10915,9 @@ def sentences(


@_try_remote_functions
def substring(str: "ColumnOrName", pos: int, len: int) -> Column:
def substring(
str: "ColumnOrName", pos: Union["ColumnOrName", int], len: Union["ColumnOrName", int]
) -> Column:
"""
Substring starts at `pos` and is of length `len` when str is String type or
returns the slice of byte array that starts at `pos` in byte and is of length `len`
Expand All @@ -10934,11 +10936,14 @@ def substring(str: "ColumnOrName", pos: int, len: int) -> Column:
----------
str : :class:`~pyspark.sql.Column` or str
target column to work on.
pos : int
pos : :class:`~pyspark.sql.Column` or str or int
starting position in str.
len : int
len : :class:`~pyspark.sql.Column` or str or int
length of chars.

.. versionchanged:: 4.0.0
`pos` and `len` now also accept Columns or names of Columns.

Returns
-------
:class:`~pyspark.sql.Column`
Expand All @@ -10949,9 +10954,18 @@ def substring(str: "ColumnOrName", pos: int, len: int) -> Column:
>>> df = spark.createDataFrame([('abcd',)], ['s',])
>>> df.select(substring(df.s, 1, 2).alias('s')).collect()
[Row(s='ab')]
>>> df = spark.createDataFrame([('Spark', 2, 3)], ['s', 'p', 'l'])
>>> df.select(substring(df.s, 2, df.l).alias('s')).collect()
[Row(s='par')]
>>> df.select(substring(df.s, df.p, 3).alias('s')).collect()
[Row(s='par')]
>>> df.select(substring(df.s, df.p, df.l).alias('s')).collect()
[Row(s='par')]
"""
from pyspark.sql.classic.column import _to_java_column

pos = _to_java_column(lit(pos) if isinstance(pos, int) else pos)
len = _to_java_column(lit(len) if isinstance(len, int) else len)
return _invoke_function("substring", _to_java_column(str), pos, len)


Expand Down Expand Up @@ -13969,7 +13983,10 @@ def array_position(col: "ColumnOrName", value: Any) -> Column:
col : :class:`~pyspark.sql.Column` or str
target column to work on.
value : Any
value to look for.
value or a :class:`~pyspark.sql.Column` expression to look for.

.. versionchanged:: 4.0.0
`value` now also accepts a Column type.

Returns
-------
Expand Down Expand Up @@ -14034,9 +14051,22 @@ def array_position(col: "ColumnOrName", value: Any) -> Column:
+-----------------------+
| 3|
+-----------------------+

Example 6: Finding the position of a column's value in an array of integers

>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame([([10, 20, 30], 20)], ['data', 'col'])
>>> df.select(sf.array_position(df.data, df.col)).show()
+-------------------------+
|array_position(data, col)|
+-------------------------+
| 2|
+-------------------------+

"""
from pyspark.sql.classic.column import _to_java_column

value = _to_java_column(value) if isinstance(value, Column) else value
return _invoke_function("array_position", _to_java_column(col), value)


Expand Down Expand Up @@ -14402,7 +14432,10 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column:
col : :class:`~pyspark.sql.Column` or str
name of column containing array
element :
element to be removed from the array
element or a :class:`~pyspark.sql.Column` expression to be removed from the array

.. versionchanged:: 4.0.0
`element` now also accepts a Column type.

Returns
-------
Expand Down Expand Up @@ -14470,9 +14503,21 @@ def array_remove(col: "ColumnOrName", element: Any) -> Column:
+---------------------+
| []|
+---------------------+

Example 6: Removing a column's value from a simple array

>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame([([1, 2, 3, 1, 1], 1)], ['data', 'col'])
>>> df.select(sf.array_remove(df.data, df.col)).show()
+-----------------------+
|array_remove(data, col)|
+-----------------------+
| [2, 3]|
+-----------------------+
"""
from pyspark.sql.classic.column import _to_java_column

element = _to_java_column(element) if isinstance(element, Column) else element
return _invoke_function("array_remove", _to_java_column(col), element)


Expand Down Expand Up @@ -17237,7 +17282,10 @@ def map_contains_key(col: "ColumnOrName", value: Any) -> Column:
col : :class:`~pyspark.sql.Column` or str
The name of the column or an expression that represents the map.
value :
A literal value.
A literal value, or a :class:`~pyspark.sql.Column` expression.

.. versionchanged:: 4.0.0
`value` now also accepts a Column type.

Returns
-------
Expand Down Expand Up @@ -17267,9 +17315,21 @@ def map_contains_key(col: "ColumnOrName", value: Any) -> Column:
+--------------------------+
| false|
+--------------------------+

Example 3: Check for key using a column

>>> from pyspark.sql import functions as sf
>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as data, 1 as key")
>>> df.select(sf.map_contains_key("data", sf.col("key"))).show()
+---------------------------+
|map_contains_key(data, key)|
+---------------------------+
| true|
+---------------------------+
"""
from pyspark.sql.classic.column import _to_java_column

value = _to_java_column(value) if isinstance(value, Column) else value
return _invoke_function("map_contains_key", _to_java_column(col), value)


Expand Down
13 changes: 13 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4234,6 +4234,19 @@ object functions {
def substring(str: Column, pos: Int, len: Int): Column =
Column.fn("substring", str, lit(pos), lit(len))

/**
* Substring starts at `pos` and is of length `len` when str is String type or
* returns the slice of byte array that starts at `pos` in byte and is of length `len`
* when str is Binary type
*
* @note The position is not zero based, but 1 based index.
*
* @group string_funcs
* @since 4.0.0
*/
def substring(str: Column, pos: Column, len: Column): Column =
Column.fn("substring", str, pos, len)

/**
* Returns the substring from string str before count occurrences of the delimiter delim.
* If count is positive, everything the left of the final delimiter (counting from left) is
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,11 @@ class StringFunctionsSuite extends QueryTest with SharedSparkSession {
// scalastyle:on
}

test("string substring function using columns") {
val df = Seq(("Spark", 2, 3)).toDF("a", "b", "c")
checkAnswer(df.select(substring($"a", $"b", $"c")), Row("par"))
}

test("string encode/decode function") {
val bytes = Array[Byte](-27, -92, -89, -27, -115, -125, -28, -72, -106, -25, -107, -116)
// scalastyle:off
Expand Down