Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.executor

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.shuffle.ShuffleMetricsReporter
import org.apache.spark.util.LongAccumulator


Expand Down Expand Up @@ -123,12 +124,13 @@ class ShuffleReadMetrics private[spark] () extends Serializable {
}
}


/**
* A temporary shuffle read metrics holder that is used to collect shuffle read metrics for each
* shuffle dependency, and all temporary metrics will be merged into the [[ShuffleReadMetrics]] at
* last.
*/
private[spark] class TempShuffleReadMetrics {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was moved to TempShuffleReadMetrics

private[spark] class TempShuffleReadMetrics extends ShuffleMetricsReporter {
private[this] var _remoteBlocksFetched = 0L
private[this] var _localBlocksFetched = 0L
private[this] var _remoteBytesRead = 0L
Expand All @@ -137,13 +139,13 @@ private[spark] class TempShuffleReadMetrics {
private[this] var _fetchWaitTime = 0L
private[this] var _recordsRead = 0L

def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched += v
def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched += v
def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead += v
def incRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk += v
def incLocalBytesRead(v: Long): Unit = _localBytesRead += v
def incFetchWaitTime(v: Long): Unit = _fetchWaitTime += v
def incRecordsRead(v: Long): Unit = _recordsRead += v
override def incRemoteBlocksFetched(v: Long): Unit = _remoteBlocksFetched += v
override def incLocalBlocksFetched(v: Long): Unit = _localBlocksFetched += v
override def incRemoteBytesRead(v: Long): Unit = _remoteBytesRead += v
override def incRemoteBytesReadToDisk(v: Long): Unit = _remoteBytesReadToDisk += v
override def incLocalBytesRead(v: Long): Unit = _localBytesRead += v
override def incFetchWaitTime(v: Long): Unit = _fetchWaitTime += v
override def incRecordsRead(v: Long): Unit = _recordsRead += v

def remoteBlocksFetched: Long = _remoteBlocksFetched
def localBlocksFetched: Long = _localBlocksFetched
Expand Down
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,10 @@ class CoGroupedRDD[K: ClassTag](

case shuffleDependency: ShuffleDependency[_, _, _] =>
// Read map outputs of shuffle
val metrics = context.taskMetrics().createTempShuffleReadMetrics()
val it = SparkEnv.get.shuffleManager
.getReader(shuffleDependency.shuffleHandle, split.index, split.index + 1, context)
.getReader(
shuffleDependency.shuffleHandle, split.index, split.index + 1, context, metrics)
.read()
rddIterators += ((it, depNum))
}
Expand Down
4 changes: 3 additions & 1 deletion core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](

override def compute(split: Partition, context: TaskContext): Iterator[(K, C)] = {
val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, C]]
SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
val metrics = context.taskMetrics().createTempShuffleReadMetrics()
SparkEnv.get.shuffleManager.getReader(
dep.shuffleHandle, split.index, split.index + 1, context, metrics)
.read()
.asInstanceOf[Iterator[(K, C)]]
}
Expand Down
7 changes: 6 additions & 1 deletion core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,14 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
.asInstanceOf[Iterator[Product2[K, V]]].foreach(op)

case shuffleDependency: ShuffleDependency[_, _, _] =>
val metrics = context.taskMetrics().createTempShuffleReadMetrics()
val iter = SparkEnv.get.shuffleManager
.getReader(
shuffleDependency.shuffleHandle, partition.index, partition.index + 1, context)
shuffleDependency.shuffleHandle,
partition.index,
partition.index + 1,
context,
metrics)
.read()
iter.foreach(op)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
startPartition: Int,
endPartition: Int,
context: TaskContext,
readMetrics: ShuffleMetricsReporter,
serializerManager: SerializerManager = SparkEnv.get.serializerManager,
blockManager: BlockManager = SparkEnv.get.blockManager,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
Expand All @@ -53,7 +54,8 @@ private[spark] class BlockStoreShuffleReader[K, C](
SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue),
SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS),
SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM),
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true))
SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true),
readMetrics)

val serializerInstance = dep.serializer.newInstance()

Expand All @@ -66,7 +68,6 @@ private[spark] class BlockStoreShuffleReader[K, C](
}

// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createTempShuffleReadMetrics()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map { record =>
readMetrics.incRecordsRead(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ private[spark] trait ShuffleManager {
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C]
context: TaskContext,
metrics: ShuffleMetricsReporter): ShuffleReader[K, C]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, we should pass a read metrics reporter here, as this method is getReader which is called by the reducers to read shuffle files.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a read metrics here actually. In the write PR this is renamed ShuffleReadMetricsReporter.


/**
* Remove a shuffle's metadata from the ShuffleManager.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle

/**
* An interface for reporting shuffle information, for each shuffle. This interface assumes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for each shuffle -> for each reducer of a shuffle?

* all the methods are called on a single-threaded, i.e. concrete implementations would not need
* to synchronize anything.
*/
private[spark] trait ShuffleMetricsReporter {
def incRemoteBlocksFetched(v: Long): Unit
def incLocalBlocksFetched(v: Long): Unit
def incRemoteBytesRead(v: Long): Unit
def incRemoteBytesReadToDisk(v: Long): Unit
def incLocalBytesRead(v: Long): Unit
def incFetchWaitTime(v: Long): Unit
def incRecordsRead(v: Long): Unit
}
52 changes: 52 additions & 0 deletions core/src/main/scala/org/apache/spark/shuffle/metrics.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.shuffle

