Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,10 @@ final class ShuffleBlockFetcherIterator(
hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
} else {
remoteBlockBytes += blockInfos.map(_._2).sum
collectFetchRequests(address, blockInfos, collectedRemoteRequests)
val (_, timeCost) = Utils.timeTakenMs[Unit] {
collectFetchRequests(address, blockInfos, collectedRemoteRequests)
}
logDebug(s"Collected remote fetch requests for $address in $timeCost ms")
}
}
val numRemoteBlocks = collectedRemoteRequests.map(_.blocks.size).sum
Expand Down Expand Up @@ -408,10 +411,10 @@ final class ShuffleBlockFetcherIterator(
curBlocks: Seq[FetchBlockInfo],
address: BlockManagerId,
isLast: Boolean,
collectedRemoteRequests: ArrayBuffer[FetchRequest]): Seq[FetchBlockInfo] = {
collectedRemoteRequests: ArrayBuffer[FetchRequest]): ArrayBuffer[FetchBlockInfo] = {
val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, doBatchFetch)
numBlocksToFetch += mergedBlocks.size
var retBlocks = Seq.empty[FetchBlockInfo]
val retBlocks = new ArrayBuffer[FetchBlockInfo]
if (mergedBlocks.length <= maxBlocksInFlightPerAddress) {
collectedRemoteRequests += createFetchRequest(mergedBlocks, address)
} else {
Expand All @@ -421,7 +424,7 @@ final class ShuffleBlockFetcherIterator(
} else {
// The last group does not exceed `maxBlocksInFlightPerAddress`. Put it back
// to `curBlocks`.
retBlocks = blocks
retBlocks ++= blocks
numBlocksToFetch -= blocks.size
}
}
Expand All @@ -435,26 +438,24 @@ final class ShuffleBlockFetcherIterator(
collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = {
val iterator = blockInfos.iterator
var curRequestSize = 0L
var curBlocks = Seq.empty[FetchBlockInfo]
var curBlocks = new ArrayBuffer[FetchBlockInfo]()

while (iterator.hasNext) {
val (blockId, size, mapIndex) = iterator.next()
assertPositiveBlockSize(blockId, size)
curBlocks = curBlocks ++ Seq(FetchBlockInfo(blockId, size, mapIndex))
curBlocks += FetchBlockInfo(blockId, size, mapIndex)
curRequestSize += size
// For batch fetch, the actual block in flight should count for merged block.
val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= maxBlocksInFlightPerAddress
if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
curBlocks = createFetchRequests(curBlocks, address, isLast = false,
curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = false,
collectedRemoteRequests)
curRequestSize = curBlocks.map(_.size).sum
}
}
// Add in the final request
if (curBlocks.nonEmpty) {
curBlocks = createFetchRequests(curBlocks, address, isLast = true,
collectedRemoteRequests)
curRequestSize = curBlocks.map(_.size).sum
createFetchRequests(curBlocks.toSeq, address, isLast = true, collectedRemoteRequests)
}
}

Expand Down Expand Up @@ -994,7 +995,7 @@ object ShuffleBlockFetcherIterator {
blocks: Seq[FetchBlockInfo],
doBatchFetch: Boolean): Seq[FetchBlockInfo] = {
val result = if (doBatchFetch) {
var curBlocks = new ArrayBuffer[FetchBlockInfo]
val curBlocks = new ArrayBuffer[FetchBlockInfo]
val mergedBlockInfo = new ArrayBuffer[FetchBlockInfo]

def mergeFetchBlockInfo(toBeMerged: ArrayBuffer[FetchBlockInfo]): FetchBlockInfo = {
Expand Down