Skip to content

Commit 2cb23a8

Browse files
icexellossHyukjinKwon
authored andcommitted
[SPARK-23011][SQL][PYTHON] Support alternative function form with group aggregate pandas UDF
## What changes were proposed in this pull request? This PR proposes to support an alternative function from with group aggregate pandas UDF. The current form: ``` def foo(pdf): return ... ``` Takes a single arg that is a pandas DataFrame. With this PR, an alternative form is supported: ``` def foo(key, pdf): return ... ``` The alternative form takes two argument - a tuple that presents the grouping key, and a pandas DataFrame represents the data. ## How was this patch tested? GroupbyApplyTests Author: Li Jin <[email protected]> Closes #20295 from icexelloss/SPARK-23011-groupby-apply-key.
1 parent d6632d1 commit 2cb23a8

File tree

8 files changed

+294
-55
lines changed

8 files changed

+294
-55
lines changed

python/pyspark/serializers.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,15 @@ def __init__(self, timezone):
250250
super(ArrowStreamPandasSerializer, self).__init__()
251251
self._timezone = timezone
252252

253+
def arrow_to_pandas(self, arrow_column):
254+
from pyspark.sql.types import from_arrow_type, \
255+
_check_series_convert_date, _check_series_localize_timestamps
256+
257+
s = arrow_column.to_pandas()
258+
s = _check_series_convert_date(s, from_arrow_type(arrow_column.type))
259+
s = _check_series_localize_timestamps(s, self._timezone)
260+
return s
261+
253262
def dump_stream(self, iterator, stream):
254263
"""
255264
Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or
@@ -272,16 +281,11 @@ def load_stream(self, stream):
272281
"""
273282
Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
274283
"""
275-
from pyspark.sql.types import from_arrow_schema, _check_dataframe_convert_date, \
276-
_check_dataframe_localize_timestamps
277284
import pyarrow as pa
278285
reader = pa.open_stream(stream)
279-
schema = from_arrow_schema(reader.schema)
286+
280287
for batch in reader:
281-
pdf = batch.to_pandas()
282-
pdf = _check_dataframe_convert_date(pdf, schema)
283-
pdf = _check_dataframe_localize_timestamps(pdf, self._timezone)
284-
yield [c for _, c in pdf.iteritems()]
288+
yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
285289

286290
def __repr__(self):
287291
return "ArrowStreamPandasSerializer"

python/pyspark/sql/functions.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2267,6 +2267,31 @@ def pandas_udf(f=None, returnType=None, functionType=None):
22672267
| 2| 1.1094003924504583|
22682268
+---+-------------------+
22692269
2270+
Alternatively, the user can define a function that takes two arguments.
2271+
In this case, the grouping key will be passed as the first argument and the data will
2272+
be passed as the second argument. The grouping key will be passed as a tuple of numpy
2273+
data types, e.g., `numpy.int32` and `numpy.float64`. The data will still be passed in
2274+
as a `pandas.DataFrame` containing all columns from the original Spark DataFrame.
2275+
This is useful when the user does not want to hardcode grouping key in the function.
2276+
2277+
>>> from pyspark.sql.functions import pandas_udf, PandasUDFType
2278+
>>> import pandas as pd # doctest: +SKIP
2279+
>>> df = spark.createDataFrame(
2280+
... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
2281+
... ("id", "v")) # doctest: +SKIP
2282+
>>> @pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP) # doctest: +SKIP
2283+
... def mean_udf(key, pdf):
2284+
... # key is a tuple of one numpy.int64, which is the value
2285+
... # of 'id' for the current group
2286+
... return pd.DataFrame([key + (pdf.v.mean(),)])
2287+
>>> df.groupby('id').apply(mean_udf).show() # doctest: +SKIP
2288+
+---+---+
2289+
| id| v|
2290+
+---+---+
2291+
| 1|1.5|
2292+
| 2|6.0|
2293+
+---+---+
2294+
22702295
.. seealso:: :meth:`pyspark.sql.GroupedData.apply`
22712296
22722297
3. GROUPED_AGG

python/pyspark/sql/tests.py

Lines changed: 110 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3903,7 +3903,7 @@ def foo(df):
39033903
return df
39043904
with self.assertRaisesRegexp(ValueError, 'Invalid function'):
39053905
@pandas_udf(returnType='k int, v double', functionType=PandasUDFType.GROUPED_MAP)
3906-
def foo(k, v):
3906+
def foo(k, v, w):
39073907
return k
39083908

39093909

@@ -4476,20 +4476,45 @@ def test_supported_types(self):
44764476
from pyspark.sql.functions import pandas_udf, PandasUDFType, array, col
44774477
df = self.data.withColumn("arr", array(col("id")))
44784478

4479-
foo_udf = pandas_udf(
4479+
# Different forms of group map pandas UDF, results of these are the same
4480+
4481+
output_schema = StructType(
4482+
[StructField('id', LongType()),
4483+
StructField('v', IntegerType()),
4484+
StructField('arr', ArrayType(LongType())),
4485+
StructField('v1', DoubleType()),
4486+
StructField('v2', LongType())])
4487+
4488+
udf1 = pandas_udf(
44804489
lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
4481-
StructType(
4482-
[StructField('id', LongType()),
4483-
StructField('v', IntegerType()),
4484-
StructField('arr', ArrayType(LongType())),
4485-
StructField('v1', DoubleType()),
4486-
StructField('v2', LongType())]),
4490+
output_schema,
44874491
PandasUDFType.GROUPED_MAP
44884492
)
44894493

4490-
result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
4491-
expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
4492-
self.assertPandasEqual(expected, result)
4494+
udf2 = pandas_udf(
4495+
lambda _, pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
4496+
output_schema,
4497+
PandasUDFType.GROUPED_MAP
4498+
)
4499+
4500+
udf3 = pandas_udf(
4501+
lambda key, pdf: pdf.assign(id=key[0], v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id),
4502+
output_schema,
4503+
PandasUDFType.GROUPED_MAP
4504+
)
4505+
4506+
result1 = df.groupby('id').apply(udf1).sort('id').toPandas()
4507+
expected1 = df.toPandas().groupby('id').apply(udf1.func).reset_index(drop=True)
4508+
4509+
result2 = df.groupby('id').apply(udf2).sort('id').toPandas()
4510+
expected2 = expected1
4511+
4512+
result3 = df.groupby('id').apply(udf3).sort('id').toPandas()
4513+
expected3 = expected1
4514+
4515+
self.assertPandasEqual(expected1, result1)
4516+
self.assertPandasEqual(expected2, result2)
4517+
self.assertPandasEqual(expected3, result3)
44934518

44944519
def test_register_grouped_map_udf(self):
44954520
from pyspark.sql.functions import pandas_udf, PandasUDFType
@@ -4648,6 +4673,80 @@ def test_timestamp_dst(self):
46484673
result = df.groupby('time').apply(foo_udf).sort('time')
46494674
self.assertPandasEqual(df.toPandas(), result.toPandas())
46504675

4676+
def test_udf_with_key(self):
4677+
from pyspark.sql.functions import pandas_udf, col, PandasUDFType
4678+
df = self.data
4679+
pdf = df.toPandas()
4680+
4681+
def foo1(key, pdf):
4682+
import numpy as np
4683+
assert type(key) == tuple
4684+
assert type(key[0]) == np.int64
4685+
4686+
return pdf.assign(v1=key[0],
4687+
v2=pdf.v * key[0],
4688+
v3=pdf.v * pdf.id,
4689+
v4=pdf.v * pdf.id.mean())
4690+
4691+
def foo2(key, pdf):
4692+
import numpy as np
4693+
assert type(key) == tuple
4694+
assert type(key[0]) == np.int64
4695+
assert type(key[1]) == np.int32
4696+
4697+
return pdf.assign(v1=key[0],
4698+
v2=key[1],
4699+
v3=pdf.v * key[0],
4700+
v4=pdf.v + key[1])
4701+
4702+
def foo3(key, pdf):
4703+
assert type(key) == tuple
4704+
assert len(key) == 0
4705+
return pdf.assign(v1=pdf.v * pdf.id)
4706+
4707+
# v2 is int because numpy.int64 * pd.Series<int32> results in pd.Series<int32>
4708+
# v3 is long because pd.Series<int64> * pd.Series<int32> results in pd.Series<int64>
4709+
udf1 = pandas_udf(
4710+
foo1,
4711+
'id long, v int, v1 long, v2 int, v3 long, v4 double',
4712+
PandasUDFType.GROUPED_MAP)
4713+
4714+
udf2 = pandas_udf(
4715+
foo2,
4716+
'id long, v int, v1 long, v2 int, v3 int, v4 int',
4717+
PandasUDFType.GROUPED_MAP)
4718+
4719+
udf3 = pandas_udf(
4720+
foo3,
4721+
'id long, v int, v1 long',
4722+
PandasUDFType.GROUPED_MAP)
4723+
4724+
# Test groupby column
4725+
result1 = df.groupby('id').apply(udf1).sort('id', 'v').toPandas()
4726+
expected1 = pdf.groupby('id')\
4727+
.apply(lambda x: udf1.func((x.id.iloc[0],), x))\
4728+
.sort_values(['id', 'v']).reset_index(drop=True)
4729+
self.assertPandasEqual(expected1, result1)
4730+
4731+
# Test groupby expression
4732+
result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas()
4733+
expected2 = pdf.groupby(pdf.id % 2)\
4734+
.apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\
4735+
.sort_values(['id', 'v']).reset_index(drop=True)
4736+
self.assertPandasEqual(expected2, result2)
4737+
4738+
# Test complex groupby
4739+
result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas()
4740+
expected3 = pdf.groupby([pdf.id, pdf.v % 2])\
4741+
.apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\
4742+
.sort_values(['id', 'v']).reset_index(drop=True)
4743+
self.assertPandasEqual(expected3, result3)
4744+
4745+
# Test empty groupby
4746+
result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas()
4747+
expected4 = udf3.func((), pdf)
4748+
self.assertPandasEqual(expected4, result4)
4749+
46514750

46524751
@unittest.skipIf(
46534752
not _have_pandas or not _have_pyarrow,

python/pyspark/sql/types.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1695,6 +1695,19 @@ def from_arrow_schema(arrow_schema):
16951695
for field in arrow_schema])
16961696

16971697

1698+
def _check_series_convert_date(series, data_type):
1699+
"""
1700+
Cast the series to datetime.date if it's a date type, otherwise returns the original series.
1701+
1702+
:param series: pandas.Series
1703+
:param data_type: a Spark data type for the series
1704+
"""
1705+
if type(data_type) == DateType:
1706+
return series.dt.date
1707+
else:
1708+
return series
1709+
1710+
16981711
def _check_dataframe_convert_date(pdf, schema):
16991712
""" Correct date type value to use datetime.date.
17001713
@@ -1705,8 +1718,7 @@ def _check_dataframe_convert_date(pdf, schema):
17051718
:param schema: a Spark schema of the pandas.DataFrame
17061719
"""
17071720
for field in schema:
1708-
if type(field.dataType) == DateType:
1709-
pdf[field.name] = pdf[field.name].dt.date
1721+
pdf[field.name] = _check_series_convert_date(pdf[field.name], field.dataType)
17101722
return pdf
17111723

17121724

@@ -1725,6 +1737,29 @@ def _get_local_timezone():
17251737
return os.environ.get('TZ', 'dateutil/:')
17261738

17271739

1740+
def _check_series_localize_timestamps(s, timezone):
1741+
"""
1742+
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone.
1743+
1744+
If the input series is not a timestamp series, then the same series is returned. If the input
1745+
series is a timestamp series, then a converted series is returned.
1746+
1747+
:param s: pandas.Series
1748+
:param timezone: the timezone to convert. if None then use local timezone
1749+
:return pandas.Series that have been converted to tz-naive
1750+
"""
1751+
from pyspark.sql.utils import require_minimum_pandas_version
1752+
require_minimum_pandas_version()
1753+
1754+
from pandas.api.types import is_datetime64tz_dtype
1755+
tz = timezone or _get_local_timezone()
1756+
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
1757+
if is_datetime64tz_dtype(s.dtype):
1758+
return s.dt.tz_convert(tz).dt.tz_localize(None)
1759+
else:
1760+
return s
1761+
1762+
17281763
def _check_dataframe_localize_timestamps(pdf, timezone):
17291764
"""
17301765
Convert timezone aware timestamps to timezone-naive in the specified timezone or local timezone
@@ -1736,12 +1771,8 @@ def _check_dataframe_localize_timestamps(pdf, timezone):
17361771
from pyspark.sql.utils import require_minimum_pandas_version
17371772
require_minimum_pandas_version()
17381773

