-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23011][SQL][PYTHON] Support alternative function form with group aggregate pandas UDF #20295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-23011][SQL][PYTHON] Support alternative function form with group aggregate pandas UDF #20295
Conversation
|
Test build #86286 has finished for PR 20295 at commit
|
|
cc @ueshin @HyukjinKwon @cloud-fan @viirya This PR implements discussion here #20211 (review). There are more refinement needs to be done but I'd like to get some early feedback whether this approach looks good in general. The general idea is to pass grouping columns as extra columns to the python worker and to use I can also implement some kind of de duplicate logic in WDYT? |
|
Test build #86290 has finished for PR 20295 at commit
|
|
How are you going to send the group columns? For a group we have only one group row and a bunch of data rows. |
|
@cloud-fan Currently I send group columns along with the extra data column. For example, if the original DataFrame has I implemented it this way because it doesn't change the existing serialization protocol. Alternatively, we can implement a new serialization protocol for GROUP_MAP eval type, i.e, instead of sending an arrow batch, we could send a group row and then an arrow batch. What do you think? |
|
How do we turn a single group column to a series? just repeat the group column? |
|
Yep, that's correct. |
|
To me, seems roughly fine.
I don't have a strong preference on this. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm also neutral basically, but I'd prefer the new serialization if there is a simple way and performant enough.
python/pyspark/serializers.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can remove the comment above (# NOTE: ...) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually don't know what the comment above means. @BryanCutler do you remember?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the note was because prev we iterated over Arrow columns and converted each to a Series, then changed to convert an Arrow batch to DataFrame and then iterated over DataFrame columns to get a Series. I wasn't sure if there might be a perf decrease, so I left the note but I'm not sure why it wasn't done like the above in the first place - seems like it would be just as good as the original. Anyway, yeah the note can be removed now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BryanCutler Thanks for the clarification. I removed the note.
|
Let me experiment with new serialization approach. Will update here. |
BryanCutler
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sounds fine with me, I think that different serialization is slightly better to avoid duplicate data. Would the group Row be sent as a separate Arrow batch?
Regarding the API, I missed the original discussion, but just as an additional thought, a while back I proposed having an optional kwargs to each pandas_udf to deal with 0-param udfs. If we were to do that, the group Row could be placed in there and then there wouldn't need to be 2 types of signatures to allow for an optional key arg. I can see why it might be preferable to have an explicit key though, so it's up to you guys - just thought I'd mention this again.
python/pyspark/serializers.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, the note was because prev we iterated over Arrow columns and converted each to a Series, then changed to convert an Arrow batch to DataFrame and then iterated over DataFrame columns to get a Series. I wasn't sure if there might be a perf decrease, so I left the note but I'm not sure why it wasn't done like the above in the first place - seems like it would be just as good as the original. Anyway, yeah the note can be removed now.
c7fccde to
259edc5
Compare
|
Hi all, I did some digging and I think adding a serialization form that serialize a key object along with a Arrow record batch is quite complicated because we are using ArrowStreamReader/Writer for sending batches and send extra key data would have to use a lower level Arrow API for sending/receiving batches. I did two things to convince myself the current approach is fine:
We will not send extra grouping columns because those are already part of data columns. Instead, we will just use the corresponding data column to get grouping key to pass to user function. However, if user calls: then an extra column
I'd like to leave the work for more flexible arrow serialization as future work because it doesn't seems to affect performance of this patch and proceed with the current patch based on the two points above. What do you guys think? |
|
Test build #86606 has finished for PR 20295 at commit
|
|
Test build #86604 has finished for PR 20295 at commit
|
|
Test build #86651 has finished for PR 20295 at commit
|
python/pyspark/sql/udf.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should update the error message here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aha good catch. Fixed.
|
For #20295 (comment), I am fine without new serialization protocol actually. I didn't have a strong preference there because I wasn't sure if it's worth - complexity vs actual gain vaguely and seems that's now clarified there. I am okay with the current approach. @BryanCutler, I think the intention here is to follow other few APIs and |
|
@HyukjinKwon Thanks for the comment. I will continue with the current approach unless objection raises. I will work on comments and refinements in the next day or two. |
|
@HyukjinKwon @ueshin This is ready for review. I addressed the comments so far. @BryanCutler yeah I think kwargs is another option. But I think the API in this PR is more consistent with the exsiting APIs though. |
|
Test build #86773 has finished for PR 20295 at commit
|
edb77dc to
2668251
Compare
|
Rebased |
|
Test build #86834 has finished for PR 20295 at commit
|
|
Test build #86836 has finished for PR 20295 at commit
|
|
Test build #86837 has finished for PR 20295 at commit
|
Yeah if it's consistent with other APIs then sounds fine with me. My concern was in giving the user too many options that it starts to get confusing to make UDFs. If it's a familiar API then that probably won't be the case. |
9ed3779 to
722ed50
Compare
|
Addressed all comments and manually tested the example in docstring. |
|
Test build #87968 has finished for PR 20295 at commit
|
|
Will merge this one if there's no more comments or not merged within few days. |
python/pyspark/worker.py
Outdated
| def wrapped(*series): | ||
| def wrapped(key_series, value_series): | ||
| import pandas as pd | ||
| argspec = inspect.getargspec(f) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this also do getfullargspec for py3 like in udf.py?
maybe it would be useful to put a function in util.py, what do you guys think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. Let me take a look at that.
|
LGTM except for @BryanCutler's suggestion (#20295 (comment)). Thanks! |
|
@icexelloss Could you annotate |
|
Test build #88020 has finished for PR 20295 at commit
|
|
retest this please |
python/pyspark/sql/udf.py
Outdated
| sc.pythonVer, broadcast_vars, sc._javaAccumulator) | ||
|
|
||
|
|
||
| def _get_argspec(f): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about putting this in pyspark.util? It might be useful in places other than sql
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sense. Moved to pyspark.util.
|
Test build #88025 has finished for PR 20295 at commit
|
|
Test build #88049 has finished for PR 20295 at commit
|
|
Test build #88050 has finished for PR 20295 at commit
|
|
Test build #88048 has finished for PR 20295 at commit
|
|
Merged to master. |
|
BTW, let's don't forget to fix the doc later .. |
|
Thanks all for review! @HyukjinKwon do you mean this doc? I can update it now or we can update later in batch before 2.4 release. What do you prefer? |
|
Yup. Maybe we could do that when we are close to 2.4. |
|
Sounds good. Let's track in https://issues.apache.org/jira/browse/SPARK-23633 |
| self.assertPandasEqual(expected2, result2) | ||
|
|
||
| # Test complex groupby | ||
| result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any negative test case when the number of columns specified in groupby is different from the definition of udf (foo2)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For end users, the misuse of this alternative functions could be common. For example, do we issue an appropriate error in the following cases?
- result3 = df.groupby(df.id).apply(udf2).sort('id', 'v').toPandas()
- result3 = df.groupby(df.id, df.v % 2, df.id).apply(udf2).sort('id', 'v').toPandas()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case, any error in this case will be thrown as is from worker.py side which is read and redirect to users end via JVM. For instance:
from pyspark.sql.functions import pandas_udf, PandasUDFType
def test_func(key, pdf):
assert len(key) == 0
return pdf
udf1 = pandas_udf(test_func, "id long, v1 double", PandasUDFType.GROUPED_MAP)
spark.range(10).groupby('id').apply(udf1).sort('id').show()18/09/04 14:22:52 ERROR TaskSetManager: Task 1 in stage 0.0 failed 1 times; aborting job
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/.../spark/python/pyspark/sql/dataframe.py", line 378, in show
print(self._jdf.showString(n, 20, vertical))
File "/.../spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__
File "/.../spark/python/pyspark/sql/utils.py", line 63, in deco
return f(*a, **kw)
File "/.../spark/python/lib/py4j-0.10.7-src.zip/py4j/protocol.py", line 328, in get_return_value
py4j.protocol.Py4JJavaError: An error occurred while calling o68.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 1 in stage 0.0 failed 1 times, most recent failure: Lost task 1.0 in stage 0.0 (TID 1, localhost, executor driver): org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/.../python/lib/pyspark.zip/pyspark/worker.py", line 353, in main
process()
File "/.../python/lib/pyspark.zip/pyspark/worker.py", line 348, in process
serializer.dump_stream(func(split_index, iterator), outfile)
File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 242, in <lambda>
func = lambda _, it: map(mapper, it)
File "<string>", line 1, in <lambda>
File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 110, in wrapped
result = f(key, pd.concat(value_series, axis=1))
File "/.../spark/python/pyspark/util.py", line 99, in wrapper
return f(*args, **kwargs)
File "<stdin>", line 2, in test_func
AssertionError
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:418)
at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:172)
at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:372)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
at scala.collection.convert.Wrappers$IteratorWrapper.hasNext(Wrappers.scala:30)
at org.spark_project.guava.collect.Ordering.leastOf(Ordering.java:628)
at org.apache.spark.util.collection.Utils$.takeOrdered(Utils.scala:37)
at org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$29.apply(RDD.scala:1427)
at org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$29.apply(RDD.scala:1424)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:800)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:800)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:48)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:128)
at org.apache.spark.executor.Executor$TaskRunner$$anonfun$7.apply(Executor.scala:367)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1348)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:373)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
Driver stacktrace:
at org.apache.spark.scheduler.DAGScheduler.org$apache$spark$scheduler$DAGScheduler$$failJobAndIndependentStages(DAGScheduler.scala:1822)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1810)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$abortStage$1.apply(DAGScheduler.scala:1809)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:1809)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
at org.apache.spark.scheduler.DAGScheduler$$anonfun$handleTaskSetFailed$1.apply(DAGScheduler.scala:926)
at scala.Option.foreach(Option.scala:257)
at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:926)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2043)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1992)
at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:1981)
at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:737)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2061)
at org.apache.spark.SparkContext.runJob(SparkContext.scala:2158)
at org.apache.spark.rdd.RDD$$anonfun$reduce$1.apply(RDD.scala:1029)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
at org.apache.spark.rdd.RDD.reduce(RDD.scala:1011)
at org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1.apply(RDD.scala:1433)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
at org.apache.spark.rdd.RDD.withScope(RDD.scala:363)
at org.apache.spark.rdd.RDD.takeOrdered(RDD.scala:1420)
at org.apache.spark.sql.execution.TakeOrderedAndProjectExec.executeCollect(limit.scala:207)
at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$collectFromPlan(Dataset.scala:3384)
at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
at org.apache.spark.sql.Dataset$$anonfun$head$1.apply(Dataset.scala:2545)
at org.apache.spark.sql.Dataset$$anonfun$53.apply(Dataset.scala:3365)
at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1.apply(SQLExecution.scala:78)
at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:125)
at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:73)
at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3364)
at org.apache.spark.sql.Dataset.head(Dataset.scala:2545)
at org.apache.spark.sql.Dataset.take(Dataset.scala:2759)
at org.apache.spark.sql.Dataset.getRows(Dataset.scala:255)
at org.apache.spark.sql.Dataset.showString(Dataset.scala:292)
at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
at java.lang.reflect.Method.invoke(Method.java:498)
at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
at py4j.Gateway.invoke(Gateway.java:282)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:238)
at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.api.python.PythonException: Traceback (most recent call last):
File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 353, in main
process()
File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 348, in process
serializer.dump_stream(func(split_index, iterator), outfile)
File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 242, in <lambda>
func = lambda _, it: map(mapper, it)
File "<string>", line 1, in <lambda>
File "/.../spark/python/lib/pyspark.zip/pyspark/worker.py", line 110, in wrapped
result = f(key, pd.concat(value_series, axis=1))
File "/.../spark/python/pyspark/util.py", line 99, in wrapper
return f(*args, **kwargs)
File "<stdin>", line 2, in test_func
AssertionError
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.handlePythonException(PythonRunner.scala:418)
at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:172)
at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.read(ArrowPythonRunner.scala:122)
at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:372)
at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
at scala.collection.Iterator$$anon$12.hasNext(Iterator.scala:440)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:409)
at scala.collection.convert.Wrappers$IteratorWrapper.hasNext(Wrappers.scala:30)
at org.spark_project.guava.collect.Ordering.leastOf(Ordering.java:628)
at org.apache.spark.util.collection.Utils$.takeOrdered(Utils.scala:37)
at org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$29.apply(RDD.scala:1427)
at org.apache.spark.rdd.RDD$$anonfun$takeOrdered$1$$anonfun$29.apply(RDD.scala:1424)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:800)
at org.apache.spark.rdd.RDD$$anonfun$mapPartitions$1$$anonfun$apply$23.apply(RDD.scala:800)
at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:48)
at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:324)
at org.apache.spark.rdd.RDD.iterator(RDD.scala:288)
at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:90)
at org.apache.spark.scheduler.Task.run(Task.scala:128)
at org.apache.spark.executor.Executor$TaskRunner$$anonfun$7.apply(Executor.scala:367)
at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1348)
at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:373)
at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
... 1 more
What changes were proposed in this pull request?
This PR proposes to support an alternative function from with group aggregate pandas UDF.
The current form:
Takes a single arg that is a pandas DataFrame.
With this PR, an alternative form is supported:
The alternative form takes two argument - a tuple that presents the grouping key, and a pandas DataFrame represents the data.
How was this patch tested?
GroupbyApplyTests