/**
* An interface for reporting shuffle read metrics, for each shuffle. This interface assumes
* all the methods are called on a single-threaded, i.e. concrete implementations would not need
* to synchronize.
*
* All methods have additional Spark visibility modifier to allow public, concrete implementations
* that still have these methods marked as private[spark].
*/
private[spark] trait ShuffleReadMetricsReporter {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we plan to use this interface later on? It's not used in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xuanyuanking just submitted a PR on how to use it :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#23128 :)

private[spark] def incRemoteBlocksFetched(v: Long): Unit
private[spark] def incLocalBlocksFetched(v: Long): Unit
private[spark] def incRemoteBytesRead(v: Long): Unit
private[spark] def incRemoteBytesReadToDisk(v: Long): Unit
private[spark] def incLocalBytesRead(v: Long): Unit
private[spark] def incFetchWaitTime(v: Long): Unit
private[spark] def incRecordsRead(v: Long): Unit
}


/**
* An interface for reporting shuffle write metrics. This interface assumes all the methods are
* called on a single-threaded, i.e. concrete implementations would not need to synchronize.
*
* All methods have additional Spark visibility modifier to allow public, concrete implementations
* that still have these methods marked as private[spark].
*/
private[spark] trait ShuffleWriteMetricsReporter {
private[spark] def incBytesWritten(v: Long): Unit
private[spark] def incRecordsWritten(v: Long): Unit
private[spark] def incWriteTime(v: Long): Unit
private[spark] def decBytesWritten(v: Long): Unit
private[spark] def decRecordsWritten(v: Long): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,11 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager
handle: ShuffleHandle,
startPartition: Int,
endPartition: Int,
context: TaskContext): ShuffleReader[K, C] = {
context: TaskContext,
metrics: ShuffleMetricsReporter): ShuffleReader[K, C] = {
new BlockStoreShuffleReader(
handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
handle.asInstanceOf[BaseShuffleHandle[K, _, C]],
startPartition, endPartition, context, metrics)
}

/** Get a writer for a given partition. Called on executors by map tasks. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.shuffle._
import org.apache.spark.network.util.TransportConf
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.shuffle.{FetchFailedException, ShuffleMetricsReporter}
import org.apache.spark.util.Utils
import org.apache.spark.util.io.ChunkedByteBufferOutputStream

Expand All @@ -51,14 +51,15 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage. Note that zero-sized blocks are
* already excluded, which happened in
* [[MapOutputTracker.convertMapStatuses]].
* [[org.apache.spark.MapOutputTracker.convertMapStatuses]].
* @param streamWrapper A function to wrap the returned input stream.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
* @param maxReqsInFlight max number of remote requests to fetch blocks at any given point.
* @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point
* for a given remote host:port.
* @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory.
* @param detectCorrupt whether to detect any corruption in fetched blocks.
* @param shuffleMetrics used to report shuffle metrics.
*/
private[spark]
final class ShuffleBlockFetcherIterator(
Expand All @@ -71,7 +72,8 @@ final class ShuffleBlockFetcherIterator(
maxReqsInFlight: Int,
maxBlocksInFlightPerAddress: Int,
maxReqSizeShuffleToMem: Long,
detectCorrupt: Boolean)
detectCorrupt: Boolean,
shuffleMetrics: ShuffleMetricsReporter)
extends Iterator[(BlockId, InputStream)] with DownloadFileManager with Logging {

import ShuffleBlockFetcherIterator._
Expand Down Expand Up @@ -137,8 +139,6 @@ final class ShuffleBlockFetcherIterator(
*/
private[this] val corruptedBlocks = mutable.HashSet[BlockId]()

private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics()

/**
* Whether the iterator is still active. If isZombie is true, the callback interface will no
* longer place fetched blocks into [[results]].
Expand Down
6 changes: 4 additions & 2 deletions core/src/test/scala/org/apache/spark/ShuffleSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,10 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
mapTrackerMaster.registerMapOutput(0, 0, mapStatus)
}

val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem))
val taskContext = new TaskContextImpl(
1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)
val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics)
val readData = reader.read().toIndexedSeq
assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@ class CustomShuffledRDD[K, V, C](

override def compute(p: Partition, context: TaskContext): Iterator[(K, C)] = {
val part = p.asInstanceOf[CustomShuffledRDDPartition]
val metrics = context.taskMetrics().createTempShuffleReadMetrics()
SparkEnv.get.shuffleManager.getReader(
dependency.shuffleHandle, part.startIndexInParent, part.endIndexInParent, context)
dependency.shuffleHandle, part.startIndexInParent, part.endIndexInParent, context, metrics)
.read()
.asInstanceOf[Iterator[(K, C)]]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,14 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
.set("spark.shuffle.compress", "false")
.set("spark.shuffle.spill.compress", "false"))

val taskContext = TaskContext.empty()
val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
val shuffleReader = new BlockStoreShuffleReader(
shuffleHandle,
reduceId,
reduceId + 1,
TaskContext.empty(),
taskContext,
metrics,
serializerManager,
blockManager,
mapOutputTracker)
Expand Down
Loading