Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -549,13 +549,15 @@ class SparkConnectPlanner(
pythonUdf,
DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType]),
baseRel,
isBarrier)
isBarrier,
None)
case PythonEvalType.SQL_MAP_ARROW_ITER_UDF =>
logical.MapInArrow(
pythonUdf,
DataTypeUtils.toAttributes(pythonUdf.dataType.asInstanceOf[StructType]),
baseRel,
isBarrier)
isBarrier,
None)
case _ =>
throw InvalidPlanInput(
s"Function with EvalType: ${pythonUdf.evalType} is not supported")
Expand Down
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ def __hash__(self):
"pyspark.sql.tests.test_udf_profiler",
"pyspark.sql.tests.test_udtf",
"pyspark.sql.tests.test_utils",
"pyspark.sql.tests.test_resources",
],
)

Expand Down
61 changes: 56 additions & 5 deletions python/pyspark/sql/pandas/map_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
# limitations under the License.
#
import sys
from typing import Union, TYPE_CHECKING
from typing import Union, TYPE_CHECKING, Optional

from py4j.java_gateway import JavaObject

from pyspark.resource.requests import ExecutorResourceRequests, TaskResourceRequests
from pyspark.rdd import PythonEvalType
from pyspark.resource import ResourceProfile
from pyspark.sql.types import StructType

