Skip to content

Commit c99463d

Browse files
committed
[SPARK-26979][PYTHON][FOLLOW-UP] Make binary math/string functions take string as columns as well
## What changes were proposed in this pull request? This is a followup of #23882 to handle binary math/string functions. For instance, see the cases below: **Before:** ```python >>> from pyspark.sql.functions import lit, ascii >>> spark.range(1).select(lit('a').alias("value")).select(ascii("value")) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/.../spark/python/pyspark/sql/functions.py", line 51, in _ jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col) File "/.../spark/python/lib/py4j-0.10.8.1-src.zip/py4j/java_gateway.py", line 1286, in __call__ File "/.../spark/python/pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "/.../spark/python/lib/py4j-0.10.8.1-src.zip/py4j/protocol.py", line 332, in get_return_value py4j.protocol.Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.ascii. Trace: py4j.Py4JException: Method ascii([class java.lang.String]) does not exist at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318) at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:339) at py4j.Gateway.invoke(Gateway.java:276) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:238) at java.lang.Thread.run(Thread.java:748) ``` ```python >>> from pyspark.sql.functions import atan2 >>> spark.range(1).select(atan2("id", "id")) Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/.../spark/python/pyspark/sql/functions.py", line 78, in _ jc = getattr(sc._jvm.functions, name)(col1._jc if isinstance(col1, Column) else float(col1), ValueError: could not convert string to float: id ``` **After:** ```python >>> from pyspark.sql.functions import lit, ascii >>> spark.range(1).select(lit('a').alias("value")).select(ascii("value")) DataFrame[ascii(value): int] ``` ```python >>> from pyspark.sql.functions import atan2 >>> spark.range(1).select(atan2("id", "id")) DataFrame[ATAN2(id, id): double] ``` Note that, - This PR causes a slight behaviour changes for math functions. For instance, numbers as strings (e.g., `"1"`) were supported as arguments of binary math functions before. After this PR, it recognises it as column names. - I also intentionally didn't document this behaviour changes since we're going ahead for Spark 3.0 and I don't think numbers as strings make much sense in math functions. - There is another exception `when`, which takes string as literal values as below. This PR doeesn't fix this ambiguity. ```python >>> spark.range(1).select(when(lit(True), col("id"))).show() ``` ``` +--------------------------+ |CASE WHEN true THEN id END| +--------------------------+ | 0| +--------------------------+ ``` ```python >>> spark.range(1).select(when(lit(True), "id")).show() ``` ``` +--------------------------+ |CASE WHEN true THEN id END| +--------------------------+ | id| +--------------------------+ ``` This PR also fixes as below: #23882 fixed it to: - Rename `_create_function` to `_create_name_function` - Define new `_create_function` to take strings as column names. This PR, I proposes to: - Revert `_create_name_function` name to `_create_function`. - Define new `_create_function_over_column` to take strings as column names. ## How was this patch tested? Some unit tests were added for binary math / string functions. Closes #24121 from HyukjinKwon/SPARK-26979. Authored-by: Hyukjin Kwon <gurwls223@apache.org> Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
1 parent 8b0aa59 commit c99463d

File tree

2 files changed

+64
-29
lines changed

2 files changed

+64
-29
lines changed

python/pyspark/sql/functions.py

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,22 @@
3030

3131
from pyspark import since, SparkContext
3232
from pyspark.rdd import ignore_unicode_prefix, PythonEvalType
33-
from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal
33+
from pyspark.sql.column import Column, _to_java_column, _to_seq, _create_column_from_literal, \
34+
_create_column_from_name
3435
from pyspark.sql.dataframe import DataFrame
3536
from pyspark.sql.types import StringType, DataType
3637
# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409
3738
from pyspark.sql.udf import UserDefinedFunction, _create_udf
3839

40+
# Note to developers: all of PySpark functions here take string as column names whenever possible.
41+
# Namely, if columns are referred as arguments, they can be always both Column or string,
42+
# even though there might be few exceptions for legacy or inevitable reasons.
43+
# If you are fixing other language APIs together, also please note that Scala side is not the case
44+
# since it requires to make every single overridden definition.
3945

40-
def _create_name_function(name, doc=""):
41-
""" Create a function that takes a column name argument, by name"""
46+
47+
def _create_function(name, doc=""):
48+
"""Create a PySpark function by its name"""
4249
def _(col):
4350
sc = SparkContext._active_spark_context
4451
jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
@@ -48,8 +55,11 @@ def _(col):
4855
return _
4956

5057

