diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index ef4c421cbf82..24da855633db 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -340,17 +340,17 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging /** * Called from executors to get the server URIs and output sizes for each shuffle block that * needs to be read from a given range of map output partitions (startPartition is included but - * endPartition is excluded from the range) and a given mapId. + * endPartition is excluded from the range) and is produced by a specific mapper. * * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, * and the second item is a sequence of (shuffle block id, shuffle block size, map index) * tuples describing the shuffle blocks that are stored at that block manager. */ - def getMapSizesByExecutorId( + def getMapSizesByMapIndex( shuffleId: Int, + mapIndex: Int, startPartition: Int, - endPartition: Int, - mapId: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] + endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] /** * Deletes map output status information for the specified shuffle stage. @@ -741,13 +741,12 @@ private[spark] class MapOutputTrackerMaster( } } - override def getMapSizesByExecutorId( + override def getMapSizesByMapIndex( shuffleId: Int, + mapIndex: Int, startPartition: Int, - endPartition: Int, - mapId: Int) - : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { - logDebug(s"Fetching outputs for shuffle $shuffleId, mapId $mapId" + + endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, mapIndex $mapIndex" + s"partitions $startPartition-$endPartition") shuffleStatuses.get(shuffleId) match { case Some (shuffleStatus) => @@ -757,7 +756,7 @@ private[spark] class MapOutputTrackerMaster( startPartition, endPartition, statuses, - Some(mapId)) + Some(mapIndex)) } case None => Iterator.empty @@ -809,17 +808,17 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr } } - override def getMapSizesByExecutorId( + override def getMapSizesByMapIndex( shuffleId: Int, + mapIndex: Int, startPartition: Int, - endPartition: Int, - mapId: Int) : Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { - logDebug(s"Fetching outputs for shuffle $shuffleId, mapId $mapId" + + endPartition: Int): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + logDebug(s"Fetching outputs for shuffle $shuffleId, mapIndex $mapIndex" + s"partitions $startPartition-$endPartition") val statuses = getStatuses(shuffleId) try { MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, - statuses, Some(mapId)) + statuses, Some(mapIndex)) } catch { case e: MetadataFetchFailedException => // We experienced a fetch failure so our mapStatuses cache is outdated; clear it: @@ -962,6 +961,7 @@ private[spark] object MapOutputTracker extends Logging { * @param startPartition Start of map output partition ID range (included in range) * @param endPartition End of map output partition ID range (excluded from range) * @param statuses List of map statuses, indexed by map partition index. + * @param mapIndex When specified, only shuffle blocks from this mapper will be processed. * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId, * and the second item is a sequence of (shuffle block id, shuffle block size, map index) * tuples describing the shuffle blocks that are stored at that block manager. @@ -971,11 +971,11 @@ private[spark] object MapOutputTracker extends Logging { startPartition: Int, endPartition: Int, statuses: Array[MapStatus], - mapId : Option[Int] = None): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { + mapIndex : Option[Int] = None): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = { assert (statuses != null) val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]] val iter = statuses.iterator.zipWithIndex - for ((status, mapIndex) <- mapId.map(id => iter.filter(_._2 == id)).getOrElse(iter)) { + for ((status, mapIndex) <- mapIndex.map(index => iter.filter(_._2 == index)).getOrElse(iter)) { if (status == null) { val errorMessage = s"Missing an output location for shuffle $shuffleId" logError(errorMessage) diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 242442ac9d8f..3737102a1aba 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -20,7 +20,7 @@ package org.apache.spark.shuffle import org.apache.spark._ import org.apache.spark.internal.{config, Logging} import org.apache.spark.serializer.SerializerManager -import org.apache.spark.storage.{BlockManager, ShuffleBlockFetcherIterator} +import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalSorter @@ -30,34 +30,18 @@ import org.apache.spark.util.collection.ExternalSorter */ private[spark] class BlockStoreShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], - startPartition: Int, - endPartition: Int, + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], context: TaskContext, readMetrics: ShuffleReadMetricsReporter, serializerManager: SerializerManager = SparkEnv.get.serializerManager, blockManager: BlockManager = SparkEnv.get.blockManager, - mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, - mapId: Option[Int] = None) + mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker) extends ShuffleReader[K, C] with Logging { private val dep = handle.dependency /** Read the combined key-values for this reduce task */ override def read(): Iterator[Product2[K, C]] = { - val blocksByAddress = mapId match { - case (Some(mapId)) => mapOutputTracker.getMapSizesByExecutorId( - handle.shuffleId, - startPartition, - endPartition, - mapId) - case (None) => mapOutputTracker.getMapSizesByExecutorId( - handle.shuffleId, - startPartition, - endPartition) - case (_) => throw new IllegalArgumentException( - "mapId should be both set or unset") - } - val wrappedStreams = new ShuffleBlockFetcherIterator( context, blockManager.blockStoreClient, diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index 0041dca507c0..01aa43eb9763 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -55,17 +55,16 @@ private[spark] trait ShuffleManager { metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] /** - * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) to - * read from mapId. - * Called on executors by reduce tasks. + * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) + * that are produced by one specific mapper. Called on executors by reduce tasks. */ - def getMapReader[K, C]( + def getReaderForOneMapper[K, C]( handle: ShuffleHandle, + mapIndex: Int, startPartition: Int, endPartition: Int, context: TaskContext, - metrics: ShuffleReadMetricsReporter, - mapId: Int): ShuffleReader[K, C] + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] /** * Remove a shuffle's metadata from the ShuffleManager. diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index b21ce9ce0fc7..3cb94c1cbdd7 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -122,30 +122,23 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager endPartition: Int, context: TaskContext, metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, startPartition, endPartition) new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - startPartition, endPartition, context, metrics) + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics) } - /** - * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive) to - * read from mapId. - * Called on executors by reduce tasks. - */ - override def getMapReader[K, C]( + override def getReaderForOneMapper[K, C]( handle: ShuffleHandle, + mapIndex: Int, startPartition: Int, endPartition: Int, context: TaskContext, - metrics: ShuffleReadMetricsReporter, - mapId: Int): ShuffleReader[K, C] = { + metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { + val blocksByAddress = SparkEnv.get.mapOutputTracker.getMapSizesByMapIndex( + handle.shuffleId, mapIndex, startPartition, endPartition) new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - startPartition, - endPartition, - context, - metrics, - mapId = Some(mapId)) + handle.asInstanceOf[BaseShuffleHandle[K, _, C]], blocksByAddress, context, metrics) } /** Get a writer for a given partition. Called on executors by map tasks. */ diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 74ec8abb22ad..3f9536e224de 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -130,15 +130,15 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext val taskContext = TaskContext.empty() val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() + val blocksByAddress = mapOutputTracker.getMapSizesByExecutorId( + shuffleId, reduceId, reduceId + 1) val shuffleReader = new BlockStoreShuffleReader( shuffleHandle, - reduceId, - reduceId + 1, + blocksByAddress, taskContext, metrics, serializerManager, - blockManager, - mapOutputTracker) + blockManager) assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala index 9ad1ebaf6f37..5fccb5ce6578 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/LocalShuffledRowRDD.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLShuffleReadMetricsRe * (identified by `preShufflePartitionIndex`) contains a range of post-shuffle partitions * (`startPostShufflePartitionIndex` to `endPostShufflePartitionIndex - 1`, inclusive). */ -private final class LocalShuffleRowRDDPartition( +private final class LocalShuffledRowRDDPartition( val preShufflePartitionIndex: Int) extends Partition { override val index: Int = preShufflePartitionIndex } @@ -63,7 +63,7 @@ class LocalShuffledRowRDD( override def getPartitions: Array[Partition] = { Array.tabulate[Partition](numMappers) { i => - new LocalShuffleRowRDDPartition(i) + new LocalShuffledRowRDDPartition(i) } } @@ -73,20 +73,20 @@ class LocalShuffledRowRDD( } override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { - val localRowPartition = split.asInstanceOf[LocalShuffleRowRDDPartition] - val mapId = localRowPartition.index + val localRowPartition = split.asInstanceOf[LocalShuffledRowRDDPartition] + val mapIndex = localRowPartition.index val tempMetrics = context.taskMetrics().createTempShuffleReadMetrics() // `SQLShuffleReadMetricsReporter` will update its own metrics for SQL exchange operator, // as well as the `tempMetrics` for basic shuffle metrics. val sqlMetricsReporter = new SQLShuffleReadMetricsReporter(tempMetrics, metrics) - val reader = SparkEnv.get.shuffleManager.getMapReader( + val reader = SparkEnv.get.shuffleManager.getReaderForOneMapper( dependency.shuffleHandle, + mapIndex, 0, numReducers, context, - sqlMetricsReporter, - mapId) + sqlMetricsReporter) reader.read().asInstanceOf[Iterator[Product2[Int, InternalRow]]].map(_._2) }