diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3c33e2bed92d..98e8f15f18b9 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -37,8 +37,8 @@ from pyspark.sql.udf import UserDefinedFunction, _create_udf -def _create_function(name, doc=""): - """ Create a function for aggregator by name""" +def _create_name_function(name, doc=""): + """ Create a function that takes a column name argument, by name""" def _(col): sc = SparkContext._active_spark_context jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col) @@ -48,6 +48,17 @@ def _(col): return _ +def _create_function(name, doc=""): + """ Create a function that takes a Column object, by name""" + def _(col): + sc = SparkContext._active_spark_context + jc = getattr(sc._jvm.functions, name)(_to_java_column(col)) + return Column(jc) + _.__name__ = name + _.__doc__ = doc + return _ + + def _wrap_deprecated_function(func, message): """ Wrap the deprecated function to print out deprecation warnings""" def _(col): @@ -85,13 +96,16 @@ def _(): >>> df.select(lit(5).alias('height')).withColumn('spark_user', lit(True)).take(1) [Row(height=5, spark_user=True)] """ -_functions = { +_name_functions = { + # name functions take a column name as their argument 'lit': _lit_doc, 'col': 'Returns a :class:`Column` based on the given column name.', 'column': 'Returns a :class:`Column` based on the given column name.', 'asc': 'Returns a sort expression based on the ascending order of the given column name.', 'desc': 'Returns a sort expression based on the descending order of the given column name.', +} +_functions = { 'upper': 'Converts a string expression to upper case.', 'lower': 'Converts a string expression to upper case.', 'sqrt': 'Computes the square root of the specified float value.', @@ -141,7 +155,7 @@ def _(): 'bitwiseNOT': 'Computes bitwise not.', } -_functions_2_4 = { +_name_functions_2_4 = { 'asc_nulls_first': 'Returns a sort expression based on the ascending order of the given' + ' column name, and null values return before non-null values.', 'asc_nulls_last': 'Returns a sort expression based on the ascending order of the given' + @@ -254,6 +268,8 @@ def _(): _functions_deprecated = { } +for _name, _doc in _name_functions.items(): + globals()[_name] = since(1.3)(_create_name_function(_name, _doc)) for _name, _doc in _functions.items(): globals()[_name] = since(1.3)(_create_function(_name, _doc)) for _name, _doc in _functions_1_4.items(): @@ -268,8 +284,8 @@ def _(): globals()[_name] = since(2.1)(_create_function(_name, _doc)) for _name, _message in _functions_deprecated.items(): globals()[_name] = _wrap_deprecated_function(globals()[_name], _message) -for _name, _doc in _functions_2_4.items(): - globals()[_name] = since(2.4)(_create_function(_name, _doc)) +for _name, _doc in _name_functions_2_4.items(): + globals()[_name] = since(2.4)(_create_name_function(_name, _doc)) del _name, _doc @@ -1437,10 +1453,6 @@ def hash(*cols): 'ascii': 'Computes the numeric value of the first character of the string column.', 'base64': 'Computes the BASE64 encoding of a binary column and returns it as a string column.', 'unbase64': 'Decodes a BASE64 encoded string column and returns it as a binary column.', - 'initcap': 'Returns a new string column by converting the first letter of each word to ' + - 'uppercase. Words are delimited by whitespace.', - 'lower': 'Converts a string column to lower case.', - 'upper': 'Converts a string column to upper case.', 'ltrim': 'Trim the spaces from left end for the specified string value.', 'rtrim': 'Trim the spaces from right end for the specified string value.', 'trim': 'Trim the spaces from both ends for the specified string column.',