Skip to content

Commit c4e4497

Browse files
wbo4958WeichenXu123
authored andcommitted
[SPARK-46812][SQL][PYTHON] Make mapInPandas / mapInArrow support ResourceProfile
### What changes were proposed in this pull request? Support stage-level scheduling for some PySpark DataFrame APIs (mapInPandas and mapInArrow). ### Why are the changes needed? The introduction of barrier mode in Spark, as seen in #40520, allows for the implementation of Spark ML cases (pure Python algorithms) using DataFrame APIs such as mapInPandas and mapInArrow, so it's necessary to enable stage-level scheduling for DataFrame APIs. ### Does this PR introduce _any_ user-facing change? Yes, This PR adds a new argument "profile" for mapInPandas and mapInArrow. ``` python def mapInPandas( self, func: "PandasMapIterFunction", schema: Union[StructType, str], barrier: bool = False, profile: Optional[ResourceProfile] = None, ) -> "DataFrame": def mapInArrow( self, func: "ArrowMapIterFunction", schema: Union[StructType, str], barrier: bool = False, profile: Optional[ResourceProfile] = None, ) -> "DataFrame": ``` How to use it? take mapInPandas as an example, ``` python from pyspark import TaskContext def func(iterator): tc = TaskContext.get() assert tc.cpus() == 3 for batch in iterator: yield batch df = spark.range(10) from pyspark.resource import TaskResourceRequests, ResourceProfileBuilder treqs = TaskResourceRequests().cpus(3) rp = ResourceProfileBuilder().require(treqs).build df.mapInPandas(func, "id long", False, rp).collect() ``` ### How was this patch tested? The newly added tests can pass, and some manual tests are needed for dynamic allocation on or off. ### Was this patch authored or co-authored using generative AI tooling? No Closes #44852 from wbo4958/df-rp. Lead-authored-by: Bobby Wang <[email protected]> Co-authored-by: Bobby Wang <[email protected]> Signed-off-by: Weichen Xu <[email protected]>
1 parent 0818096 commit c4e4497

File tree

12 files changed

+206
-25
lines changed

12 files changed

+206
-25
lines changed

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -549,13 +549,15 @@ class SparkConnectPlanner(
549549
pythonUdf,
550550
DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType]),
551551
baseRel,
552-
isBarrier)
552+
isBarrier,
553+
None)
553554
case PythonEvalType.SQL_MAP_ARROW_ITER_UDF =>
554555
logical.MapInArrow(
555556
pythonUdf,
556557
DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType]),
557558
baseRel,
558-
isBarrier)
559+
isBarrier,
560+
None)
559561
case _ =>
560562
throw InvalidPlanInput(
561563
s"Function with EvalType: ${pythonUdf.evalType} is not supported")

dev/sparktestsupport/modules.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ def __hash__(self):
530530
"pyspark.sql.tests.test_udf_profiler",
531531
"pyspark.sql.tests.test_udtf",
532532
"pyspark.sql.tests.test_utils",
533+
"pyspark.sql.tests.test_resources",
533534
],
534535
)
535536

python/pyspark/sql/pandas/map_ops.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,13 @@
1515
# limitations under the License.
1616
#
1717
import sys
18-
from typing import Union, TYPE_CHECKING
18+
from typing import Union, TYPE_CHECKING, Optional
1919

20+
from py4j.java_gateway import JavaObject
21+
22+
from pyspark.resource.requests import ExecutorResourceRequests, TaskResourceRequests
2023
from pyspark.rdd import PythonEvalType
24+
from pyspark.resource import ResourceProfile
2125
from pyspark.sql.types import StructType
2226

