Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Still need mapId for the fetch fail scenario
  • Loading branch information
xuanyuanking committed Sep 17, 2019
commit 578c2338c47819f7e247052c0aa7af6b389b2933
24 changes: 12 additions & 12 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging

// For testing
def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)
}

Expand All @@ -292,11 +292,11 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
* endPartition is excluded from the range).
*
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
* and the second item is a sequence of (shuffle block id, shuffle block size) tuples
* describing the shuffle blocks that are stored at that block manager.
* and the second item is a sequence of (shuffle block id, shuffle block size, map id)
* tuples describing the shuffle blocks that are stored at that block manager.
*/
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long)])]
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]

/**
* Deletes map output status information for the specified shuffle stage.
Expand Down Expand Up @@ -646,7 +646,7 @@ private[spark] class MapOutputTrackerMaster(
// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
// This method is only called in local-mode.
def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
shuffleStatuses.get(shuffleId) match {
case Some (shuffleStatus) =>
Expand Down Expand Up @@ -686,7 +686,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr

// Get blocks sizes by executor Id. Note that zero-sized blocks are excluded in the result.
override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
: Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
val statuses = getStatuses(shuffleId)
try {
Expand Down Expand Up @@ -834,17 +834,17 @@ private[spark] object MapOutputTracker extends Logging {
* @param endPartition End of map output partition ID range (excluded from range)
* @param statuses List of map statuses, indexed by map ID.
* @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
* and the second item is a sequence of (shuffle block ID, shuffle block size) tuples
* describing the shuffle blocks that are stored at that block manager.
* and the second item is a sequence of (shuffle block id, shuffle block size, map id)
* tuples describing the shuffle blocks that are stored at that block manager.
*/
def convertMapStatuses(
shuffleId: Int,
startPartition: Int,
endPartition: Int,
statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long)])] = {
statuses: Array[MapStatus]): Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])] = {
assert (statuses != null)
val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long)]]
statuses.foreach { status =>
val splitsByAddress = new HashMap[BlockManagerId, ListBuffer[(BlockId, Long, Int)]]
for ((status, mapId) <- statuses.iterator.zipWithIndex) {
if (status == null) {
val errorMessage = s"Missing an output location for shuffle $shuffleId"
logError(errorMessage)
Expand All @@ -854,7 +854,7 @@ private[spark] object MapOutputTracker extends Logging {
val size = status.getSizeForBlock(part)
if (size != 0) {
splitsByAddress.getOrElseUpdate(status.location, ListBuffer()) +=
((ShuffleBlockId(shuffleId, status.mapTaskAttemptId, part), size))
((ShuffleBlockId(shuffleId, status.mapTaskAttemptId, part), size, mapId))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ import org.apache.spark.util.{CompletionIterator, TaskCompletionListener, Utils}
* @param shuffleClient [[BlockStoreClient]] for fetching remote blocks
* @param blockManager [[BlockManager]] for reading local blocks
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage. Note that zero-sized blocks are
* already excluded, which happened in
* For each block we also require two info: 1. the size (in bytes as a long
* field) in order to throttle the memory usage; 2. the mapId for this
* block, which indicate the index in the map stage of the block.
* Note that zero-sized blocks are already excluded, which happened in
* [[org.apache.spark.MapOutputTracker.convertMapStatuses]].
* @param streamWrapper A function to wrap the returned input stream.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
Expand All @@ -67,7 +68,7 @@ final class ShuffleBlockFetcherIterator(
context: TaskContext,
shuffleClient: BlockStoreClient,
blockManager: BlockManager,
blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long)])],
blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
streamWrapper: (BlockId, InputStream) => InputStream,
maxBytesInFlight: Long,
maxReqsInFlight: Int,
Expand Down Expand Up @@ -97,7 +98,7 @@ final class ShuffleBlockFetcherIterator(
private[this] val startTimeNs = System.nanoTime()

/** Local blocks to fetch, excluding zero-sized blocks. */
private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[BlockId]()
private[this] val localBlocks = scala.collection.mutable.LinkedHashSet[(BlockId, Int)]()

/** Remote blocks to fetch, excluding zero-sized blocks. */
private[this] val remoteBlocks = new HashSet[BlockId]()
Expand Down Expand Up @@ -199,7 +200,7 @@ final class ShuffleBlockFetcherIterator(
while (iter.hasNext) {
val result = iter.next()
result match {
case SuccessFetchResult(_, address, _, buf, _) =>
case SuccessFetchResult(_, _, address, _, buf, _) =>
if (address != blockManager.blockManagerId) {
shuffleMetrics.incRemoteBytesRead(buf.size)
if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
Expand All @@ -224,9 +225,11 @@ final class ShuffleBlockFetcherIterator(
bytesInFlight += req.size
reqsInFlight += 1

// so we can look up the size of each blockID
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val remainingBlocks = new HashSet[String]() ++= sizeMap.keys
// so we can look up the block info of each blockID
val infoMap = req.blocks.map {
case (blockId, size, mapId) => (blockId.toString, (size, mapId))
}.toMap
val remainingBlocks = new HashSet[String]() ++= infoMap.keys
val blockIds = req.blocks.map(_._1.toString)
val address = req.address

Expand All @@ -240,8 +243,8 @@ final class ShuffleBlockFetcherIterator(
// This needs to be released after use.
buf.retain()
remainingBlocks -= blockId
results.put(new SuccessFetchResult(BlockId(blockId), address, sizeMap(blockId), buf,
remainingBlocks.isEmpty))
results.put(new SuccessFetchResult(BlockId(blockId), infoMap(blockId)._2,
address, infoMap(blockId)._1, buf, remainingBlocks.isEmpty))
logDebug("remainingBlocks: " + remainingBlocks)
}
}
Expand All @@ -250,7 +253,7 @@ final class ShuffleBlockFetcherIterator(

override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
results.put(new FailureFetchResult(BlockId(blockId), address, e))
results.put(new FailureFetchResult(BlockId(blockId), infoMap(blockId)._2, address, e))
}
}

Expand Down Expand Up @@ -283,28 +286,28 @@ final class ShuffleBlockFetcherIterator(
for ((address, blockInfos) <- blocksByAddress) {
if (address.executorId == blockManager.blockManagerId.executorId) {
blockInfos.find(_._2 <= 0) match {
case Some((blockId, size)) if size < 0 =>
case Some((blockId, size, _)) if size < 0 =>
throw new BlockException(blockId, "Negative block size " + size)
case Some((blockId, size)) if size == 0 =>
case Some((blockId, size, _)) if size == 0 =>
throw new BlockException(blockId, "Zero-sized blocks should be excluded.")
case None => // do nothing.
}
localBlocks ++= blockInfos.map(_._1)
localBlocks ++= blockInfos.map(info => (info._1, info._3))
localBlockBytes += blockInfos.map(_._2).sum
numBlocksToFetch += localBlocks.size
} else {
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = new ArrayBuffer[(BlockId, Long)]
var curBlocks = new ArrayBuffer[(BlockId, Long, Int)]
while (iterator.hasNext) {
val (blockId, size) = iterator.next()
val (blockId, size, mapId) = iterator.next()
remoteBlockBytes += size
if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
} else if (size == 0) {
throw new BlockException(blockId, "Zero-sized blocks should be excluded.")
} else {
curBlocks += ((blockId, size))
curBlocks += ((blockId, size, mapId))
remoteBlocks += blockId
numBlocksToFetch += 1
curRequestSize += size
Expand All @@ -315,7 +318,7 @@ final class ShuffleBlockFetcherIterator(
remoteRequests += new FetchRequest(address, curBlocks)
logDebug(s"Creating fetch request of $curRequestSize at $address "
+ s"with ${curBlocks.size} blocks")
curBlocks = new ArrayBuffer[(BlockId, Long)]
curBlocks = new ArrayBuffer[(BlockId, Long, Int)]
curRequestSize = 0
}
}
Expand All @@ -341,13 +344,13 @@ final class ShuffleBlockFetcherIterator(
logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}")
val iter = localBlocks.iterator
while (iter.hasNext) {
val blockId = iter.next()
val (blockId, mapId) = iter.next()
try {
val buf = blockManager.getBlockData(blockId)
shuffleMetrics.incLocalBlocksFetched(1)
shuffleMetrics.incLocalBytesRead(buf.size)
buf.retain()
results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId,
results.put(new SuccessFetchResult(blockId, mapId, blockManager.blockManagerId,
buf.size(), buf, false))
} catch {
// If we see an exception, stop immediately.
Expand All @@ -360,7 +363,7 @@ final class ShuffleBlockFetcherIterator(
logError("Error occurred while fetching local blocks, " + ce.getMessage)
case ex: Exception => logError("Error occurred while fetching local blocks", ex)
}
results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
results.put(new FailureFetchResult(blockId, mapId, blockManager.blockManagerId, e))
return
}
}
Expand Down Expand Up @@ -420,7 +423,7 @@ final class ShuffleBlockFetcherIterator(
shuffleMetrics.incFetchWaitTime(fetchWaitTime)

result match {
case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) =>
case r @ SuccessFetchResult(blockId, mapId, address, size, buf, isNetworkReqDone) =>
if (address != blockManager.blockManagerId) {
numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1
shuffleMetrics.incRemoteBytesRead(buf.size)
Expand All @@ -429,7 +432,7 @@ final class ShuffleBlockFetcherIterator(
}
shuffleMetrics.incRemoteBlocksFetched(1)
}
if (!localBlocks.contains(blockId)) {
if (!localBlocks.contains((blockId, mapId))) {
bytesInFlight -= size
}
if (isNetworkReqDone) {
Expand All @@ -453,7 +456,7 @@ final class ShuffleBlockFetcherIterator(
// since the last call.
val msg = s"Received a zero-size buffer for block $blockId from $address " +
s"(expectedApproxSize = $size, isNetworkReqDone=$isNetworkReqDone)"
throwFetchFailedException(blockId, address, new IOException(msg))
throwFetchFailedException(blockId, mapId, address, new IOException(msg))
}

val in = try {
Expand All @@ -469,7 +472,7 @@ final class ShuffleBlockFetcherIterator(
case e: IOException => logError("Failed to create input stream from local block", e)
}
buf.release()
throwFetchFailedException(blockId, address, e)
throwFetchFailedException(blockId, mapId, address, e)
}
try {
input = streamWrapper(blockId, in)
Expand All @@ -487,11 +490,11 @@ final class ShuffleBlockFetcherIterator(
buf.release()
if (buf.isInstanceOf[FileSegmentManagedBuffer]
|| corruptedBlocks.contains(blockId)) {
throwFetchFailedException(blockId, address, e)
throwFetchFailedException(blockId, mapId, address, e)
} else {
logWarning(s"got an corrupted block $blockId from $address, fetch again", e)
corruptedBlocks += blockId
fetchRequests += FetchRequest(address, Array((blockId, size)))
fetchRequests += FetchRequest(address, Array((blockId, size, mapId)))
result = null
}
} finally {
Expand All @@ -503,8 +506,8 @@ final class ShuffleBlockFetcherIterator(
}
}

case FailureFetchResult(blockId, address, e) =>
throwFetchFailedException(blockId, address, e)
case FailureFetchResult(blockId, mapId, address, e) =>
throwFetchFailedException(blockId, mapId, address, e)
}

// Send fetch requests up to maxBytesInFlight
Expand All @@ -517,6 +520,7 @@ final class ShuffleBlockFetcherIterator(
input,
this,
currentResult.blockId,
currentResult.mapId,
currentResult.address,
detectCorrupt && streamCompressedOrEncrypted))
}
Expand Down Expand Up @@ -583,10 +587,11 @@ final class ShuffleBlockFetcherIterator(

private[storage] def throwFetchFailedException(
blockId: BlockId,
mapId: Int,
address: BlockManagerId,
e: Throwable) = {
blockId match {
case ShuffleBlockId(shufId, mapId, reduceId) =>
case ShuffleBlockId(shufId, _, reduceId) =>
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
case _ =>
throw new SparkException(
Expand All @@ -604,6 +609,7 @@ private class BufferReleasingInputStream(
private[storage] val delegate: InputStream,
private val iterator: ShuffleBlockFetcherIterator,
private val blockId: BlockId,
private val mapId: Int,
private val address: BlockManagerId,
private val detectCorruption: Boolean)
extends InputStream {
Expand All @@ -615,7 +621,7 @@ private class BufferReleasingInputStream(
} catch {
case e: IOException if detectCorruption =>
IOUtils.closeQuietly(this)
iterator.throwFetchFailedException(blockId, address, e)
iterator.throwFetchFailedException(blockId, mapId, address, e)
}
}

Expand All @@ -637,7 +643,7 @@ private class BufferReleasingInputStream(
} catch {
case e: IOException if detectCorruption =>
IOUtils.closeQuietly(this)
iterator.throwFetchFailedException(blockId, address, e)
iterator.throwFetchFailedException(blockId, mapId, address, e)
}
}

Expand All @@ -649,7 +655,7 @@ private class BufferReleasingInputStream(
} catch {
case e: IOException if detectCorruption =>
IOUtils.closeQuietly(this)
iterator.throwFetchFailedException(blockId, address, e)
iterator.throwFetchFailedException(blockId, mapId, address, e)
}
}

Expand All @@ -659,7 +665,7 @@ private class BufferReleasingInputStream(
} catch {
case e: IOException if detectCorruption =>
IOUtils.closeQuietly(this)
iterator.throwFetchFailedException(blockId, address, e)
iterator.throwFetchFailedException(blockId, mapId, address, e)
}
}

Expand Down Expand Up @@ -694,9 +700,10 @@ object ShuffleBlockFetcherIterator {
* A request to fetch blocks from a remote BlockManager.
* @param address remote BlockManager to fetch from.
* @param blocks Sequence of tuple, where the first element is the block id,
* and the second element is the estimated size, used to calculate bytesInFlight.
* and the second element is the estimated size, used to calculate bytesInFlight,
* the third element is the mapId.
*/
case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) {
case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long, Int)]) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it's a tuple3 with int and long elements. I think it's better to create a class for it to make the code easier to read.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, add FetchBlockInfo class for this in d2215b2.

val size = blocks.map(_._2).sum
}

Expand All @@ -711,6 +718,7 @@ object ShuffleBlockFetcherIterator {
/**
* Result of a fetch from a remote block successfully.
* @param blockId block id
* @param mapId mapId for this block
* @param address BlockManager that the block was fetched from.
* @param size estimated size of the block. Note that this is NOT the exact bytes.
* Size of remote block is used to calculate bytesInFlight.
Expand All @@ -719,6 +727,7 @@ object ShuffleBlockFetcherIterator {
*/
private[storage] case class SuccessFetchResult(
blockId: BlockId,
mapId: Int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need the map index here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if I follow correctly, the reason is that even a SuccessFetchResult still sometimes results in a FetchFailure back to driver (eg. error decompressing the buffer). And the FetchFailure needs the mapIndex, because the mapstatus is still stored by mapIndex, so this tells us what we need to remove in the handling in DAGScheduler.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's right. Here we need to guarantee all paths to throwFetchFailedException has mapIndex pass though, even a SuccessFetchResult still can trigger fetch failed exception.

address: BlockManagerId,
size: Long,
buf: ManagedBuffer,
Expand All @@ -730,11 +739,13 @@ object ShuffleBlockFetcherIterator {
/**
* Result of a fetch from a remote block unsuccessfully.
* @param blockId block id
* @param mapId mapId for this block
* @param address BlockManager that the block was attempted to be fetched from
* @param e the failure exception
*/
private[storage] case class FailureFetchResult(
blockId: BlockId,
mapId: Int,
address: BlockManagerId,
e: Throwable)
extends FetchResult
Expand Down
Loading