-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-22239][SQL][Python] Enable grouped aggregate pandas UDFs as window functions with unbounded window frames #21082
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
659e1df
8e00adb
a9b30df
04ae99f
5cb5c91
abdfd9e
4cfd5c4
27b6449
6864148
5140e2c
e54ea6b
019096b
136d83d
17e6578
328b2c4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2580,10 +2580,12 @@ def pandas_udf(f=None, returnType=None, functionType=None): | |
| The returned scalar can be either a python primitive type, e.g., `int` or `float` | ||
| or a numpy data type, e.g., `numpy.int64` or `numpy.float64`. | ||
|
|
||
| :class:`ArrayType`, :class:`MapType` and :class:`StructType` are currently not supported as | ||
| output types. | ||
| :class:`MapType` and :class:`StructType` are currently not supported as output types. | ||
|
|
||
| Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` | ||
| Group aggregate UDFs are used with :meth:`pyspark.sql.GroupedData.agg` and | ||
| :class:`pyspark.sql.Window` | ||
|
|
||
| This example shows using grouped aggregated UDFs with groupby: | ||
|
|
||
| >>> from pyspark.sql.functions import pandas_udf, PandasUDFType | ||
| >>> df = spark.createDataFrame( | ||
|
|
@@ -2600,7 +2602,31 @@ def pandas_udf(f=None, returnType=None, functionType=None): | |
| | 2| 6.0| | ||
| +---+-----------+ | ||
|
|
||
| .. seealso:: :meth:`pyspark.sql.GroupedData.agg` | ||
| This example shows using grouped aggregated UDFs as window functions. Note that only | ||
| unbounded window frame is supported at the moment: | ||
|
|
||
| >>> from pyspark.sql.functions import pandas_udf, PandasUDFType | ||
| >>> from pyspark.sql import Window | ||
| >>> df = spark.createDataFrame( | ||
| ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], | ||
| ... ("id", "v")) | ||
| >>> @pandas_udf("double", PandasUDFType.GROUPED_AGG) # doctest: +SKIP | ||
|
||
| ... def mean_udf(v): | ||
| ... return v.mean() | ||
| >>> w = Window.partitionBy('id') \\ | ||
| ... .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) | ||
| >>> df.withColumn('mean_v', mean_udf(df['v']).over(w)).show() # doctest: +SKIP | ||
| +---+----+------+ | ||
| | id| v|mean_v| | ||
| +---+----+------+ | ||
| | 1| 1.0| 1.5| | ||
| | 1| 2.0| 1.5| | ||
| | 2| 3.0| 6.0| | ||
| | 2| 5.0| 6.0| | ||
| | 2|10.0| 6.0| | ||
| +---+----+------+ | ||
|
|
||
| .. seealso:: :meth:`pyspark.sql.GroupedData.agg` and :class:`pyspark.sql.Window` | ||
|
|
||
| .. note:: The user-defined functions are considered deterministic by default. Due to | ||
| optimization, duplicate invocations may be eliminated or the function may even be invoked | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5454,6 +5454,15 @@ def test_retain_group_columns(self): | |
| expected1 = df.groupby(df.id).agg(sum(df.v)) | ||
| self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) | ||
|
|
||
| def test_array_type(self): | ||
|
||
| from pyspark.sql.functions import pandas_udf, PandasUDFType | ||
|
|
||
| df = self.data | ||
|
|
||
| array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG) | ||
| result1 = df.groupby('id').agg(array_udf(df['v']).alias('v2')) | ||
| self.assertEquals(result1.first()['v2'], [1.0, 2.0]) | ||
|
|
||
| def test_invalid_args(self): | ||
| from pyspark.sql.functions import mean | ||
|
|
||
|
|
@@ -5479,6 +5488,235 @@ def test_invalid_args(self): | |
| 'mixture.*aggregate function.*group aggregate pandas UDF'): | ||
| df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() | ||
|
|
||
|
|
||
| @unittest.skipIf( | ||
| not _have_pandas or not _have_pyarrow, | ||
| _pandas_requirement_message or _pyarrow_requirement_message) | ||
| class WindowPandasUDFTests(ReusedSQLTestCase): | ||
| @property | ||
| def data(self): | ||
| from pyspark.sql.functions import array, explode, col, lit | ||
| return self.spark.range(10).toDF('id') \ | ||
| .withColumn("vs", array([lit(i * 1.0) + col('id') for i in range(20, 30)])) \ | ||
| .withColumn("v", explode(col('vs'))) \ | ||
| .drop('vs') \ | ||
| .withColumn('w', lit(1.0)) | ||
|
|
||
| @property | ||
| def python_plus_one(self): | ||
|
||
| from pyspark.sql.functions import udf | ||
| return udf(lambda v: v + 1, 'double') | ||
|
|
||
| @property | ||
| def pandas_scalar_time_two(self): | ||
| from pyspark.sql.functions import pandas_udf, PandasUDFType | ||
| return pandas_udf(lambda v: v * 2, 'double') | ||
|
|
||
| @property | ||
| def pandas_agg_mean_udf(self): | ||
| from pyspark.sql.functions import pandas_udf, PandasUDFType | ||
|
|
||
| @pandas_udf('double', PandasUDFType.GROUPED_AGG) | ||
| def avg(v): | ||
| return v.mean() | ||
| return avg | ||
|
|
||
| @property | ||
| def pandas_agg_max_udf(self): | ||
| from pyspark.sql.functions import pandas_udf, PandasUDFType | ||
|
|
||
| @pandas_udf('double', PandasUDFType.GROUPED_AGG) | ||
| def max(v): | ||
| return v.max() | ||
| return max | ||
|
|
||
| @property | ||
| def pandas_agg_min_udf(self): | ||
| from pyspark.sql.functions import pandas_udf, PandasUDFType | ||
|
|
||
| @pandas_udf('double', PandasUDFType.GROUPED_AGG) | ||
| def min(v): | ||
| return v.min() | ||
| return min | ||
|
|
||
| @property | ||
| def unbounded_window(self): | ||
| return Window.partitionBy('id') \ | ||
| .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing) | ||
|
|
||
| @property | ||
| def ordered_window(self): | ||
| return Window.partitionBy('id').orderBy('v') | ||
|
|
||
| @property | ||
| def unpartitioned_window(self): | ||
| return Window.partitionBy() | ||
|
||
|
|
||
| def test_simple(self): | ||
| from pyspark.sql.functions import pandas_udf, PandasUDFType, percent_rank, mean, max | ||
|
|
||
| df = self.data | ||
| w = self.unbounded_window | ||
|
|
||
| mean_udf = self.pandas_agg_mean_udf | ||
|
|
||
| result1 = df.withColumn('mean_v', mean_udf(df['v']).over(w)) | ||
| expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) | ||
|
|
||
| result2 = df.select(mean_udf(df['v']).over(w)) | ||
| expected2 = df.select(mean(df['v']).over(w)) | ||
|
|
||
| self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) | ||
| self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) | ||
|
|
||
| def test_multiple_udfs(self): | ||
| from pyspark.sql.functions import max, min, mean | ||
|
|
||
| df = self.data | ||
| w = self.unbounded_window | ||
|
|
||
| result1 = df.withColumn('mean_v', self.pandas_agg_mean_udf(df['v']).over(w)) \ | ||
| .withColumn('max_v', self.pandas_agg_max_udf(df['v']).over(w)) \ | ||
| .withColumn('min_w', self.pandas_agg_min_udf(df['w']).over(w)) | ||
|
|
||
| expected1 = df.withColumn('mean_v', mean(df['v']).over(w)) \ | ||
| .withColumn('max_v', max(df['v']).over(w)) \ | ||
| .withColumn('min_w', min(df['w']).over(w)) | ||
|
|
||
| self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) | ||
|
|
||
| def test_replace_existing(self): | ||
| from pyspark.sql.functions import mean | ||
|
|
||
| df = self.data | ||
| w = self.unbounded_window | ||
|
|
||
| result1 = df.withColumn('v', self.pandas_agg_mean_udf(df['v']).over(w)) | ||
| expected1 = df.withColumn('v', mean(df['v']).over(w)) | ||
|
|
||
| self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) | ||
|
|
||
| def test_mixed_sql(self): | ||
| from pyspark.sql.functions import mean | ||
|
|
||
| df = self.data | ||
| w = self.unbounded_window | ||
| mean_udf = self.pandas_agg_mean_udf | ||
|
|
||
| result1 = df.withColumn('v', mean_udf(df['v'] * 2).over(w) + 1) | ||
| expected1 = df.withColumn('v', mean(df['v'] * 2).over(w) + 1) | ||
|
|
||
| self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) | ||
|
|
||
| def test_mixed_udf(self): | ||
| from pyspark.sql.functions import mean | ||
|
|
||
| df = self.data | ||
| w = self.unbounded_window | ||
|
|
||
| plus_one = self.python_plus_one | ||
| time_two = self.pandas_scalar_time_two | ||
| mean_udf = self.pandas_agg_mean_udf | ||
|
|
||
| result1 = df.withColumn( | ||
| 'v2', | ||
| plus_one(mean_udf(plus_one(df['v'])).over(w))) | ||
| expected1 = df.withColumn( | ||
| 'v2', | ||
| plus_one(mean(plus_one(df['v'])).over(w))) | ||
|
|
||
| result2 = df.withColumn( | ||
| 'v2', | ||
| time_two(mean_udf(time_two(df['v'])).over(w))) | ||
| expected2 = df.withColumn( | ||
| 'v2', | ||
| time_two(mean(time_two(df['v'])).over(w))) | ||
|
|
||
| self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) | ||
| self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) | ||
|
|
||
| def test_without_partitionBy(self): | ||
| from pyspark.sql.functions import mean | ||
|
|
||
| df = self.data | ||
| w = self.unpartitioned_window | ||
| mean_udf = self.pandas_agg_mean_udf | ||
|
|
||
| result1 = df.withColumn('v2', mean_udf(df['v']).over(w)) | ||
| expected1 = df.withColumn('v2', mean(df['v']).over(w)) | ||
|
|
||
| result2 = df.select(mean_udf(df['v']).over(w)) | ||
| expected2 = df.select(mean(df['v']).over(w)) | ||
|
|
||
| self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) | ||
| self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) | ||
|
|
||
| def test_mixed_sql_and_udf(self): | ||
| from pyspark.sql.functions import max, min, rank, col | ||
|
|
||
| df = self.data | ||
| w = self.unbounded_window | ||
| ow = self.ordered_window | ||
| max_udf = self.pandas_agg_max_udf | ||
| min_udf = self.pandas_agg_min_udf | ||
|
|
||
| result1 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min_udf(df['v']).over(w)) | ||
| expected1 = df.withColumn('v_diff', max(df['v']).over(w) - min(df['v']).over(w)) | ||
|
|
||
| # Test mixing sql window function and window udf in the same expression | ||
| result2 = df.withColumn('v_diff', max_udf(df['v']).over(w) - min(df['v']).over(w)) | ||
| expected2 = expected1 | ||
|
|
||
| # Test chaining sql aggregate function and udf | ||
| result3 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ | ||
| .withColumn('min_v', min(df['v']).over(w)) \ | ||
| .withColumn('v_diff', col('max_v') - col('min_v')) \ | ||
| .drop('max_v', 'min_v') | ||
| expected3 = expected1 | ||
|
|
||
| # Test mixing sql window function and udf | ||
| result4 = df.withColumn('max_v', max_udf(df['v']).over(w)) \ | ||
| .withColumn('rank', rank().over(ow)) | ||
| expected4 = df.withColumn('max_v', max(df['v']).over(w)) \ | ||
| .withColumn('rank', rank().over(ow)) | ||
|
|
||
| self.assertPandasEqual(expected1.toPandas(), result1.toPandas()) | ||
| self.assertPandasEqual(expected2.toPandas(), result2.toPandas()) | ||
| self.assertPandasEqual(expected3.toPandas(), result3.toPandas()) | ||
| self.assertPandasEqual(expected4.toPandas(), result4.toPandas()) | ||
|
|
||
| def test_array_type(self): | ||
| from pyspark.sql.functions import pandas_udf, PandasUDFType | ||
|
|
||
| df = self.data | ||
| w = self.unbounded_window | ||
|
|
||
| array_udf = pandas_udf(lambda x: [1.0, 2.0], 'array<double>', PandasUDFType.GROUPED_AGG) | ||
| result1 = df.withColumn('v2', array_udf(df['v']).over(w)) | ||
| self.assertEquals(result1.first()['v2'], [1.0, 2.0]) | ||
|
|
||
| def test_invalid_args(self): | ||
| from pyspark.sql.functions import mean, pandas_udf, PandasUDFType | ||
|
|
||
| df = self.data | ||
| w = self.unbounded_window | ||
| ow = self.ordered_window | ||
| mean_udf = self.pandas_agg_mean_udf | ||
|
|
||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp( | ||
| AnalysisException, | ||
| '.*not supported within a window function'): | ||
| foo_udf = pandas_udf(lambda x: x, 'v double', PandasUDFType.GROUPED_MAP) | ||
| df.withColumn('v2', foo_udf(df['v']).over(w)) | ||
|
|
||
| with QuietTest(self.sc): | ||
| with self.assertRaisesRegexp( | ||
| AnalysisException, | ||
| '.*Only unbounded window frame is supported.*'): | ||
| df.withColumn('mean_v', mean_udf(df['v']).over(ow)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| from pyspark.sql.tests import * | ||
| if xmlrunner: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -128,6 +128,21 @@ def wrapped(*series): | |
| return lambda *a: (wrapped(*a), arrow_return_type) | ||
|
|
||
|
|
||
| def wrap_window_agg_pandas_udf(f, return_type): | ||
| # This is similar to grouped_agg_pandas_udf, the only difference | ||
| # is that window_agg_pandas_udf needs to repeat the return value | ||
| # to match window length, where grouped_agg_pandas_udf just returns | ||
| # the scalar value. | ||
| arrow_return_type = to_arrow_type(return_type) | ||
|
|
||
| def wrapped(*series): | ||
| import pandas as pd | ||
| result = f(*series) | ||
| return pd.Series([result]).repeat(len(series[0])) | ||
|
||
|
|
||
| return lambda *a: (wrapped(*a), arrow_return_type) | ||
|
|
||
|
|
||
| def read_single_udf(pickleSer, infile, eval_type): | ||
| num_arg = read_int(infile) | ||
| arg_offsets = [read_int(infile) for i in range(num_arg)] | ||
|
|
@@ -151,6 +166,8 @@ def read_single_udf(pickleSer, infile, eval_type): | |
| return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) | ||
| elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: | ||
| return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) | ||
| elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: | ||
| return arg_offsets, wrap_window_agg_pandas_udf(func, return_type) | ||
| elif eval_type == PythonEvalType.SQL_BATCHED_UDF: | ||
| return arg_offsets, wrap_udf(func, return_type) | ||
| else: | ||
|
|
@@ -195,7 +212,8 @@ def read_udfs(pickleSer, infile, eval_type): | |
|
|
||
| if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, | ||
| PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, | ||
| PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): | ||
| PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, | ||
| PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF): | ||
| timezone = utf8_deserializer.loads(infile) | ||
| ser = ArrowStreamPandasSerializer(timezone) | ||
| else: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@icexelloss, actually should we keep this note? I think this is matched with https://spark.apache.org/docs/latest/sql-programming-guide.html#supported-sql-types which we documented there and SQLConf.
Probably, just leaving a link could be fine. Removing out is okay to me too. I think just adding a note for all the Pandas udfs works too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am leaning towards keeping this in the API doc and maybe make sql-programming-guide link to this.
I think most user would look for API docs first rather than sql-programming-guide, so it's probably a bit more convenient to have it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, I think that works too. I left a comment only because it looked mismatched with this api doc and the sql programming guide.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Track in: https://issues.apache.org/jira/browse/SPARK-23633