diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 17746734df082..c0fcf99fecd2a 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -287,6 +287,17 @@ package object config { .bytesConf(ByteUnit.BYTE) .createWithDefault(100 * 1024 * 1024) + private[spark] val REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS = + ConfigBuilder("spark.reducer.maxBlocksInFlightPerAddress") + .doc("This configuration limits the number of remote blocks being fetched per reduce task" + + " from a given host port. When a large number of blocks are being requested from a given" + + " address in a single fetch or simultaneously, this could crash the serving executor or" + + " Node Manager. This is especially useful to reduce the load on the Node Manager when" + + " external shuffle is enabled. You can mitigate the issue by setting it to a lower value.") + .intConf + .checkValue(_ > 0, "The max no. of blocks in flight cannot be non-positive.") + .createWithDefault(Int.MaxValue) + private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM = ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 2fbac79a2305b..c8d1460300934 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -51,6 +51,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( // Note: we use getSizeAsMb when no suffix is provided for backwards compatibility SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), + SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM), SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 7e2bcf7683bff..a91bbf71b68dd 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -23,7 +23,7 @@ import java.util.concurrent.LinkedBlockingQueue import javax.annotation.concurrent.GuardedBy import scala.collection.mutable -import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging @@ -52,6 +52,8 @@ import org.apache.spark.util.io.ChunkedByteBufferOutputStream * @param streamWrapper A function to wrap the returned input stream. * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point. * @param maxReqsInFlight max number of remote requests to fetch blocks at any given point. + * @param maxBlocksInFlightPerAddress max number of shuffle blocks being fetched at any given point + * for a given remote host:port. * @param maxReqSizeShuffleToMem max size (in bytes) of a request that can be shuffled to memory. * @param detectCorrupt whether to detect any corruption in fetched blocks. */ @@ -64,6 +66,7 @@ final class ShuffleBlockFetcherIterator( streamWrapper: (BlockId, InputStream) => InputStream, maxBytesInFlight: Long, maxReqsInFlight: Int, + maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean) extends Iterator[(BlockId, InputStream)] with TempShuffleFileManager with Logging { @@ -110,12 +113,21 @@ final class ShuffleBlockFetcherIterator( */ private[this] val fetchRequests = new Queue[FetchRequest] + /** + * Queue of fetch requests which could not be issued the first time they were dequeued. These + * requests are tried again when the fetch constraints are satisfied. + */ + private[this] val deferredFetchRequests = new HashMap[BlockManagerId, Queue[FetchRequest]]() + /** Current bytes in flight from our requests */ private[this] var bytesInFlight = 0L /** Current number of requests in flight */ private[this] var reqsInFlight = 0 + /** Current number of blocks in flight per host:port */ + private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, Int]() + /** * The blocks that can't be decompressed successfully, it is used to guarantee that we retry * at most once for those corrupted blocks. @@ -245,7 +257,8 @@ final class ShuffleBlockFetcherIterator( // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5 // nodes, rather than blocking on reading output from one node. val targetRequestSize = math.max(maxBytesInFlight / 5, 1L) - logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize) + logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize + + ", maxBlocksInFlightPerAddress: " + maxBlocksInFlightPerAddress) // Split local and remote blocks. Remote blocks are further split into FetchRequests of size // at most maxBytesInFlight in order to limit the amount of data in flight. @@ -274,11 +287,13 @@ final class ShuffleBlockFetcherIterator( } else if (size < 0) { throw new BlockException(blockId, "Negative block size " + size) } - if (curRequestSize >= targetRequestSize) { + if (curRequestSize >= targetRequestSize || + curBlocks.size >= maxBlocksInFlightPerAddress) { // Add this FetchRequest remoteRequests += new FetchRequest(address, curBlocks) + logDebug(s"Creating fetch request of $curRequestSize at $address " + + s"with ${curBlocks.size} blocks") curBlocks = new ArrayBuffer[(BlockId, Long)] - logDebug(s"Creating fetch request of $curRequestSize at $address") curRequestSize = 0 } } @@ -372,6 +387,7 @@ final class ShuffleBlockFetcherIterator( result match { case r @ SuccessFetchResult(blockId, address, size, buf, isNetworkReqDone) => if (address != blockManager.blockManagerId) { + numBlocksInFlightPerAddress(address) = numBlocksInFlightPerAddress(address) - 1 shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) } @@ -437,12 +453,57 @@ final class ShuffleBlockFetcherIterator( } private def fetchUpToMaxBytes(): Unit = { - // Send fetch requests up to maxBytesInFlight - while (fetchRequests.nonEmpty && - (bytesInFlight == 0 || - (reqsInFlight + 1 <= maxReqsInFlight && - bytesInFlight + fetchRequests.front.size <= maxBytesInFlight))) { - sendRequest(fetchRequests.dequeue()) + // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a remote host + // immediately, defer the request until the next time it can be processed. + + // Process any outstanding deferred fetch requests if possible. + if (deferredFetchRequests.nonEmpty) { + for ((remoteAddress, defReqQueue) <- deferredFetchRequests) { + while (isRemoteBlockFetchable(defReqQueue) && + !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) { + val request = defReqQueue.dequeue() + logDebug(s"Processing deferred fetch request for $remoteAddress with " + + s"${request.blocks.length} blocks") + send(remoteAddress, request) + if (defReqQueue.isEmpty) { + deferredFetchRequests -= remoteAddress + } + } + } + } + + // Process any regular fetch requests if possible. + while (isRemoteBlockFetchable(fetchRequests)) { + val request = fetchRequests.dequeue() + val remoteAddress = request.address + if (isRemoteAddressMaxedOut(remoteAddress, request)) { + logDebug(s"Deferring fetch request for $remoteAddress with ${request.blocks.size} blocks") + val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new Queue[FetchRequest]()) + defReqQueue.enqueue(request) + deferredFetchRequests(remoteAddress) = defReqQueue + } else { + send(remoteAddress, request) + } + } + + def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = { + sendRequest(request) + numBlocksInFlightPerAddress(remoteAddress) = + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size + } + + def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = { + fetchReqQueue.nonEmpty && + (bytesInFlight == 0 || + (reqsInFlight + 1 <= maxReqsInFlight && + bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight)) + } + + // Checks if sending a new fetch request will exceed the max no. of blocks being fetched from a + // given remote address. + def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: FetchRequest): Boolean = { + numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + request.blocks.size > + maxBlocksInFlightPerAddress } } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 6a70cedf769b8..c371cbcf8dff5 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -110,6 +110,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + Int.MaxValue, true) // 3 local blocks fetched in initialization @@ -187,6 +188,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + Int.MaxValue, true) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() @@ -254,6 +256,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + Int.MaxValue, true) // Continue only after the mock calls onBlockFetchFailure @@ -319,6 +322,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + Int.MaxValue, true) // Continue only after the mock calls onBlockFetchFailure @@ -400,6 +404,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024, Int.MaxValue, Int.MaxValue, + Int.MaxValue, false) // Continue only after the mock calls onBlockFetchFailure @@ -457,6 +462,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT (_, in) => in, maxBytesInFlight = Int.MaxValue, maxReqsInFlight = Int.MaxValue, + maxBlocksInFlightPerAddress = Int.MaxValue, maxReqSizeShuffleToMem = 200, detectCorrupt = true) } diff --git a/docs/configuration.md b/docs/configuration.md index f47967cab669b..3c9da8151094c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -520,6 +520,15 @@ Apart from these, the following properties are also available, and may be useful + spark.reducer.maxBlocksInFlightPerAddress + Int.MaxValue + + This configuration limits the number of remote blocks being fetched per reduce task from a + given host port. When a large number of blocks are being requested from a given address in a + single fetch or simultaneously, this could crash the serving executor or Node Manager. This + is especially useful to reduce the load on the Node Manager when external shuffle is enabled. + You can mitigate this issue by setting it to a lower value. + spark.reducer.maxReqSizeShuffleToMem Long.MaxValue