2327
if TYPE_CHECKING:
@@ -32,7 +36,11 @@ class PandasMapOpsMixin:
3236
"""
3337

3438
def mapInPandas(
35-
self, func: "PandasMapIterFunction", schema: Union[StructType, str], barrier: bool = False
39+
self,
40+
func: "PandasMapIterFunction",
41+
schema: Union[StructType, str],
42+
barrier: bool = False,
43+
profile: Optional[ResourceProfile] = None,
3644
) -> "DataFrame":
3745
"""
3846
Maps an iterator of batches in the current :class:`DataFrame` using a Python native
@@ -65,6 +73,12 @@ def mapInPandas(
6573
6674
.. versionadded: 3.5.0
6775
76+
profile : :class:`pyspark.resource.ResourceProfile`. The optional ResourceProfile
77+
to be used for mapInPandas.
78+
79+
.. versionadded: 4.0.0
80+
81+
6882
Examples
6983
--------
7084
>>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
@@ -141,11 +155,17 @@ def mapInPandas(
141155
func, returnType=schema, functionType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
142156
) # type: ignore[call-overload]
143157
udf_column = udf(*[self[col] for col in self.columns])
144-
jdf = self._jdf.mapInPandas(udf_column._jc.expr(), barrier)
158+
159+
jrp = self._build_java_profile(profile)
160+
jdf = self._jdf.mapInPandas(udf_column._jc.expr(), barrier, jrp)
145161
return DataFrame(jdf, self.sparkSession)
146162

147163
def mapInArrow(
148-
self, func: "ArrowMapIterFunction", schema: Union[StructType, str], barrier: bool = False
164+
self,
165+
func: "ArrowMapIterFunction",
166+
schema: Union[StructType, str],
167+
barrier: bool = False,
168+
profile: Optional[ResourceProfile] = None,
149169
) -> "DataFrame":
150170
"""
151171
Maps an iterator of batches in the current :class:`DataFrame` using a Python native
@@ -175,6 +195,11 @@ def mapInArrow(
175195
176196
.. versionadded: 3.5.0
177197
198+
profile : :class:`pyspark.resource.ResourceProfile`. The optional ResourceProfile
199+
to be used for mapInArrow.
200+
201+
.. versionadded: 4.0.0
202+
178203
Examples
179204
--------
180205
>>> import pyarrow # doctest: +SKIP
@@ -220,9 +245,35 @@ def mapInArrow(
220245
func, returnType=schema, functionType=PythonEvalType.SQL_MAP_ARROW_ITER_UDF
221246
) # type: ignore[call-overload]
222247
udf_column = udf(*[self[col] for col in self.columns])
223-
jdf = self._jdf.mapInArrow(udf_column._jc.expr(), barrier)
248+
249+
jrp = self._build_java_profile(profile)
250+
jdf = self._jdf.mapInArrow(udf_column._jc.expr(), barrier, jrp)
224251
return DataFrame(jdf, self.sparkSession)
225252

253+
def _build_java_profile(
254+
self, profile: Optional[ResourceProfile] = None
255+
) -> Optional[JavaObject]:
256+
"""Build the java ResourceProfile based on PySpark ResourceProfile"""
257+
from pyspark.sql import DataFrame
258+
259+
assert isinstance(self, DataFrame)
260+
261+
jrp = None
262+
if profile is not None:
263+
if profile._java_resource_profile is not None:
264+
jrp = profile._java_resource_profile
265+
else:
266+
jvm = self.sparkSession.sparkContext._jvm
267+
assert jvm is not None
268+
269+
builder = jvm.org.apache.spark.resource.ResourceProfileBuilder()
270+
ereqs = ExecutorResourceRequests(jvm, profile._executor_resource_requests)
271+
treqs = TaskResourceRequests(jvm, profile._task_resource_requests)
272+
builder.require(ereqs._java_executor_resource_requests)
273+
builder.require(treqs._java_task_resource_requests)
274+
jrp = builder.build()
275+
return jrp
276+
226277

227278
def _test() -> None:
228279
import doctest
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one or more
3+
# contributor license agreements. See the NOTICE file distributed with
4+
# this work for additional information regarding copyright ownership.
5+
# The ASF licenses this file to You under the Apache License, Version 2.0
6+
# (the "License"); you may not use this file except in compliance with
7+
# the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
import unittest
18+
19+
from pyspark import SparkContext, TaskContext
20+
from pyspark.resource import TaskResourceRequests, ResourceProfileBuilder
21+
from pyspark.sql import SparkSession
22+
from pyspark.testing.sqlutils import (
23+
have_pandas,
24+
have_pyarrow,
25+
pandas_requirement_message,
26+
pyarrow_requirement_message,
27+
)
28+
from pyspark.testing.utils import ReusedPySparkTestCase
29+
30+
31+
@unittest.skipIf(
32+
not have_pandas or not have_pyarrow,
33+
pandas_requirement_message or pyarrow_requirement_message,
34+
)
35+
class ResourceProfileTestsMixin(object):
36+
def test_map_in_arrow_without_profile(self):
37+
def func(iterator):
38+
tc = TaskContext.get()
39+
assert tc.cpus() == 1
40+
for batch in iterator:
41+
yield batch
42+
43+
df = self.spark.range(10)
44+
df.mapInArrow(func, "id long").collect()
45+
46+
def test_map_in_arrow_with_profile(self):
47+
def func(iterator):
48+
tc = TaskContext.get()
49+
assert tc.cpus() == 3
50+
for batch in iterator:
51+
yield batch
52+
53+
df = self.spark.range(10)
54+
55+
treqs = TaskResourceRequests().cpus(3)
56+
rp = ResourceProfileBuilder().require(treqs).build
57+
df.mapInArrow(func, "id long", False, rp).collect()
58+
59+
def test_map_in_pandas_without_profile(self):
60+
def func(iterator):
61+
tc = TaskContext.get()
62+
assert tc.cpus() == 1
63+
for batch in iterator:
64+
yield batch
65+
66+
df = self.spark.range(10)
67+
df.mapInPandas(func, "id long").collect()
68+
69+
def test_map_in_pandas_with_profile(self):
70+
def func(iterator):
71+
tc = TaskContext.get()
72+
assert tc.cpus() == 3
73+
for batch in iterator:
74+
yield batch
75+
76+
df = self.spark.range(10)
77+
78+
treqs = TaskResourceRequests().cpus(3)
79+
rp = ResourceProfileBuilder().require(treqs).build
80+
df.mapInPandas(func, "id long", False, rp).collect()
81+
82+
83+
class ResourceProfileTests(ResourceProfileTestsMixin, ReusedPySparkTestCase):
84+
@classmethod
85+
def setUpClass(cls):
86+
cls.sc = SparkContext("local-cluster[1, 4, 1024]", cls.__name__, conf=cls.conf())
87+
cls.spark = SparkSession(cls.sc)
88+
89+
@classmethod
90+
def tearDownClass(cls):
91+
super(ResourceProfileTests, cls).tearDownClass()
92+
cls.spark.stop()
93+
94+
95+
if __name__ == "__main__":
96+
from pyspark.sql.tests.test_resources import * # noqa: F401
97+
98+
try:
99+
import xmlrunner
100+
101+
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
102+
except ImportError:
103+
testRunner = None
104+
unittest.main(testRunner=testRunner, verbosity=2)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/DeduplicateRelations.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,13 +382,13 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
382382
newVersion.copyTagsFrom(oldVersion)
383383
Seq((oldVersion, newVersion))
384384

385-
case oldVersion @ MapInPandas(_, output, _, _)
385+
case oldVersion @ MapInPandas(_, output, _, _, _)
386386
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
387387
val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
388388
newVersion.copyTagsFrom(oldVersion)
389389
Seq((oldVersion, newVersion))
390390

391-
case oldVersion @ MapInArrow(_, output, _, _)
391+
case oldVersion @ MapInArrow(_, output, _, _, _)
392392
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
393393
val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
394394
newVersion.copyTagsFrom(oldVersion)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.plans.logical
1919

20+
import org.apache.spark.resource.ResourceProfile
2021
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF, PythonUDTF}
2122
import org.apache.spark.sql.catalyst.trees.TreePattern._
2223
import org.apache.spark.sql.catalyst.util.truncatedString
@@ -77,7 +78,8 @@ case class MapInPandas(
7778
functionExpr: Expression,
7879
output: Seq[Attribute],
7980
child: LogicalPlan,
80-
isBarrier: Boolean) extends UnaryNode {
81+
isBarrier: Boolean,
82+
profile: Option[ResourceProfile]) extends UnaryNode {
8183

8284
override val producedAttributes = AttributeSet(output)
8385

@@ -93,7 +95,8 @@ case class MapInArrow(
9395
functionExpr: Expression,
9496
output: Seq[Attribute],
9597
child: LogicalPlan,
96-
isBarrier: Boolean) extends UnaryNode {
98+
isBarrier: Boolean,
99+
profile: Option[ResourceProfile]) extends UnaryNode {
97100

98101
override val producedAttributes = AttributeSet(output)
99102

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
709709
pythonUdf,
710710
output,
711711
project,
712-
false)
712+
false,
713+
None)
713714
val left = SubqueryAlias("temp0", mapInPandas)
714715
val right = SubqueryAlias("temp1", mapInPandas)
715716
val join = Join(left, right, Inner, None, JoinHint.NONE)
@@ -729,7 +730,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
729730
pythonUdf,
730731
output,
731732
project,
732-
false)
733+
false,
734+
None)
733735
assertAnalysisSuccess(mapInPandas)
734736
}
735737

@@ -745,7 +747,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
745747
pythonUdf,
746748
output,
747749
project,
748-
false)
750+
false,
751+
None)
749752
assertAnalysisSuccess(mapInArrow)
750753
}
751754

sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
3737
import org.apache.spark.api.r.RRDD
3838
import org.apache.spark.broadcast.Broadcast
3939
import org.apache.spark.rdd.RDD
40+
import org.apache.spark.resource.ResourceProfile
4041
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QueryPlanningTracker, ScalaReflection, TableIdentifier}
4142
import org.apache.spark.sql.catalyst.analysis._
4243
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
@@ -3515,29 +3516,37 @@ class Dataset[T] private[sql](
35153516
* This function uses Apache Arrow as serialization format between Java executors and Python
35163517
* workers.
35173518
*/
3518-
private[sql] def mapInPandas(func: PythonUDF, isBarrier: Boolean = false): DataFrame = {
3519+
private[sql] def mapInPandas(
3520+
func: PythonUDF,
3521+
isBarrier: Boolean = false,
3522+
profile: ResourceProfile = null): DataFrame = {
35193523
Dataset.ofRows(
35203524
sparkSession,
35213525
MapInPandas(
35223526
func,
35233527
toAttributes(func.dataType.asInstanceOf[StructType]),
35243528
logicalPlan,
3525-
isBarrier))
3529+
isBarrier,
3530+
Option(profile)))
35263531
}
35273532

35283533
/**
35293534
* Applies a function to each partition in Arrow format. The user-defined function
35303535
* defines a transformation: `iter(pyarrow.RecordBatch)` -> `iter(pyarrow.RecordBatch)`.
35313536
* Each partition is each iterator consisting of `pyarrow.RecordBatch`s as batches.
35323537
*/
3533-
private[sql] def mapInArrow(func: PythonUDF, isBarrier: Boolean = false): DataFrame = {
3538+
private[sql] def mapInArrow(
3539+
func: PythonUDF,
3540+
isBarrier: Boolean = false,
3541+
profile: ResourceProfile = null): DataFrame = {
35343542
Dataset.ofRows(
35353543
sparkSession,
35363544
MapInArrow(
35373545
func,
35383546
toAttributes(func.dataType.asInstanceOf[StructType]),
35393547
logicalPlan,
3540-
isBarrier))
3548+
isBarrier,
3549+
Option(profile)))
35413550
}
35423551

35433552
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -867,10 +867,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
867867
execution.python.FlatMapCoGroupsInArrowExec(
868868
f.leftAttributes, f.rightAttributes,
869869
func, output, planLater(left), planLater(right)) :: Nil
870-
case logical.MapInPandas(func, output, child, isBarrier) =>
871-
execution.python.MapInPandasExec(func, output, planLater(child), isBarrier) :: Nil
872-
case logical.MapInArrow(func, output, child, isBarrier) =>
873-
execution.python.MapInArrowExec(func, output, planLater(child), isBarrier) :: Nil
870+
case logical.MapInPandas(func, output, child, isBarrier, profile) =>
871+
execution.python.MapInPandasExec(func, output, planLater(child), isBarrier, profile) :: Nil
872+
case logical.MapInArrow(func, output, child, isBarrier, profile) =>
873+
execution.python.MapInArrowExec(func, output, planLater(child), isBarrier, profile) :: Nil
874874
case logical.AttachDistributedSequence(attr, child) =>
875875
execution.python.AttachDistributedSequenceExec(attr, planLater(child)) :: Nil
876876
case logical.MapElements(f, _, _, objAttr, child) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInArrowExec.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark.sql.execution.python
1919

2020
import org.apache.spark.api.python.PythonEvalType
21+
import org.apache.spark.resource.ResourceProfile
2122
import org.apache.spark.sql.catalyst.expressions._
2223
import org.apache.spark.sql.execution.SparkPlan
2324

@@ -29,7 +30,8 @@ case class MapInArrowExec(
2930
func: Expression,
3031
output: Seq[Attribute],
3132
child: SparkPlan,
32-
override val isBarrier: Boolean)
33+
override val isBarrier: Boolean,
34+
override val profile: Option[ResourceProfile])
3335
extends MapInBatchExec {
3436

3537
override protected val pythonEvalType: Int = PythonEvalType.SQL_MAP_ARROW_ITER_UDF

0 commit comments

Comments
 (0)