diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 91a7f093cec19..11497695f85f2 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -249,6 +249,15 @@ def __init__(self, timezone): super(ArrowStreamPandasSerializer, self).__init__() self._timezone = timezone + def arrow_to_pandas(self, arrow_column): + from pyspark.sql.types import from_arrow_type, \ + _check_series_convert_date, _check_series_localize_timestamps + + s = arrow_column.to_pandas() + s = _check_series_convert_date(s, from_arrow_type(arrow_column.type)) + s = _check_series_localize_timestamps(s, self._timezone) + return s + def dump_stream(self, iterator, stream): """ Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or @@ -271,16 +280,11 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ - from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \ - _check_dataframe_localize_timestamps import pyarrow as pa reader = pa.open_stream(stream) - schema = from_arrow_schema(reader.schema) + for batch in reader: - pdf = batch.to_pandas() - pdf = _check_dataframe_convert_date(pdf, schema) - pdf = _check_dataframe_localize_timestamps(pdf, self._timezone) - yield [c for _, c in pdf.iteritems()] + yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()] def __repr__(self): return "ArrowStreamPandasSerializer" diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b9c0c57262c5d..dc1341ac74d3d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2267,6 +2267,31 @@ def pandas_udf(f=None, returnType=None, functionType=None): | 2| 1.1094003924504583| +---+-------------------+ + Alternatively, the user can define a function that takes two arguments. + In this case, the grouping key will be passed as the first argument and the data will + be passed as the second argument. The grouping key will be passed as a tuple of numpy + data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in + as a `pandas.DataFrame` containing all columns from the original Spark DataFrame. + This is useful when the user does not want to hardcode grouping key in the function. + + >>> from pyspark.sql.functions import pandas_udf, PandasUDFType + >>> import pandas as pd # doctest: +SKIP + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) # doctest: +SKIP + >>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP + ... def mean_udf(key, pdf): + ... # key is a tuple of one numpy.int64, which is the value + ... # of 'id' for the current group + ... return pd.DataFrame([key + (pdf.v.mean(),)]) + >>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP + +---+---+ + | id| v| + +---+---+ + | 1|1.5| + | 2|6.0| + +---+---+ + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` 3. GROUPED_AGG diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index fa3b7203e10ac..fa2f1f3578c73 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3843,7 +3843,7 @@ def foo(df): return df with self.assertRaisesRegexp(ValueError, 'Invalid function'): @pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP) - def foo(k, v): + def foo(k, v, w): return k @@ -4416,20 +4416,45 @@ def test_supported_types(self): from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col df = self.data.withColumn("arr", array(col("id"))) - foo_udf = pandas_udf( + # Different forms of group map pandas UDF, results of these are the same + + output_schema = StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('arr', ArrayType(LongType())), + StructField('v1', DoubleType()), + StructField('v2', LongType())]) + + udf1 = pandas_udf( lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), - StructType( - [StructField('id', LongType()), - StructField('v', IntegerType()), - StructField('arr', ArrayType(LongType())), - StructField('v1', DoubleType()), - StructField('v2', LongType())]), + output_schema, PandasUDFType.GROUPED_MAP ) - result = df.groupby('id').apply(foo_udf).sort('id').toPandas() - expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) - self.assertPandasEqual(expected, result) + udf2 = pandas_udf( + lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + output_schema, + PandasUDFType.GROUPED_MAP + ) + + udf3 = pandas_udf( + lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + output_schema, + PandasUDFType.GROUPED_MAP + ) + + result1 = df.groupby('id').apply(udf1).sort('id').toPandas() + expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True) + + result2 = df.groupby('id').apply(udf2).sort('id').toPandas() + expected2 = expected1 + + result3 = df.groupby('id').apply(udf3).sort('id').toPandas() + expected3 = expected1 + + self.assertPandasEqual(expected1, result1) + self.assertPandasEqual(expected2, result2) + self.assertPandasEqual(expected3, result3) def test_register_grouped_map_udf(self): from pyspark.sql.functions import pandas_udf, PandasUDFType @@ -4588,6 +4613,80 @@ def test_timestamp_dst(self): result = df.groupby('time').apply(foo_udf).sort('time') self.assertPandasEqual(df.toPandas(), result.toPandas()) + def test_udf_with_key(self): + from pyspark.sql.functions import pandas_udf, col, PandasUDFType + df = self.data + pdf = df.toPandas() + + def foo1(key, pdf): + import numpy as np + assert type(key) == tuple + assert type(key[0]) == np.int64 + + return pdf.assign(v1=key[0], + v2=pdf.v * key[0], + v3=pdf.v * pdf.id, + v4=pdf.v * pdf.id.mean()) + + def foo2(key, pdf): + import numpy as np + assert type(key) == tuple + assert type(key[0]) == np.int64 + assert type(key[1]) == np.int32 + + return pdf.assign(v1=key[0], + v2=key[1], + v3=pdf.v * key[0], + v4=pdf.v + key[1]) + + def foo3(key, pdf): + assert type(key) == tuple + assert len(key) == 0 + return pdf.assign(v1=pdf.v * pdf.id) + + # v2 is int because numpy.int64 * pd.Series results in pd.Series + # v3 is long because pd.Series * pd.Series results in pd.Series + udf1 = pandas_udf( + foo1, + 'id long, v int, v1 long, v2 int, v3 long, v4 double', + PandasUDFType.GROUPED_MAP) + + udf2 = pandas_udf( + foo2, + 'id long, v int, v1 long, v2 int, v3 int, v4 int', + PandasUDFType.GROUPED_MAP) + + udf3 = pandas_udf( + foo3, + 'id long, v int, v1 long', + PandasUDFType.GROUPED_MAP) + + # Test groupby column + result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas() + expected1 = pdf.groupby('id')\ + .apply(lambda x: udf1.func((x.id.iloc[0],), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected1, result1) + + # Test groupby expression + result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas() + expected2 = pdf.groupby(pdf.id % 2)\ + .apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected2, result2) + + # Test complex groupby + result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas() + expected3 = pdf.groupby([pdf.id, pdf.v % 2])\ + .apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\ + .sort_values(['id', 'v']).reset_index(drop=True) + self.assertPandasEqual(expected3, result3) + + # Test empty groupby + result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas() + expected4 = udf3.func((), pdf) + self.assertPandasEqual(expected4, result4) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index cd857402db8f7..1632862d3f1ba 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1695,6 +1695,19 @@ def from_arrow_schema(arrow_schema): for field in arrow_schema]) +def _check_series_convert_date(series, data_type): + """ + Cast the series to datetime.date if it's a date type, otherwise returns the original series. + + :param series: pandas.Series + :param data_type: a Spark data type for the series + """ + if type(data_type) == DateType: + return series.dt.date + else: + return series + + def _check_dataframe_convert_date(pdf, schema): """ Correct date type value to use datetime.date. @@ -1705,8 +1718,7 @@ def _check_dataframe_convert_date(pdf, schema): :param schema: a Spark schema of the pandas.DataFrame """ for field in schema: - if type(field.dataType) == DateType: - pdf[field.name] = pdf[field.name].dt.date + pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType) return pdf @@ -1725,6 +1737,29 @@ def _get_local_timezone(): return os.environ.get('TZ', 'dateutil/:') +def _check_series_localize_timestamps(s, timezone): + """ + Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone. + + If the input series is not a timestamp series, then the same series is returned. If the input + series is a timestamp series, then a converted series is returned. + + :param s: pandas.Series + :param timezone: the timezone to convert. if None then use local timezone + :return pandas.Series that have been converted to tz-naive + """ + from pyspark.sql.utils import require_minimum_pandas_version + require_minimum_pandas_version() + + from pandas.api.types import is_datetime64tz_dtype + tz = timezone or _get_local_timezone() + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if is_datetime64tz_dtype(s.dtype): + return s.dt.tz_convert(tz).dt.tz_localize(None) + else: + return s + + def _check_dataframe_localize_timestamps(pdf, timezone): """ Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone @@ -1736,12 +1771,8 @@ def _check_dataframe_localize_timestamps(pdf, timezone): from pyspark.sql.utils import require_minimum_pandas_version require_minimum_pandas_version() - from pandas.api.types import is_datetime64tz_dtype - tz = timezone or _get_local_timezone() for column, series in pdf.iteritems(): - # TODO: handle nested timestamps, such as ArrayType(TimestampType())? - if is_datetime64tz_dtype(series.dtype): - pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None) + pdf[column] = _check_series_localize_timestamps(series, timezone) return pdf diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index b9b490874f4fb..ce804c18e9b14 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -17,6 +17,8 @@ """ User-defined function related classes and functions """ +import sys +import inspect import functools from pyspark import SparkContext, since @@ -24,6 +26,7 @@ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \ _parse_datatype_string, to_arrow_type, to_arrow_schema +from pyspark.util import _get_argspec __all__ = ["UDFRegistration"] @@ -41,18 +44,10 @@ def _create_udf(f, returnType, evalType): PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): - import inspect - import sys from pyspark.sql.utils import require_minimum_pyarrow_version - require_minimum_pyarrow_version() - if sys.version_info[0] < 3: - # `getargspec` is deprecated since python3.0 (incompatible with function annotations). - # See SPARK-23569. - argspec = inspect.getargspec(f) - else: - argspec = inspect.getfullargspec(f) + argspec = _get_argspec(f) if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \ argspec.varargs is None: @@ -61,11 +56,11 @@ def _create_udf(f, returnType, evalType): "Instead, create a 1-arg pandas_udf and ignore the arg in your function." ) - if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF and len(argspec.args) != 1: + if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \ + and len(argspec.args) not in (1, 2): raise ValueError( "Invalid function: pandas_udfs with function type GROUPED_MAP " - "must take a single arg that is a pandas DataFrame." - ) + "must take either one argument (data) or two arguments (key, data).") # Set the name of the UserDefinedFunction object to be the name of function f udf_obj = UserDefinedFunction( diff --git a/python/pyspark/util.py b/python/pyspark/util.py index ad4a0bc68ef41..6837b18b7d7a5 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -15,6 +15,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +import sys +import inspect from py4j.protocol import Py4JJavaError __all__ = [] @@ -45,6 +48,19 @@ def _exception_message(excp): return str(excp) +def _get_argspec(f): + """ + Get argspec of a function. Supports both Python 2 and Python 3. + """ + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. + if sys.version_info[0] < 3: + argspec = inspect.getargspec(f) + else: + argspec = inspect.getfullargspec(f) + return argspec + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 89a3a92bc66d6..202cac350aafc 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -34,6 +34,7 @@ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type +from pyspark.util import _get_argspec from pyspark import shuffle pickleSer = PickleSerializer() @@ -91,10 +92,16 @@ def verify_result_length(*a): def wrap_grouped_map_pandas_udf(f, return_type): - def wrapped(*series): + def wrapped(key_series, value_series): import pandas as pd + argspec = _get_argspec(f) + + if len(argspec.args) == 1: + result = f(pd.concat(value_series, axis=1)) + elif len(argspec.args) == 2: + key = tuple(s[0] for s in key_series) + result = f(key, pd.concat(value_series, axis=1)) - result = f(pd.concat(series, axis=1)) if not isinstance(result, pd.DataFrame): raise TypeError("Return type of the user-defined function should be " "pandas.DataFrame, but is {}".format(type(result))) @@ -149,18 +156,36 @@ def read_udfs(pickleSer, infile, eval_type): num_udfs = read_int(infile) udfs = {} call_udf = [] - for i in range(num_udfs): + mapper_str = "" + if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: + # Create function like this: + # lambda a: f([a[0]], [a[0], a[1]]) + + # We assume there is only one UDF here because grouped map doesn't + # support combining multiple UDFs. + assert num_udfs == 1 + + # See FlatMapGroupsInPandasExec for how arg_offsets are used to + # distinguish between grouping attributes and data attributes arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) - udfs['f%d' % i] = udf - args = ["a[%d]" % o for o in arg_offsets] - call_udf.append("f%d(%s)" % (i, ", ".join(args))) - # Create function like this: - # lambda a: (f0(a0), f1(a1, a2), f2(a3)) - # In the special case of a single UDF this will return a single result rather - # than a tuple of results; this is the format that the JVM side expects. - mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) - mapper = eval(mapper_str, udfs) + udfs['f'] = udf + split_offset = arg_offsets[0] + 1 + arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] + arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]] + mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1)) + else: + # Create function like this: + # lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3])) + # In the special case of a single UDF this will return a single result rather + # than a tuple of results; this is the format that the JVM side expects. + for i in range(num_udfs): + arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type) + udfs['f%d' % i] = udf + args = ["a[%d]" % o for o in arg_offsets] + call_udf.append("f%d(%s)" % (i, ", ".join(args))) + mapper_str = "lambda a: (%s)" % (", ".join(call_udf)) + mapper = eval(mapper_str, udfs) func = lambda _, it: map(mapper, it) if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index c798fe5a92c54..513e174c7733e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.python import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import org.apache.spark.TaskContext import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} @@ -75,20 +76,63 @@ case class FlatMapGroupsInPandasExec( val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) - val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray) - val schema = StructType(child.schema.drop(groupingAttributes.length)) val sessionLocalTimeZone = conf.sessionLocalTimeZone val pandasRespectSessionTimeZone = conf.pandasRespectSessionTimeZone + // Deduplicate the grouping attributes. + // If a grouping attribute also appears in data attributes, then we don't need to send the + // grouping attribute to Python worker. If a grouping attribute is not in data attributes, + // then we need to send this grouping attribute to python worker. + // + // We use argOffsets to distinguish grouping attributes and data attributes as following: + // + // argOffsets[0] is the length of grouping attributes + // argOffsets[1 .. argOffsets[0]+1] is the arg offsets for grouping attributes + // argOffsets[argOffsets[0]+1 .. ] is the arg offsets for data attributes + + val dataAttributes = child.output.drop(groupingAttributes.length) + val groupingIndicesInData = groupingAttributes.map { attribute => + dataAttributes.indexWhere(attribute.semanticEquals) + } + + val groupingArgOffsets = new ArrayBuffer[Int] + val nonDupGroupingAttributes = new ArrayBuffer[Attribute] + val nonDupGroupingSize = groupingIndicesInData.count(_ == -1) + + // Non duplicate grouping attributes are added to nonDupGroupingAttributes and + // their offsets are 0, 1, 2 ... + // Duplicate grouping attributes are NOT added to nonDupGroupingAttributes and + // their offsets are n + index, where n is the total number of non duplicate grouping + // attributes and index is the index in the data attributes that the grouping attribute + // is a duplicate of. + + groupingAttributes.zip(groupingIndicesInData).foreach { + case (attribute, index) => + if (index == -1) { + groupingArgOffsets += nonDupGroupingAttributes.length + nonDupGroupingAttributes += attribute + } else { + groupingArgOffsets += index + nonDupGroupingSize + } + } + + val dataArgOffsets = nonDupGroupingAttributes.length until + (nonDupGroupingAttributes.length + dataAttributes.length) + + val argOffsets = Array(Array(groupingAttributes.length) ++ groupingArgOffsets ++ dataArgOffsets) + + // Attributes after deduplication + val dedupAttributes = nonDupGroupingAttributes ++ dataAttributes + val dedupSchema = StructType.fromAttributes(dedupAttributes) + inputRDD.mapPartitionsInternal { iter => val grouped = if (groupingAttributes.isEmpty) { Iterator(iter) } else { val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) - val dropGrouping = - UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) + val dedupProj = UnsafeProjection.create(dedupAttributes, child.output) groupedIter.map { - case (_, groupedRowIter) => groupedRowIter.map(dropGrouping) + case (_, groupedRowIter) => groupedRowIter.map(dedupProj) } } @@ -96,7 +140,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, schema, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, argOffsets, dedupSchema, sessionLocalTimeZone, pandasRespectSessionTimeZone) .compute(grouped, context.partitionId(), context)