diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 3b0062efeff0..ef4c421cbf82 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -33,7 +33,7 @@ import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} -import org.apache.spark.scheduler.MapStatus +import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, MapStatus} import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} import org.apache.spark.util._ @@ -337,6 +337,21 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging endPartition: Int) : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] + /** + * Called from executors to get the server URIs and output sizes for each shuffle block that + * needs to be read from a given range of map output partitions (startPartition is included but + * endPartition is excluded from the range) and a given mapId. + * + * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, + * and the second item is a sequence of (shuffle block id, shuffle block size, map index) + * tuples describing the shuffle blocks that are stored at that block manager. + */ + def getMapSizesByExecutorId( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + mapId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] + /** * Deletes map output status information for the specified shuffle stage. */ @@ -668,6 +683,31 @@ private[spark] class MapOutputTrackerMaster( None } + /** + * Return the location where the Mapper ran. The locations each includes both a host and an + * executor id on that host. + * + * @param dep shuffle dependency object + * @param mapId the map id + * @return a sequence of locations where task runs. + */ + def getMapLocation(dep: ShuffleDependency[_, _, _], mapId: Int): Seq[String] = + { + val shuffleStatus = shuffleStatuses.get(dep.shuffleId).orNull + if (shuffleStatus != null) { + shuffleStatus.withMapStatuses { statuses => + if (mapId >= 0 && mapId < statuses.length) { + Seq( ExecutorCacheTaskLocation(statuses(mapId).location.host, + statuses(mapId).location.executorId).toString) + } else { + Nil + } + } + } else { + Nil + } + } + def incrementEpoch(): Unit = { epochLock.synchronized { epoch += 1 @@ -701,6 +741,29 @@ private[spark] class MapOutputTrackerMaster( } } + override def getMapSizesByExecutorId( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + mapId: Int) + : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, mapId $mapId" + + s"partitions $startPartition-$endPartition") + shuffleStatuses.get(shuffleId) match { + case Some (shuffleStatus) => + shuffleStatus.withMapStatuses { statuses => + MapOutputTracker.convertMapStatuses( + shuffleId, + startPartition, + endPartition, + statuses, + Some(mapId)) + } + case None => + Iterator.empty + } + } + override def stop(): Unit = { mapOutputRequests.offer(PoisonPill) threadpool.shutdown() @@ -746,6 +809,25 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr } } + override def getMapSizesByExecutorId( + shuffleId: Int, + startPartition: Int, + endPartition: Int, + mapId: Int) : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, mapId $mapId" + + s"partitions $startPartition-$endPartition") + val statuses = getStatuses(shuffleId) + try { + MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, + statuses, Some(mapId)) + } catch { + case e: MetadataFetchFailedException => + // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: + mapStatuses.clear() + throw e + } + } + /** * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize * on this array when reading it, because on the driver, we may be changing it in place. @@ -888,10 +970,12 @@ private[spark] object MapOutputTracker extends Logging { shuffleId: Int, startPartition: Int, endPartition: Int, - statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + statuses: Array[MapStatus], + mapId : Option[Int] = None): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { assert (statuses != null) val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]] - for ((status, mapIndex) <- statuses.iterator.zipWithIndex) { + val iter = statuses.iterator.zipWithIndex + for ((status, mapIndex) <- mapId.map(id => iter.filter(_._2 == id)).getOrElse(iter)) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" logError(errorMessage) @@ -906,6 +990,7 @@ private[spark] object MapOutputTracker extends Logging { } } } + splitsByAddress.iterator } } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 4329824b1b62..242442ac9d8f 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -36,18 +36,33 @@ private[spark] class BlockStoreShuffleReader[K, C]( readMetrics: ShuffleReadMetricsReporter, serializerManager: SerializerManager = SparkEnv.get.serializerManager, blockManager: BlockManager = SparkEnv.get.blockManager, - mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, + mapId: Option[Int] = None) extends ShuffleReader[K, C] with Logging { private val dep = handle.dependency /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { + val blocksByAddress = mapId match { + case (Some(mapId)) => mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, + startPartition, + endPartition, + mapId) + case (None) => mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, + startPartition, + endPartition) + case (_) => throw new IllegalArgumentException( + "mapId should be both set or unset") + } + val wrappedStreams = new ShuffleBlockFetcherIterator( context, blockManager.blockStoreClient, blockManager, - mapOutputTracker.getMapSizesByExecutorId(handle.shuffleId, startPartition, endPartition), + blocksByAddress, serializerManager.wrapStream, // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.get(config.REDUCER_MAX_SIZE_IN_FLIGHT) * 1024 * 1024, diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index a717ef242ea7..0041dca507c0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -54,6 +54,19 @@ private[spark] trait ShuffleManager { context: TaskContext, metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) to + * read from mapId. + * Called on executors by reduce tasks. + */ + def getMapReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter, + mapId: Int): ShuffleReader[K, C] + /** * Remove a shuffle's metadata from the ShuffleManager. * @return true if the metadata removed successfully, otherwise false. diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index d96bcb3d073d..b21ce9ce0fc7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -127,6 +127,27 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager startPartition, endPartition, context, metrics) } + /** + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) to + * read from mapId. + * Called on executors by reduce tasks. + */ + override def getMapReader[K, C]( + handle: ShuffleHandle, + startPartition: Int, + endPartition: Int, + context: TaskContext, + metrics: ShuffleReadMetricsReporter, + mapId: Int): ShuffleReader[K, C] = { + new BlockStoreShuffleReader( + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], + startPartition, + endPartition, + context, + metrics, + mapId = Some(mapId)) + } + /** Get a writer for a given partition. Called on executors by map tasks. */ override def getWriter[K, V]( handle: ShuffleHandle, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index bcb3153a3ca4..f00a4b545ee3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -394,6 +394,14 @@ object SQLConf { "must be a positive integer.") .createOptional + val OPTIMIZE_LOCAL_SHUFFLE_READER_ENABLED = + buildConf("spark.sql.adaptive.optimizedLocalShuffleReader.enabled") + .doc("When true and adaptive execution is enabled, this enables the optimization of" + + " converting the shuffle reader to local shuffle reader for the shuffle exchange" + + " of the broadcast hash join in probe side.") + .booleanConf + .createWithDefault(true) + val SUBEXPRESSION_ELIMINATION_ENABLED = buildConf("spark.sql.subexpressionElimination.enabled") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala index 8c7752c4bb74..459311df22d2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec} +import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, LocalShuffleReaderExec, QueryStageExec} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.metric.SQLMetricInfo import org.apache.spark.sql.internal.SQLConf @@ -56,6 +56,7 @@ private[execution] object SparkPlanInfo { case ReusedSubqueryExec(child) => child :: Nil case a: AdaptiveSparkPlanExec => a.executedPlan :: Nil case stage: QueryStageExec => stage.plan :: Nil + case localReader: LocalShuffleReaderExec => localReader.child :: Nil case _ => plan.children ++ plan.subqueries } val metrics = plan.metrics.toSeq.map { case (key, metric) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index 524cacc11484..f45e3560b2cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -84,6 +84,7 @@ case class AdaptiveSparkPlanExec( // plan should reach a final status of query stages (i.e., no more addition or removal of // Exchange nodes) after running these rules. private def queryStagePreparationRules: Seq[Rule[SparkPlan]] = Seq( + OptimizeLocalShuffleReader(conf), ensureRequirements ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala index 0ec8710e4db4..94e66b0c3a43 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanHelper.scala @@ -125,6 +125,7 @@ trait AdaptiveSparkPlanHelper { private def allChildren(p: SparkPlan): Seq[SparkPlan] = p match { case a: AdaptiveSparkPlanExec => Seq(a.executedPlan) case s: QueryStageExec => Seq(s.plan) + case l: LocalShuffleReaderExec => Seq(l.child) case _ => p.children } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala new file mode 100644 index 000000000000..9ad1ebaf6f37 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark._ +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsReporter} + + +/** + * The [[Partition]] used by [[LocalShuffledRowRDD]]. A pre-shuffle partition + * (identified by `preShufflePartitionIndex`) contains a range of post-shuffle partitions + * (`startPostShufflePartitionIndex` to `endPostShufflePartitionIndex - 1`, inclusive). + */ +private final class LocalShuffleRowRDDPartition( + val preShufflePartitionIndex: Int) extends Partition { + override val index: Int = preShufflePartitionIndex +} + +/** + * This is a specialized version of [[org.apache.spark.sql.execution.ShuffledRowRDD]]. This is used + * in Spark SQL adaptive execution when a shuffle join is converted to broadcast join at runtime + * because the map output of one input table is small enough for broadcast. This RDD represents the + * data of another input table of the join that reads from shuffle. Each partition of the RDD reads + * the whole data from just one mapper output locally. So actually there is no data transferred + * from the network. + + * This RDD takes a [[ShuffleDependency]] (`dependency`). + * + * The `dependency` has the parent RDD of this RDD, which represents the dataset before shuffle + * (i.e. map output). Elements of this RDD are (partitionId, Row) pairs. + * Partition ids should be in the range [0, numPartitions - 1]. + * `dependency.partitioner.numPartitions` is the number of pre-shuffle partitions. (i.e. the number + * of partitions of the map output). The post-shuffle partition number is the same to the parent + * RDD's partition number. + */ +class LocalShuffledRowRDD( + var dependency: ShuffleDependency[Int, InternalRow, InternalRow], + metrics: Map[String, SQLMetric]) + extends RDD[InternalRow](dependency.rdd.context, Nil) { + + private[this] val numReducers = dependency.partitioner.numPartitions + private[this] val numMappers = dependency.rdd.partitions.length + + override def getDependencies: Seq[Dependency[_]] = List(dependency) + + override def getPartitions: Array[Partition] = { + + Array.tabulate[Partition](numMappers) { i => + new LocalShuffleRowRDDPartition(i) + } + } + + override def getPreferredLocations(partition: Partition): Seq[String] = { + val tracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + tracker.getMapLocation(dependency, partition.index) + } + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + val localRowPartition = split.asInstanceOf[LocalShuffleRowRDDPartition] + val mapId = localRowPartition.index + val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() + // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator, + // as well as the `tempMetrics` for basic shuffle metrics. + val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) + + val reader = SparkEnv.get.shuffleManager.getMapReader( + dependency.shuffleHandle, + 0, + numReducers, + context, + sqlMetricsReporter, + mapId) + reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) + } + + override def clearDependencies() { + super.clearDependencies() + dependency = null + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala new file mode 100644 index 000000000000..308e65e793d8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeLocalShuffleReader.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{LeafExecNode, SparkPlan} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchangeExec} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildLeft, BuildRight} +import org.apache.spark.sql.internal.SQLConf + +case class OptimizeLocalShuffleReader(conf: SQLConf) extends Rule[SparkPlan] { + + def canUseLocalShuffleReaderLeft(join: BroadcastHashJoinExec): Boolean = { + join.buildSide == BuildRight && ShuffleQueryStageExec.isShuffleQueryStageExec(join.left) + } + + def canUseLocalShuffleReaderRight(join: BroadcastHashJoinExec): Boolean = { + join.buildSide == BuildLeft && ShuffleQueryStageExec.isShuffleQueryStageExec(join.right) + } + + override def apply(plan: SparkPlan): SparkPlan = { + if (!conf.getConf(SQLConf.OPTIMIZE_LOCAL_SHUFFLE_READER_ENABLED)) { + return plan + } + + val optimizedPlan = plan.transformDown { + case join: BroadcastHashJoinExec if canUseLocalShuffleReaderRight(join) => + val localReader = LocalShuffleReaderExec(join.right.asInstanceOf[QueryStageExec]) + join.copy(right = localReader) + case join: BroadcastHashJoinExec if canUseLocalShuffleReaderLeft(join) => + val localReader = LocalShuffleReaderExec(join.left.asInstanceOf[QueryStageExec]) + join.copy(left = localReader) + } + + def numExchanges(plan: SparkPlan): Int = { + plan.collect { + case e: ShuffleExchangeExec => e + }.length + } + + val numExchangeBefore = numExchanges(EnsureRequirements(conf).apply(plan)) + val numExchangeAfter = numExchanges(EnsureRequirements(conf).apply(optimizedPlan)) + + if (numExchangeAfter > numExchangeBefore) { + logWarning("OptimizeLocalShuffleReader rule is not applied due" + + " to additional shuffles will be introduced.") + plan + } else { + optimizedPlan + } + } +} + +case class LocalShuffleReaderExec(child: QueryStageExec) extends LeafExecNode { + + override def output: Seq[Attribute] = child.output + + override def doCanonicalize(): SparkPlan = child.canonicalized + + override def outputPartitioning: Partitioning = { + + def tryReserveChildPartitioning(stage: ShuffleQueryStageExec): Partitioning = { + val initialPartitioning = stage.plan.child.outputPartitioning + if (initialPartitioning.isInstanceOf[UnknownPartitioning]) { + UnknownPartitioning(stage.plan.shuffleDependency.rdd.partitions.length) + } else { + initialPartitioning + } + } + + child match { + case stage: ShuffleQueryStageExec => + tryReserveChildPartitioning(stage) + case ReusedQueryStageExec(_, stage: ShuffleQueryStageExec, _) => + tryReserveChildPartitioning(stage) + } + } + + private var cachedShuffleRDD: RDD[InternalRow] = null + + override protected def doExecute(): RDD[InternalRow] = { + if (cachedShuffleRDD == null) { + cachedShuffleRDD = child match { + case stage: ShuffleQueryStageExec => + stage.plan.createLocalShuffleRDD() + case ReusedQueryStageExec(_, stage: ShuffleQueryStageExec, _) => + stage.plan.createLocalShuffleRDD() + } + } + cachedShuffleRDD + } + + override def generateTreeString( + depth: Int, + lastChildren: Seq[Boolean], + append: String => Unit, + verbose: Boolean, + prefix: String = "", + addSuffix: Boolean = false, + maxFields: Int, + printNodeId: Boolean): Unit = { + super.generateTreeString(depth, + lastChildren, + append, + verbose, + prefix, + addSuffix, + maxFields, + printNodeId) + child.generateTreeString( + depth + 1, lastChildren :+ true, append, verbose, "", false, maxFields, printNodeId) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 2f4c5734469f..2f94c522712b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Uns import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.adaptive.LocalShuffledRowRDD import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics, SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType @@ -82,6 +83,10 @@ case class ShuffleExchangeExec( new ShuffledRowRDD(shuffleDependency, readMetrics, partitionStartIndices) } + def createLocalShuffleRDD(): LocalShuffledRowRDD = { + new LocalShuffledRowRDD(shuffleDependency, readMetrics) + } + /** * Caches the created ShuffleRowRDD so we can reuse that. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 37b106c3ea53..cd0bf726da9a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -77,6 +77,13 @@ class AdaptiveQueryExecSuite } } + private def checkNumLocalShuffleReaders(plan: SparkPlan, expected: Int): Unit = { + val localReaders = plan.collect { + case reader: LocalShuffleReaderExec => reader + } + assert(localReaders.length === expected) + } + test("Change merge join to broadcast join") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", @@ -87,6 +94,7 @@ class AdaptiveQueryExecSuite assert(smj.size == 1) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) + checkNumLocalShuffleReaders(adaptivePlan, 1) } } @@ -103,14 +111,7 @@ class AdaptiveQueryExecSuite val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) - val shuffleReaders = adaptivePlan.collect { - case reader: CoalescedShuffleReaderExec => reader - } - assert(shuffleReaders.length === 1) - // The pre-shuffle partition size is [0, 72, 0, 72, 126] - shuffleReaders.foreach { reader => - assert(reader.outputPartitioning.numPartitions === 2) - } + checkNumLocalShuffleReaders(adaptivePlan, 1) } } @@ -125,6 +126,7 @@ class AdaptiveQueryExecSuite assert(smj.size == 1) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) + checkNumLocalShuffleReaders(adaptivePlan, 1) } } @@ -139,6 +141,8 @@ class AdaptiveQueryExecSuite assert(smj.size == 1) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 1) + + checkNumLocalShuffleReaders(adaptivePlan, 1) } } @@ -160,6 +164,8 @@ class AdaptiveQueryExecSuite assert(smj.size == 3) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 3) + + checkNumLocalShuffleReaders(adaptivePlan, 1) } } @@ -183,6 +189,8 @@ class AdaptiveQueryExecSuite assert(smj.size == 3) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 3) + + checkNumLocalShuffleReaders(adaptivePlan, 0) } } @@ -206,6 +214,7 @@ class AdaptiveQueryExecSuite assert(smj.size == 3) val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) assert(bhj.size == 3) + checkNumLocalShuffleReaders(adaptivePlan, 0) } } @@ -355,6 +364,27 @@ class AdaptiveQueryExecSuite } } + test("Change merge join to broadcast join without local shuffle reader") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.OPTIMIZE_LOCAL_SHUFFLE_READER_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "30") { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + """ + |SELECT * FROM testData t1 join testData2 t2 + |ON t1.key = t2.a join testData3 t3 on t2.a = t3.a + |where t1.value = 1 + """.stripMargin + ) + val smj = findTopLevelSortMergeJoin(plan) + assert(smj.size == 2) + val bhj = findTopLevelBroadcastHashJoin(adaptivePlan) + assert(bhj.size == 1) + // additional shuffle exchange introduced, so revert OptimizeLocalShuffleReader rule. + checkNumLocalShuffleReaders(adaptivePlan, 0) + } + } + test("Avoid changing merge join to broadcast join if too many empty partitions on build plan") { withSQLConf( SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",