1739-
from pandas.api.types import is_datetime64tz_dtype
1740-
tz = timezone or _get_local_timezone()
17411774
for column, series in pdf.iteritems():
1742-
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
1743-
if is_datetime64tz_dtype(series.dtype):
1744-
pdf[column] = series.dt.tz_convert(tz).dt.tz_localize(None)
1775+
pdf[column] = _check_series_localize_timestamps(series, timezone)
17451776
return pdf
17461777

17471778

python/pyspark/sql/udf.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,16 @@
1717
"""
1818
User-defined function related classes and functions
1919
"""
20+
import sys
21+
import inspect
2022
import functools
2123

2224
from pyspark import SparkContext, since
2325
from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_prefix
2426
from pyspark.sql.column import Column, _to_java_column, _to_seq
2527
from pyspark.sql.types import StringType, DataType, ArrayType, StructType, MapType, \
2628
_parse_datatype_string, to_arrow_type, to_arrow_schema
29+
from pyspark.util import _get_argspec
2730

2831
__all__ = ["UDFRegistration"]
2932

@@ -41,18 +44,10 @@ def _create_udf(f, returnType, evalType):
4144
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
4245
PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
4346

44-
import inspect
45-
import sys
4647
from pyspark.sql.utils import require_minimum_pyarrow_version
47-
4848
require_minimum_pyarrow_version()
4949