if TYPE_CHECKING:
Expand All @@ -32,7 +36,11 @@ class PandasMapOpsMixin:
"""

def mapInPandas(
self, func: "PandasMapIterFunction", schema: Union[StructType, str], barrier: bool = False
self,
func: "PandasMapIterFunction",
schema: Union[StructType, str],
barrier: bool = False,
profile: Optional[ResourceProfile] = None,
) -> "DataFrame":
"""
Maps an iterator of batches in the current :class:`DataFrame` using a Python native
Expand Down Expand Up @@ -65,6 +73,12 @@ def mapInPandas(

.. versionadded: 3.5.0

profile : :class:`pyspark.resource.ResourceProfile`. The optional ResourceProfile
to be used for mapInPandas.

.. versionadded: 4.0.0


Examples
--------
>>> df = spark.createDataFrame([(1, 21), (2, 30)], ("id", "age"))
Expand Down Expand Up @@ -141,11 +155,17 @@ def mapInPandas(
func, returnType=schema, functionType=PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
) # type: ignore[call-overload]
udf_column = udf(*[self[col] for col in self.columns])
jdf = self._jdf.mapInPandas(udf_column._jc.expr(), barrier)

jrp = self._build_java_profile(profile)
jdf = self._jdf.mapInPandas(udf_column._jc.expr(), barrier, jrp)
return DataFrame(jdf, self.sparkSession)

def mapInArrow(
self, func: "ArrowMapIterFunction", schema: Union[StructType, str], barrier: bool = False
self,
func: "ArrowMapIterFunction",
schema: Union[StructType, str],
barrier: bool = False,
profile: Optional[ResourceProfile] = None,
) -> "DataFrame":
"""
Maps an iterator of batches in the current :class:`DataFrame` using a Python native
Expand Down Expand Up @@ -175,6 +195,11 @@ def mapInArrow(

.. versionadded: 3.5.0

profile : :class:`pyspark.resource.ResourceProfile`. The optional ResourceProfile
to be used for mapInArrow.

.. versionadded: 4.0.0

Examples
--------
>>> import pyarrow # doctest: +SKIP
Expand Down Expand Up @@ -220,9 +245,35 @@ def mapInArrow(
func, returnType=schema, functionType=PythonEvalType.SQL_MAP_ARROW_ITER_UDF
) # type: ignore[call-overload]
udf_column = udf(*[self[col] for col in self.columns])
jdf = self._jdf.mapInArrow(udf_column._jc.expr(), barrier)

jrp = self._build_java_profile(profile)
jdf = self._jdf.mapInArrow(udf_column._jc.expr(), barrier, jrp)
return DataFrame(jdf, self.sparkSession)

def _build_java_profile(
self, profile: Optional[ResourceProfile] = None
) -> Optional[JavaObject]:
"""Build the java ResourceProfile based on PySpark ResourceProfile"""
from pyspark.sql import DataFrame

assert isinstance(self, DataFrame)

jrp = None
if profile is not None:
if profile._java_resource_profile is not None:
jrp = profile._java_resource_profile
else:
jvm = self.sparkSession.sparkContext._jvm
assert jvm is not None

builder = jvm.org.apache.spark.resource.ResourceProfileBuilder()
ereqs = ExecutorResourceRequests(jvm, profile._executor_resource_requests)
treqs = TaskResourceRequests(jvm, profile._task_resource_requests)
builder.require(ereqs._java_executor_resource_requests)
builder.require(treqs._java_task_resource_requests)
jrp = builder.build()
return jrp


def _test() -> None:
import doctest
Expand Down
104 changes: 104 additions & 0 deletions python/pyspark/sql/tests/test_resources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest

from pyspark import SparkContext, TaskContext
from pyspark.resource import TaskResourceRequests, ResourceProfileBuilder
from pyspark.sql import SparkSession
from pyspark.testing.sqlutils import (
have_pandas,
have_pyarrow,
pandas_requirement_message,
pyarrow_requirement_message,
)
from pyspark.testing.utils import ReusedPySparkTestCase


@unittest.skipIf(
not have_pandas or not have_pyarrow,
pandas_requirement_message or pyarrow_requirement_message,
)
class ResourceProfileTestsMixin(object):
def test_map_in_arrow_without_profile(self):
def func(iterator):
tc = TaskContext.get()
assert tc.cpus() == 1
for batch in iterator:
yield batch

df = self.spark.range(10)
df.mapInArrow(func, "id long").collect()

def test_map_in_arrow_with_profile(self):
def func(iterator):
tc = TaskContext.get()
assert tc.cpus() == 3
for batch in iterator:
yield batch

df = self.spark.range(10)

treqs = TaskResourceRequests().cpus(3)
rp = ResourceProfileBuilder().require(treqs).build
df.mapInArrow(func, "id long", False, rp).collect()

def test_map_in_pandas_without_profile(self):
def func(iterator):
tc = TaskContext.get()
assert tc.cpus() == 1
for batch in iterator:
yield batch

df = self.spark.range(10)
df.mapInPandas(func, "id long").collect()

def test_map_in_pandas_with_profile(self):
def func(iterator):
tc = TaskContext.get()
assert tc.cpus() == 3
for batch in iterator:
yield batch

df = self.spark.range(10)

treqs = TaskResourceRequests().cpus(3)
rp = ResourceProfileBuilder().require(treqs).build
df.mapInPandas(func, "id long", False, rp).collect()


class ResourceProfileTests(ResourceProfileTestsMixin, ReusedPySparkTestCase):
@classmethod
def setUpClass(cls):
cls.sc = SparkContext("local-cluster[1, 4, 1024]", cls.__name__, conf=cls.conf())
cls.spark = SparkSession(cls.sc)

@classmethod
def tearDownClass(cls):
super(ResourceProfileTests, cls).tearDownClass()
cls.spark.stop()


if __name__ == "__main__":
from pyspark.sql.tests.test_resources import * # noqa: F401

try:
import xmlrunner

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)
Original file line number Diff line number Diff line change
Expand Up @@ -382,13 +382,13 @@ object DeduplicateRelations extends Rule[LogicalPlan] {
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ MapInPandas(_, output, _, _)
case oldVersion @ MapInPandas(_, output, _, _, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
newVersion.copyTagsFrom(oldVersion)
Seq((oldVersion, newVersion))

case oldVersion @ MapInArrow(_, output, _, _)
case oldVersion @ MapInArrow(_, output, _, _, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
val newVersion = oldVersion.copy(output = output.map(_.newInstance()))
newVersion.copyTagsFrom(oldVersion)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.plans.logical

import org.apache.spark.resource.ResourceProfile
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, PythonUDF, PythonUDTF}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util.truncatedString
Expand Down Expand Up @@ -77,7 +78,8 @@ case class MapInPandas(
functionExpr: Expression,
output: Seq[Attribute],
child: LogicalPlan,
isBarrier: Boolean) extends UnaryNode {
isBarrier: Boolean,
profile: Option[ResourceProfile]) extends UnaryNode {

override val producedAttributes = AttributeSet(output)

Expand All @@ -93,7 +95,8 @@ case class MapInArrow(
functionExpr: Expression,
output: Seq[Attribute],
child: LogicalPlan,
isBarrier: Boolean) extends UnaryNode {
isBarrier: Boolean,
profile: Option[ResourceProfile]) extends UnaryNode {

override val producedAttributes = AttributeSet(output)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
pythonUdf,
output,
project,
false)
false,
None)
val left = SubqueryAlias("temp0", mapInPandas)
val right = SubqueryAlias("temp1", mapInPandas)
val join = Join(left, right, Inner, None, JoinHint.NONE)
Expand All @@ -729,7 +730,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
pythonUdf,
output,
project,
false)
false,
None)
assertAnalysisSuccess(mapInPandas)
}

Expand All @@ -745,7 +747,8 @@ class AnalysisSuite extends AnalysisTest with Matchers {
pythonUdf,
output,
project,
false)
false,
None)
assertAnalysisSuccess(mapInArrow)
}

Expand Down
17 changes: 13 additions & 4 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
import org.apache.spark.api.r.RRDD
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.resource.ResourceProfile
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QueryPlanningTracker, ScalaReflection, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
Expand Down Expand Up @@ -3515,29 +3516,37 @@ class Dataset[T] private[sql](
* This function uses Apache Arrow as serialization format between Java executors and Python
* workers.
*/
private[sql] def mapInPandas(func: PythonUDF, isBarrier: Boolean = false): DataFrame = {
private[sql] def mapInPandas(
func: PythonUDF,
isBarrier: Boolean = false,
profile: ResourceProfile = null): DataFrame = {
Dataset.ofRows(
sparkSession,
MapInPandas(
func,
toAttributes(func.dataType.asInstanceOf[StructType]),
logicalPlan,
isBarrier))
isBarrier,
Option(profile)))
}

/**
* Applies a function to each partition in Arrow format. The user-defined function
* defines a transformation: `iter(pyarrow.RecordBatch)` -> `iter(pyarrow.RecordBatch)`.
* Each partition is each iterator consisting of `pyarrow.RecordBatch`s as batches.
*/
private[sql] def mapInArrow(func: PythonUDF, isBarrier: Boolean = false): DataFrame = {
private[sql] def mapInArrow(
func: PythonUDF,
isBarrier: Boolean = false,
profile: ResourceProfile = null): DataFrame = {
Dataset.ofRows(
sparkSession,
MapInArrow(
func,
toAttributes(func.dataType.asInstanceOf[StructType]),
logicalPlan,
isBarrier))
isBarrier,
Option(profile)))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -866,10 +866,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.python.FlatMapCoGroupsInArrowExec(
f.leftAttributes, f.rightAttributes,
func, output, planLater(left), planLater(right)) :: Nil
case logical.MapInPandas(func, output, child, isBarrier) =>
execution.python.MapInPandasExec(func, output, planLater(child), isBarrier) :: Nil
case logical.MapInArrow(func, output, child, isBarrier) =>
execution.python.MapInArrowExec(func, output, planLater(child), isBarrier) :: Nil
case logical.MapInPandas(func, output, child, isBarrier, profile) =>
execution.python.MapInPandasExec(func, output, planLater(child), isBarrier, profile) :: Nil
case logical.MapInArrow(func, output, child, isBarrier, profile) =>
execution.python.MapInArrowExec(func, output, planLater(child), isBarrier, profile) :: Nil
case logical.AttachDistributedSequence(attr, child) =>
execution.python.AttachDistributedSequenceExec(attr, planLater(child)) :: Nil
case logical.MapElements(f, _, _, objAttr, child) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.python

import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.resource.ResourceProfile
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan

Expand All @@ -29,7 +30,8 @@ case class MapInArrowExec(
func: Expression,
output: Seq[Attribute],
child: SparkPlan,
override val isBarrier: Boolean)
override val isBarrier: Boolean,
override val profile: Option[ResourceProfile])
extends MapInBatchExec {

override protected val pythonEvalType: Int = PythonEvalType.SQL_MAP_ARROW_ITER_UDF
Expand Down
Loading