Skip to content

Commit 6835704

Browse files
committed
Merge remote-tracking branch 'upstream/master'
2 parents 01e4cdf + c939c70 commit 6835704

File tree

9 files changed

+38
-20
lines changed

9 files changed

+38
-20
lines changed

core/src/main/scala/org/apache/spark/rdd/RDD.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,24 @@ abstract class RDD[T: ClassTag](
705705
preservesPartitioning)
706706
}
707707

708+
/**
709+
* [performance] Spark's internal mapPartitions method which skips closure cleaning. It is a
710+
* performance API to be used carefully only if we are sure that the RDD elements are
711+
* serializable and don't require closure cleaning.
712+
*
713+
* @param preservesPartitioning indicates whether the input function preserves the partitioner,
714+
* which should be `false` unless this is a pair RDD and the input function doesn't modify
715+
* the keys.
716+
*/
717+
private[spark] def mapPartitionsInternal[U: ClassTag](
718+
f: Iterator[T] => Iterator[U],
719+
preservesPartitioning: Boolean = false): RDD[U] = withScope {
720+
new MapPartitionsRDD(
721+
this,
722+
(context: TaskContext, index: Int, iter: Iterator[T]) => f(iter),
723+
preservesPartitioning)
724+
}
725+
708726
/**
709727
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
710728
* of the original partition.

sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ private[sql] case class InMemoryRelation(
125125

126126
private def buildBuffers(): Unit = {
127127
val output = child.output
128-
val cached = child.execute().mapPartitions { rowIterator =>
128+
val cached = child.execute().mapPartitionsInternal { rowIterator =>
129129
new Iterator[CachedBatch] {
130130
def next(): CachedBatch = {
131131
val columnBuilders = output.map { attribute =>
@@ -292,7 +292,7 @@ private[sql] case class InMemoryColumnarTableScan(
292292
val relOutput = relation.output
293293
val buffers = relation.cachedColumnBuffers
294294

295-
buffers.mapPartitions { cachedBatchIterator =>
295+
buffers.mapPartitionsInternal { cachedBatchIterator =>
296296
val partitionFilter = newPredicate(
297297
partitionFilters.reduceOption(And).getOrElse(Literal(true)),
298298
schema)

sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ case class Exchange(
168168
case RangePartitioning(sortingExpressions, numPartitions) =>
169169
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
170170
// partition bounds. To get accurate samples, we need to copy the mutable keys.
171-
val rddForSampling = rdd.mapPartitions { iter =>
171+
val rddForSampling = rdd.mapPartitionsInternal { iter =>
172172
val mutablePair = new MutablePair[InternalRow, Null]()
173173
iter.map(row => mutablePair.update(row.copy(), null))
174174
}
@@ -200,12 +200,12 @@ case class Exchange(
200200
}
201201
val rddWithPartitionIds: RDD[Product2[Int, InternalRow]] = {
202202
if (needToCopyObjectsBeforeShuffle(part, serializer)) {
203-
rdd.mapPartitions { iter =>
203+
rdd.mapPartitionsInternal { iter =>
204204
val getPartitionKey = getPartitionKeyExtractor()
205205
iter.map { row => (part.getPartition(getPartitionKey(row)), row.copy()) }
206206
}
207207
} else {
208-
rdd.mapPartitions { iter =>
208+
rdd.mapPartitionsInternal { iter =>
209209
val getPartitionKey = getPartitionKeyExtractor()
210210
val mutablePair = new MutablePair[Int, InternalRow]()
211211
iter.map { row => mutablePair.update(part.getPartition(getPartitionKey(row)), row) }

sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ case class Generate(
5959
protected override def doExecute(): RDD[InternalRow] = {
6060
// boundGenerator.terminate() should be triggered after all of the rows in the partition
6161
if (join) {
62-
child.execute().mapPartitions { iter =>
62+
child.execute().mapPartitionsInternal { iter =>
6363
val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null))
6464
val joinedRow = new JoinedRow
6565

@@ -79,7 +79,7 @@ case class Generate(
7979
}
8080
}
8181
} else {
82-
child.execute().mapPartitions { iter =>
82+
child.execute().mapPartitionsInternal { iter =>
8383
iter.flatMap(row => boundGenerator.eval(row)) ++
8484
LazyIterator(() => boundGenerator.terminate())
8585
}

sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ case class SortBasedAggregate(
6969
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
7070
val numInputRows = longMetric("numInputRows")
7171
val numOutputRows = longMetric("numOutputRows")
72-
child.execute().mapPartitions { iter =>
72+
child.execute().mapPartitionsInternal { iter =>
7373
// Because the constructor of an aggregation iterator will read at least the first row,
7474
// we need to get the value of iter.hasNext first.
7575
val hasInput = iter.hasNext

sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ case class TungstenProject(projectList: Seq[NamedExpression], child: SparkPlan)
4343

4444
protected override def doExecute(): RDD[InternalRow] = {
4545
val numRows = longMetric("numRows")
46-
child.execute().mapPartitions { iter =>
46+
child.execute().mapPartitionsInternal { iter =>
4747
val project = UnsafeProjection.create(projectList, child.output,
4848
subexpressionEliminationEnabled)
4949
iter.map { row =>
@@ -67,7 +67,7 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode {
6767
protected override def doExecute(): RDD[InternalRow] = {
6868
val numInputRows = longMetric("numInputRows")
6969
val numOutputRows = longMetric("numOutputRows")
70-
child.execute().mapPartitions { iter =>
70+
child.execute().mapPartitionsInternal { iter =>
7171
val predicate = newPredicate(condition, child.output)
7272
iter.filter { row =>
7373
numInputRows += 1
@@ -161,19 +161,19 @@ case class Limit(limit: Int, child: SparkPlan)
161161

162162
protected override def doExecute(): RDD[InternalRow] = {
163163
val rdd: RDD[_ <: Product2[Boolean, InternalRow]] = if (sortBasedShuffleOn) {
164-
child.execute().mapPartitions { iter =>
164+
child.execute().mapPartitionsInternal { iter =>
165165
iter.take(limit).map(row => (false, row.copy()))
166166
}
167167
} else {
168-
child.execute().mapPartitions { iter =>
168+
child.execute().mapPartitionsInternal { iter =>
169169
val mutablePair = new MutablePair[Boolean, InternalRow]()
170170
iter.take(limit).map(row => mutablePair.update(false, row))
171171
}
172172
}
173173
val part = new HashPartitioner(1)
174174
val shuffled = new ShuffledRDD[Boolean, InternalRow, InternalRow](rdd, part)
175175
shuffled.setSerializer(new SparkSqlSerializer(child.sqlContext.sparkContext.getConf))
176-
shuffled.mapPartitions(_.take(limit).map(_._2))
176+
shuffled.mapPartitionsInternal(_.take(limit).map(_._2))
177177
}
178178
}
179179

@@ -294,7 +294,7 @@ case class MapPartitions[T, U](
294294
child: SparkPlan) extends UnaryNode {
295295

296296
override protected def doExecute(): RDD[InternalRow] = {
297-
child.execute().mapPartitions { iter =>
297+
child.execute().mapPartitionsInternal { iter =>
298298
val tBoundEncoder = tEncoder.bind(child.output)
299299
func(iter.map(tBoundEncoder.fromRow)).map(uEncoder.toRow)
300300
}
@@ -318,7 +318,7 @@ case class AppendColumns[T, U](
318318
override def output: Seq[Attribute] = child.output ++ newColumns
319319

320320
override protected def doExecute(): RDD[InternalRow] = {
321-
child.execute().mapPartitions { iter =>
321+
child.execute().mapPartitionsInternal { iter =>
322322
val tBoundEncoder = tEncoder.bind(child.output)
323323
val combiner = GenerateUnsafeRowJoiner.create(tEncoder.schema, uEncoder.schema)
324324
iter.map { row =>
@@ -350,7 +350,7 @@ case class MapGroups[K, T, U](
350350
Seq(groupingAttributes.map(SortOrder(_, Ascending)))
351351

352352
override protected def doExecute(): RDD[InternalRow] = {
353-
child.execute().mapPartitions { iter =>
353+
child.execute().mapPartitionsInternal { iter =>
354354
val grouped = GroupedIterator(iter, groupingAttributes, child.output)
355355
val groupKeyEncoder = kEncoder.bind(groupingAttributes)
356356
val groupDataEncoder = tEncoder.bind(child.output)

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,15 @@ case class BroadcastLeftSemiJoinHash(
5454
val hashSet = buildKeyHashSet(input.toIterator, SQLMetrics.nullLongMetric)
5555
val broadcastedRelation = sparkContext.broadcast(hashSet)
5656

57-
left.execute().mapPartitions { streamIter =>
57+
left.execute().mapPartitionsInternal { streamIter =>
5858
hashSemiJoin(streamIter, numLeftRows, broadcastedRelation.value, numOutputRows)
5959
}
6060
} else {
6161
val hashRelation =
6262
HashedRelation(input.toIterator, SQLMetrics.nullLongMetric, rightKeyGenerator, input.size)
6363
val broadcastedRelation = sparkContext.broadcast(hashRelation)
6464

65-
left.execute().mapPartitions { streamIter =>
65+
left.execute().mapPartitionsInternal { streamIter =>
6666
val hashedRelation = broadcastedRelation.value
6767
hashedRelation match {
6868
case unsafe: UnsafeHashedRelation =>

sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod
4646
row.copy()
4747
}
4848

49-
leftResults.cartesian(rightResults).mapPartitions { iter =>
49+
leftResults.cartesian(rightResults).mapPartitionsInternal { iter =>
5050
val joinedRow = new JoinedRow
5151
iter.map { r =>
5252
numOutputRows += 1

sql/core/src/main/scala/org/apache/spark/sql/execution/sort.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ case class Sort(
4747
if (global) OrderedDistribution(sortOrder) :: Nil else UnspecifiedDistribution :: Nil
4848

4949
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "sort") {
50-
child.execute().mapPartitions( { iterator =>
50+
child.execute().mapPartitionsInternal( { iterator =>
5151
val ordering = newOrdering(sortOrder, child.output)
5252
val sorter = new ExternalSorter[InternalRow, Null, InternalRow](
5353
TaskContext.get(), ordering = Some(ordering))

0 commit comments

Comments
 (0)