51-
def _create_function(name, doc=""):
52-
""" Create a function that takes a Column object, by name"""
58+
def _create_function_over_column(name, doc=""):
59+
"""Similar with `_create_function` but creates a PySpark function that takes a column
60+
(as string as well). This is mainly for PySpark functions to take strings as
61+
column names.
62+
"""
5363
def _(col):
5464
sc = SparkContext._active_spark_context
5565
jc = getattr(sc._jvm.functions, name)(_to_java_column(col))
@@ -71,9 +81,23 @@ def _create_binary_mathfunction(name, doc=""):
7181
""" Create a binary mathfunction by name"""
7282
def _(col1, col2):
7383
sc = SparkContext._active_spark_context
74-
# users might write ints for simplicity. This would throw an error on the JVM side.
75-
jc = getattr(sc._jvm.functions, name)(col1._jc if isinstance(col1, Column) else float(col1),
76-
col2._jc if isinstance(col2, Column) else float(col2))
84+
# For legacy reasons, the arguments here can be implicitly converted into floats,
85+
# if they are not columns or strings.
86+
if isinstance(col1, Column):
87+
arg1 = col1._jc
88+
elif isinstance(col1, basestring):
89+
arg1 = _create_column_from_name(col1)
90+
else:
91+
arg1 = float(col1)
92+
93+
if isinstance(col2, Column):
94+
arg2 = col2._jc
95+
elif isinstance(col2, basestring):
96+
arg2 = _create_column_from_name(col2)
97+
else:
98+
arg2 = float(col2)
99+
100+
jc = getattr(sc._jvm.functions, name)(arg1, arg2)
77101
return Column(jc)
78102
_.__name__ = name
79103
_.__doc__ = doc
@@ -96,18 +120,15 @@ def _():
96120
>>> df.select(lit(5).alias('height')).withColumn('spark_user', lit(True)).take(1)
97121
[Row(height=5, spark_user=True)]
98122
"""
99-
_name_functions = {
100-
# name functions take a column name as their argument
123+
_functions = {
101124
'lit': _lit_doc,
102125
'col': 'Returns a :class:`Column` based on the given column name.',
103126
'column': 'Returns a :class:`Column` based on the given column name.',
104127
'asc': 'Returns a sort expression based on the ascending order of the given column name.',
105128
'desc': 'Returns a sort expression based on the descending order of the given column name.',
106129
}
107130

108-
_functions = {
109-
'upper': 'Converts a string expression to upper case.',
110-
'lower': 'Converts a string expression to upper case.',
131+
_functions_over_column = {
111132
'sqrt': 'Computes the square root of the specified float value.',
112133
'abs': 'Computes the absolute value.',
113134

@@ -120,7 +141,7 @@ def _():
120141
'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
121142
}
122143

123-
_functions_1_4 = {
144+
_functions_1_4_over_column = {
124145
# unary math functions
125146
'acos': ':return: inverse cosine of `col`, as if computed by `java.lang.Math.acos()`',
126147
'asin': ':return: inverse sine of `col`, as if computed by `java.lang.Math.asin()`',
@@ -155,7 +176,7 @@ def _():
155176
'bitwiseNOT': 'Computes bitwise not.',
156177
}
157178

158-
_name_functions_2_4 = {
179+
_functions_2_4 = {
159180
'asc_nulls_first': 'Returns a sort expression based on the ascending order of the given' +
160181
' column name, and null values return before non-null values.',
161182
'asc_nulls_last': 'Returns a sort expression based on the ascending order of the given' +
@@ -186,7 +207,7 @@ def _():
186207
>>> df2.agg(collect_set('age')).collect()
187208
[Row(collect_set(age)=[5, 2])]
188209
"""
189-
_functions_1_6 = {
210+
_functions_1_6_over_column = {
190211
# unary math functions
191212
'stddev': 'Aggregate function: returns the unbiased sample standard deviation of' +
192213
' the expression in a group.',
@@ -203,7 +224,7 @@ def _():
203224
'collect_set': _collect_set_doc
204225
}
205226

206-
_functions_2_1 = {
227+
_functions_2_1_over_column = {
207228
# unary math functions
208229
'degrees': """
209230
Converts an angle measured in radians to an approximately equivalent angle
@@ -268,24 +289,24 @@ def _():
268289
_functions_deprecated = {
269290
}
270291

271-
for _name, _doc in _name_functions.items():
272-
globals()[_name] = since(1.3)(_create_name_function(_name, _doc))
273292
for _name, _doc in _functions.items():
274293
globals()[_name] = since(1.3)(_create_function(_name, _doc))
275-
for _name, _doc in _functions_1_4.items():
276-
globals()[_name] = since(1.4)(_create_function(_name, _doc))
294+
for _name, _doc in _functions_over_column.items():
295+
globals()[_name] = since(1.3)(_create_function_over_column(_name, _doc))
296+
for _name, _doc in _functions_1_4_over_column.items():
297+
globals()[_name] = since(1.4)(_create_function_over_column(_name, _doc))
277298
for _name, _doc in _binary_mathfunctions.items():
278299
globals()[_name] = since(1.4)(_create_binary_mathfunction(_name, _doc))
279300
for _name, _doc in _window_functions.items():
280301
globals()[_name] = since(1.6)(_create_window_function(_name, _doc))
281-
for _name, _doc in _functions_1_6.items():
282-
globals()[_name] = since(1.6)(_create_function(_name, _doc))
283-
for _name, _doc in _functions_2_1.items():
284-
globals()[_name] = since(2.1)(_create_function(_name, _doc))
302+
for _name, _doc in _functions_1_6_over_column.items():
303+
globals()[_name] = since(1.6)(_create_function_over_column(_name, _doc))
304+
for _name, _doc in _functions_2_1_over_column.items():
305+
globals()[_name] = since(2.1)(_create_function_over_column(_name, _doc))
285306
for _name, _message in _functions_deprecated.items():
286307
globals()[_name] = _wrap_deprecated_function(globals()[_name], _message)
287-
for _name, _doc in _name_functions_2_4.items():
288-
globals()[_name] = since(2.4)(_create_name_function(_name, _doc))
308+
for _name, _doc in _functions_2_4.items():
309+
globals()[_name] = since(2.4)(_create_function(_name, _doc))
289310
del _name, _doc
290311

291312

@@ -1450,6 +1471,8 @@ def hash(*cols):
14501471
# ---------------------- String/Binary functions ------------------------------
14511472

14521473
_string_functions = {
1474+
'upper': 'Converts a string expression to upper case.',
1475+
'lower': 'Converts a string expression to lower case.',
14531476
'ascii': 'Computes the numeric value of the first character of the string column.',
14541477
'base64': 'Computes the BASE64 encoding of a binary column and returns it as a string column.',
14551478
'unbase64': 'Decodes a BASE64 encoded string column and returns it as a binary column.',
@@ -1460,7 +1483,7 @@ def hash(*cols):
14601483

14611484

14621485
for _name, _doc in _string_functions.items():
1463-
globals()[_name] = since(1.5)(_create_function(_name, _doc))
1486+
globals()[_name] = since(1.5)(_create_function_over_column(_name, _doc))
14641487
del _name, _doc
14651488

14661489

python/pyspark/sql/tests/test_functions.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,12 @@ def assert_close(a, b):
129129
df.select(functions.pow(df.a, 2.0)).collect())
130130
assert_close([math.hypot(i, 2 * i) for i in range(10)],
131131
df.select(functions.hypot(df.a, df.b)).collect())
132+
assert_close([math.hypot(i, 2 * i) for i in range(10)],
133+
df.select(functions.hypot("a", u"b")).collect())
134+
assert_close([math.hypot(i, 2) for i in range(10)],
135+
df.select(functions.hypot("a", 2)).collect())
136+
assert_close([math.hypot(i, 2) for i in range(10)],
137+
df.select(functions.hypot(df.a, 2)).collect())
132138

133139
def test_rand_functions(self):
134140
df = self.df
@@ -151,7 +157,8 @@ def test_rand_functions(self):
151157
self.assertEqual(sorted(rndn1), sorted(rndn2))
152158

153159
def test_string_functions(self):
154-
from pyspark.sql.functions import col, lit
160+
from pyspark.sql import functions
161+
from pyspark.sql.functions import col, lit, _string_functions
155162
df = self.spark.createDataFrame([['nick']], schema=['name'])
156163
self.assertRaisesRegexp(
157164
TypeError,
@@ -162,6 +169,11 @@ def test_string_functions(self):
162169
TypeError,
163170
lambda: df.select(col('name').substr(long(0), long(1))))
164171

172+
for name in _string_functions.keys():
173+
self.assertEqual(
174+
df.select(getattr(functions, name)("name")).first()[0],
175+
df.select(getattr(functions, name)(col("name"))).first()[0])
176+
165177
def test_array_contains_function(self):
166178
from pyspark.sql.functions import array_contains
167179

0 commit comments

Comments
 (0)