50-
if sys.version_info[0] < 3:
51-
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
52-
# See SPARK-23569.
53-
argspec = inspect.getargspec(f)
54-
else:
55-
argspec = inspect.getfullargspec(f)
50+
argspec = _get_argspec(f)
5651

5752
if evalType == PythonEvalType.SQL_SCALAR_PANDAS_UDF and len(argspec.args) == 0 and \
5853
argspec.varargs is None:
@@ -61,11 +56,11 @@ def _create_udf(f, returnType, evalType):
6156
"Instead, create a 1-arg pandas_udf and ignore the arg in your function."
6257
)
6358

64-
if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF and len(argspec.args) != 1:
59+
if evalType == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF \
60+
and len(argspec.args) not in (1, 2):
6561
raise ValueError(
6662
"Invalid function: pandas_udfs with function type GROUPED_MAP "
67-
"must take a single arg that is a pandas DataFrame."
68-
)
63+
"must take either one argument (data) or two arguments (key, data).")
6964

7065
# Set the name of the UserDefinedFunction object to be the name of function f
7166
udf_obj = UserDefinedFunction(

python/pyspark/util.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
# See the License for the specific language governing permissions and
1616
# limitations under the License.
1717
#
18+
19+
import sys
20+
import inspect
1821
from py4j.protocol import Py4JJavaError
1922

2023
__all__ = []
@@ -45,6 +48,19 @@ def _exception_message(excp):
4548
return str(excp)
4649

4750

51+
def _get_argspec(f):
52+
"""
53+
Get argspec of a function. Supports both Python 2 and Python 3.
54+
"""
55+
# `getargspec` is deprecated since python3.0 (incompatible with function annotations).
56+
# See SPARK-23569.
57+
if sys.version_info[0] < 3:
58+
argspec = inspect.getargspec(f)
59+
else:
60+
argspec = inspect.getfullargspec(f)
61+
return argspec
62+
63+
4864
if __name__ == "__main__":
4965
import doctest
5066
(failure_count, test_count) = doctest.testmod()

0 commit comments

Comments
 (0)