Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
Signed-off-by: Weichen Xu <[email protected]>
  • Loading branch information
WeichenXu123 committed Mar 22, 2023
commit e95fd9554c319fbd60827fb92cc584a765e994f8
Original file line number Diff line number Diff line change
Expand Up @@ -485,12 +485,14 @@ class SparkConnectPlanner(val session: SparkSession) {
logical.MapInPandas(
pythonUdf,
pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
transformRelation(rel.getInput))
transformRelation(rel.getInput),
false)
case PythonEvalType.SQL_MAP_ARROW_ITER_UDF =>
logical.PythonMapInArrow(
pythonUdf,
pythonUdf.dataType.asInstanceOf[StructType].toAttributes,
transformRelation(rel.getInput))
transformRelation(rel.getInput),
false)
case _ =>
throw InvalidPlanInput(s"Function with EvalType: ${pythonUdf.evalType} is not supported")
}
Expand Down
10 changes: 8 additions & 2 deletions python/pyspark/sql/pandas/map_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ class PandasMapOpsMixin:
"""

def mapInPandas(
self, func: "PandasMapIterFunction", schema: Union[StructType, str], is_barrier: bool
self,
func: "PandasMapIterFunction", schema: Union[StructType, str],
is_barrier: bool = False
) -> "DataFrame":
"""
Maps an iterator of batches in the current :class:`DataFrame` using a Python native
Expand Down Expand Up @@ -60,6 +62,7 @@ def mapInPandas(
schema : :class:`pyspark.sql.types.DataType` or str
the return type of the `func` in PySpark. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
is_barrier : Use barrier mode execution if True.

Examples
--------
Expand Down Expand Up @@ -97,7 +100,9 @@ def mapInPandas(
return DataFrame(jdf, self.sparkSession)

def mapInArrow(
self, func: "ArrowMapIterFunction", schema: Union[StructType, str], is_barrier: bool
self,
func: "ArrowMapIterFunction", schema: Union[StructType, str],
is_barrier: bool = False
) -> "DataFrame":
"""
Maps an iterator of batches in the current :class:`DataFrame` using a Python native
Expand All @@ -122,6 +127,7 @@ def mapInArrow(
schema : :class:`pyspark.sql.types.DataType` or str
the return type of the `func` in PySpark. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
is_barrier : Use barrier mode execution if True.

Examples
--------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics {
override def outputPartitioning: Partitioning = child.outputPartitioning

override protected def doExecute(): RDD[InternalRow] = {
val resultRDD = child.execute().mapPartitionsInternal { inputIter =>
def mapper(inputIter: Iterator[InternalRow]): Iterator[InternalRow] = {
// Single function with one struct.
val argOffsets = Array(Array(0))
val chainedFunc = Seq(ChainedPythonFunctions(Seq(pythonFunction)))
Expand Down Expand Up @@ -92,10 +92,11 @@ trait MapInBatchExec extends UnaryExecNode with PythonSQLMetrics {
flattenedBatch.rowIterator.asScala
}.map(unsafeProj)
}

if (isBarrier) {
resultRDD.barrier().mapPartitions(iter => iter)
child.execute().barrier().mapPartitions(mapper)
} else {
resultRDD
child.execute().mapPartitionsInternal(mapper)
}
}
}