diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 1745d52c8192..e34c796b60c1 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -18,7 +18,9 @@ package org.apache.spark.network import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.storage.BlockId +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.StorageLevel private[spark] trait BlockDataManager { @@ -29,6 +31,12 @@ trait BlockDataManager { */ def getBlockData(blockId: BlockId): ManagedBuffer + /** + * Interface to get local block data managed by given BlockManagerId. + * Throws an exception if the block cannot be found or cannot be read successfully. + */ + def getBlockData(blockId: BlockId, blockManagerId: BlockManagerId): ManagedBuffer + /** * Put the block locally, using the given storage level. */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index cc5f933393ad..d8b96aa27fc0 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -98,8 +98,9 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) } } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { - val file = blockManager.diskBlockManager.getFile(blockId) + override def getBlockData(blockId: ShuffleBlockId, blockManagerId: BlockManagerId) + : ManagedBuffer = { + val file = blockManager.diskBlockManager.getFile(blockId, blockManagerId) new FileSegmentManagedBuffer(transportConf, file, 0, file.length) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index fadb8fe7ed0a..6c5134411ceb 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -49,13 +49,19 @@ private[spark] class IndexShuffleBlockResolver( private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") - def getDataFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) - } + def getDataFile(shuffleId: Int, mapId: Int): File = + getDataFile(shuffleId, mapId, blockManager.blockManagerId) - private def getIndexFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) - } + private def getDataFile(shuffleId: Int, mapId: Int, blockManagerId: BlockManagerId): File = + blockManager.diskBlockManager.getFile( + ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID), blockManagerId) + + private def getIndexFile(shuffleId: Int, mapId: Int): File = + getIndexFile(shuffleId, mapId, blockManager.blockManagerId) + + private def getIndexFile(shuffleId: Int, mapId: Int, blockManagerId: BlockManagerId): File = + blockManager.diskBlockManager.getFile( + ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID), blockManagerId) /** * Remove data file and index file that contain the output data from one map. @@ -183,10 +189,12 @@ private[spark] class IndexShuffleBlockResolver( } } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { + override def getBlockData( + blockId: ShuffleBlockId, + blockManagerId: BlockManagerId): ManagedBuffer = { // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index - val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) + val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId, blockManagerId) val in = new DataInputStream(new FileInputStream(indexFile)) try { @@ -195,7 +203,7 @@ private[spark] class IndexShuffleBlockResolver( val nextOffset = in.readLong() new FileSegmentManagedBuffer( transportConf, - getDataFile(blockId.shuffleId, blockId.mapId), + getDataFile(blockId.shuffleId, blockId.mapId, blockManagerId), offset, nextOffset - offset) } finally { diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index 4342b0d598b1..907ef68ecf88 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -17,9 +17,8 @@ package org.apache.spark.shuffle -import java.nio.ByteBuffer import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} private[spark] /** @@ -35,7 +34,7 @@ trait ShuffleBlockResolver { * Retrieve the data for the specified block. If the data for that block is not available, * throws an unspecified exception. */ - def getBlockData(blockId: ShuffleBlockId): ManagedBuffer + def getBlockData(blockId: ShuffleBlockId, blockManagerId: BlockManagerId): ManagedBuffer def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ab0007fb7899..03bf8ac7fda2 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -187,7 +187,8 @@ private[spark] class BlockManager( blockManagerId } - master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) + master.registerBlockManager( + blockManagerId, maxMemory, diskBlockManager.getLocalDirsPath(), slaveEndpoint) // Register Executors' configuration with the local shuffle service, if one should exist. if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { @@ -250,7 +251,8 @@ private[spark] class BlockManager( def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo("BlockManager re-registering with master") - master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) + master.registerBlockManager( + blockManagerId, maxMemory, diskBlockManager.getLocalDirsPath(), slaveEndpoint) reportAllBlocks() } @@ -286,9 +288,10 @@ private[spark] class BlockManager( * Interface to get local block data. Throws an exception if the block cannot be found or * cannot be read successfully. */ - override def getBlockData(blockId: BlockId): ManagedBuffer = { + override def getBlockData(blockId: BlockId, blockManagerId: BlockManagerId): ManagedBuffer = { if (blockId.isShuffle) { - shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + shuffleManager.shuffleBlockResolver.getBlockData( + blockId.asInstanceOf[ShuffleBlockId], blockManagerId) } else { val blockBytesOpt = doGetLocal(blockId, asBlockResult = false) .asInstanceOf[Option[ByteBuffer]] @@ -301,6 +304,10 @@ private[spark] class BlockManager( } } + override def getBlockData(blockId: BlockId): ManagedBuffer = { + getBlockData(blockId, this.blockManagerId) + } + /** * Put the block locally, using the given storage level. */ @@ -432,7 +439,8 @@ private[spark] class BlockManager( // TODO: This should gracefully handle case where local block is not available. Currently // downstream code will throw an exception. Option( - shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) + shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId], blockManagerId) + .nioByteBuffer()) } else { doGetLocal(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index f45bff34d4db..3c3f77793323 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -43,9 +43,12 @@ class BlockManagerMaster( /** Register the BlockManager's id with the driver. */ def registerBlockManager( - blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = { + blockManagerId: BlockManagerId, + maxMemSize: Long, + localDirsPath: Array[String], + slaveEndpoint: RpcEndpointRef): Unit = { logInfo("Trying to register BlockManager") - tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) + tell(RegisterBlockManager(blockManagerId, maxMemSize, localDirsPath, slaveEndpoint)) logInfo("Registered BlockManager") } @@ -74,6 +77,12 @@ class BlockManagerMaster( GetLocationsMultipleBlockIds(blockIds)) } + /** Return other blockmanager's local dirs with the given blockManagerId */ + def getLocalDirsPath(blockManagerId: BlockManagerId): Map[BlockManagerId, Array[String]] = { + driverEndpoint.askWithRetry[Map[BlockManagerId, Array[String]]]( + GetLocalDirsPath(blockManagerId)) + } + /** * Check if block manager master has a block. Note that this can be used to check for only * those blocks that are reported to block manager master. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 7db6035553ae..e6450b097e3f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -19,17 +19,16 @@ package org.apache.spark.storage import java.util.{HashMap => JHashMap} -import scala.collection.immutable.HashSet -import scala.collection.mutable import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint} -import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.{Logging, SparkConf} /** * BlockManagerMasterEndpoint is an [[ThreadSafeRpcEndpoint]] on the master node to track statuses @@ -56,8 +55,8 @@ class BlockManagerMasterEndpoint( private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) => - register(blockManagerId, maxMemSize, slaveEndpoint) + case RegisterBlockManager(blockManagerId, maxMemSize, localDirsPath, slaveEndpoint) => + register(blockManagerId, maxMemSize, localDirsPath, slaveEndpoint) context.reply(true) case _updateBlockInfo @ UpdateBlockInfo( @@ -81,6 +80,9 @@ class BlockManagerMasterEndpoint( case GetMemoryStatus => context.reply(memoryStatus) + case GetLocalDirsPath(blockManagerId) => + context.reply(getLocalDirsPath(blockManagerId)) + case GetStorageStatus => context.reply(storageStatus) @@ -235,11 +237,20 @@ class BlockManagerMasterEndpoint( // Return a map from the block manager id to max memory and remaining memory. private def memoryStatus: Map[BlockManagerId, (Long, Long)] = { - blockManagerInfo.map { case(blockManagerId, info) => + blockManagerInfo.map { case (blockManagerId, info) => (blockManagerId, (info.maxMem, info.remainingMem)) }.toMap } + // Return the local dirs of a block manager with the given blockManagerId + private def getLocalDirsPath(blockManagerId: BlockManagerId) + : Map[BlockManagerId, Array[String]] = { + blockManagerInfo + .filter { case (id, _) => id != blockManagerId && id.host == blockManagerId.host } + .mapValues { info => info.localDirsPath } + .toMap + } + private def storageStatus: Array[StorageStatus] = { blockManagerInfo.map { case (blockManagerId, info) => new StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala) @@ -299,7 +310,11 @@ class BlockManagerMasterEndpoint( ).map(_.flatten.toSeq) } - private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) { + private def register( + id: BlockManagerId, + maxMemSize: Long, + localDirsPath: Array[String], + slaveEndpoint: RpcEndpointRef): Unit = { val time = System.currentTimeMillis() if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -316,7 +331,7 @@ class BlockManagerMasterEndpoint( blockManagerIdByExecutor(id.executorId) = id blockManagerInfo(id) = new BlockManagerInfo( - id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) + id, System.currentTimeMillis(), maxMemSize, localDirsPath, slaveEndpoint) } listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) } @@ -423,6 +438,7 @@ private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, val maxMem: Long, + val localDirsPath: Array[String], val slaveEndpoint: RpcEndpointRef) extends Logging { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 376e9eb48843..2c32ca174c1c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -50,6 +50,7 @@ private[spark] object BlockManagerMessages { case class RegisterBlockManager( blockManagerId: BlockManagerId, maxMemSize: Long, + localDirsPath: Array[String], sender: RpcEndpointRef) extends ToBlockManagerMaster @@ -109,4 +110,6 @@ private[spark] object BlockManagerMessages { case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster case class HasCachedBlocks(executorId: String) extends ToBlockManagerMaster + + case class GetLocalDirsPath(blockManagerId: BlockManagerId) extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index f7e84a2c2e14..6bb504f657f4 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -20,6 +20,8 @@ package org.apache.spark.storage import java.util.UUID import java.io.{IOException, File} +import scala.collection.mutable + import org.apache.spark.{SparkConf, Logging} import org.apache.spark.executor.ExecutorExitCode import org.apache.spark.util.{ShutdownHookManager, Utils} @@ -51,16 +53,26 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon // of subDirs(i) is protected by the lock of subDirs(i) private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) + // Cache local directories for other block managers + private val localDirsByOtherBlkMgr = new mutable.HashMap[BlockManagerId, Array[String]] + private val shutdownHook = addShutdownHook() + def blockManagerId: BlockManagerId = blockManager.blockManagerId + + def getLocalDirsPath(): Array[String] = { + localDirs.map(_.getAbsolutePath) + } + + def getLocalDirsPath(blockManagerId: BlockManagerId): Map[BlockManagerId, Array[String]] = { + blockManager.master.getLocalDirsPath(blockManagerId) + } + /** Looks up a file by hashing it into one of our local subdirectories. */ // This method should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getFile(). def getFile(filename: String): File = { - // Figure out which local directory it hashes to, and which subdirectory in that - val hash = Utils.nonNegativeHash(filename) - val dirId = hash % localDirs.length - val subDirId = (hash / localDirs.length) % subDirsPerLocalDir + val (dirId, subDirId) = getDirInfo(filename, localDirs.length) // Create the subdirectory if it doesn't already exist val subDir = subDirs(dirId).synchronized { @@ -82,6 +94,39 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon def getFile(blockId: BlockId): File = getFile(blockId.name) + def getFile(blockId: BlockId, blockManagerId: BlockManagerId): File = { + if (this.blockManagerId == blockManagerId) { + getFile(blockId) + } else { + // Get a file from another block manager with given blockManagerId + val dirs = localDirsByOtherBlkMgr.synchronized { + localDirsByOtherBlkMgr.getOrElse(blockManagerId, { + localDirsByOtherBlkMgr ++= getLocalDirsPath(this.blockManagerId) + localDirsByOtherBlkMgr.getOrElse(blockManagerId, { + throw new IOException(s"Block manager (${blockManagerId}) not found " + + s"in host '${this.blockManagerId.host}'") + }) + }) + } + val (dirId, subDirId) = getDirInfo(blockId.name, dirs.length) + val file = new File(new File(dirs(dirId), "%02x".format(subDirId)), blockId.name) + if (!file.exists()) { + throw new IOException(s"File '${file}' not found in local dir") + } + logInfo(s"${this.blockManagerId} bypasses network access and " + + s"directly reads file '${file}' in local dir") + file + } + } + + def getDirInfo(filename: String, numDirs: Int): (Int, Int) = { + // Figure out which local directory it hashes to, and which subdirectory in that + val hash = Utils.nonNegativeHash(filename) + val dirId = hash % numDirs + val subDirName = (hash / numDirs) % subDirsPerLocalDir + (dirId, subDirName) + } + /** Check if disk block manager has a block. */ def containsBlock(blockId: BlockId): Boolean = { getFile(blockId.name).exists() @@ -166,7 +211,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon // Only perform cleanup if an external service is not serving our shuffle files. // Also blockManagerId could be null if block manager is not initialized properly. if (!blockManager.externalShuffleServiceEnabled || - (blockManager.blockManagerId != null && blockManager.blockManagerId.isDriver)) { + (this.blockManagerId != null && blockManager.blockManagerId.isDriver)) { localDirs.foreach { localDir => if (localDir.isDirectory() && localDir.exists()) { try { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 0d0448feb5b0..54383ec8d648 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -20,7 +20,7 @@ package org.apache.spark.storage import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue -import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} +import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet, Queue} import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkException, TaskContext} @@ -58,6 +58,17 @@ final class ShuffleBlockFetcherIterator( import ShuffleBlockFetcherIterator._ + private[this] val enableExternalShuffleService = + blockManager.conf.getBoolean("spark.shuffle.service.enabled", false) + + /** + * If this option enabled, bypass unnecessary network interaction + * if multiple block managers work in a single host. + */ + private[this] val enableBypassNetworkAccess = + blockManager.conf.getBoolean("spark.shuffle.bypassNetworkAccess", false) && + !enableExternalShuffleService + /** * Total number of blocks to fetch. This can be smaller than the total number of blocks * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]]. @@ -74,8 +85,12 @@ final class ShuffleBlockFetcherIterator( private[this] val startTime = System.currentTimeMillis - /** Local blocks to fetch, excluding zero-sized blocks. */ - private[this] val localBlocks = new ArrayBuffer[BlockId]() + /** + * Local blocks to fetch, excluding zero-sized blocks. + * This iterator bypasses remote access to fetch the blocks that + * other block managers holds in an identical host. + */ + private[this] val localBlocks = new HashMap[BlockManagerId, ArrayBuffer[BlockId]] /** Remote blocks to fetch, excluding zero-sized blocks. */ private[this] val remoteBlocks = new HashSet[BlockId]() @@ -188,10 +203,12 @@ final class ShuffleBlockFetcherIterator( var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { totalBlocks += blockInfos.size - if (address.executorId == blockManager.blockManagerId.executorId) { + if (address.executorId == blockManager.blockManagerId.executorId || + (enableBypassNetworkAccess && blockManager.blockManagerId.host == address.host)) { // Filter out zero-sized blocks - localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) - numBlocksToFetch += localBlocks.size + val blocks = blockInfos.filter(_._2 != 0).map(_._1) + localBlocks.getOrElseUpdate(address, ArrayBuffer()) ++= blocks + numBlocksToFetch += blocks.size } else { val iterator = blockInfos.iterator var curRequestSize = 0L @@ -233,19 +250,28 @@ final class ShuffleBlockFetcherIterator( private[this] def fetchLocalBlocks() { val iter = localBlocks.iterator while (iter.hasNext) { - val blockId = iter.next() - try { - val buf = blockManager.getBlockData(blockId) - shuffleMetrics.incLocalBlocksFetched(1) - shuffleMetrics.incLocalBytesRead(buf.size) - buf.retain() - results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf)) - } catch { - case e: Exception => - // If we see an exception, stop immediately. - logError(s"Error occurred while fetching local blocks", e) - results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e)) - return + val (blockManagerId, blockIds) = iter.next() + val blockIter = blockIds.iterator + while (blockIter.hasNext) { + val blockId = blockIter.next() + assert(blockId.isShuffle) + try { + val buf = if (!enableExternalShuffleService) { + blockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId], blockManagerId) + } else { + blockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + } + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + buf.retain() + results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf)) + } catch { + case NonFatal(e) => + // If we see an exception, stop immediately. + logError(s"Error occurred while fetching local blocks", e) + results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e)) + return + } } } } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 1c3f2bc315dd..676948856a0a 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -86,6 +86,18 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(groups.map(_._2).sum === 2000) } + test("bypass remote access") { + val conf = new SparkConf().set("spark.shuffle.bypassNetworkAccess", "true") + Seq("hash", "sort", "tungsten-sort").map { shuffle => + sc = new SparkContext(clusterUrl, "test", conf.clone.set("spark.shuffle.manager", shuffle)) + val rdd = sc.parallelize((0 until 1000).map(x => (x % 4, 1)), 5) + val groups = rdd.reduceByKey(_ + _).collect + assert(groups.size === 4) + assert(groups.forall(_._2 == 250)) + resetSparkContext() + } + } + test("accumulators") { sc = new SparkContext(clusterUrl, "test") val accum = sc.accumulator(0) diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 26a372d6a905..6bd5135c1984 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -77,6 +77,8 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // can ensure retain() and release() are properly called. val blockManager = mock(classOf[BlockManager]) + when(blockManager.conf).thenReturn(testConf) + // Create a return function to use for the mocked wrapForCompression method that just returns // the original input stream. val dummyCompressionFunction = new Answer[InputStream] { @@ -104,7 +106,8 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to // fetch shuffle data. val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) - when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) + when(blockManager.getBlockData(shuffleBlockId, localBlockManagerId)) + .thenReturn(managedBuffer) when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream]))) .thenAnswer(dummyCompressionFunction) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 688f56f4665f..f566491e380a 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -17,21 +17,22 @@ package org.apache.spark.storage -import java.io.{File, FileWriter} +import java.io.{File, FileWriter, IOException} import scala.language.reflectiveCalls -import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} +import org.mockito.Matchers.{eq => meq} +import org.mockito.Mockito.{mock, times, verify, when} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, PrivateMethodTester} -import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils +import org.apache.spark.{SparkConf, SparkFunSuite} -class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll { +class DiskBlockManagerSuite extends SparkFunSuite + with BeforeAndAfterEach with BeforeAndAfterAll with PrivateMethodTester { private val testConf = new SparkConf(false) private var rootDir0: File = _ private var rootDir1: File = _ - private var rootDirs: String = _ val blockManager = mock(classOf[BlockManager]) when(blockManager.conf).thenReturn(testConf) @@ -41,7 +42,7 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B super.beforeAll() rootDir0 = Utils.createTempDir() rootDir1 = Utils.createTempDir() - rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath + testConf.set("spark.local.dir", rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath) } override def afterAll() { @@ -51,9 +52,7 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B } override def beforeEach() { - val conf = testConf.clone - conf.set("spark.local.dir", rootDirs) - diskBlockManager = new DiskBlockManager(blockManager, conf) + diskBlockManager = new DiskBlockManager(blockManager, testConf.clone) } override def afterEach() { @@ -81,4 +80,58 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B for (i <- 0 until numBytes) writer.write(i) writer.close() } + + test("bypassing network access") { + val mockBlockManagerMaster = mock(classOf[BlockManagerMaster]) + val mockBlockManager = mock(classOf[BlockManager]) + + // Assume two executors in an identical host + val localBmId1 = BlockManagerId("test-exec1", "test-client1", 1) + val localBmId2 = BlockManagerId("test-exec2", "test-client1", 2) + + // Assume that localBmId2 holds 'shuffle_1_0_0' + val blockIdInLocalBmId2 = ShuffleBlockId(1, 0, 0) + val tempDir = Utils.createTempDir() + try { + // Create mock classes for testing + when(mockBlockManagerMaster.getLocalDirsPath(meq(localBmId1))) + .thenReturn(Map(localBmId2 -> Array(tempDir.getAbsolutePath))) + when(mockBlockManager.conf).thenReturn(testConf) + when(mockBlockManager.master).thenReturn(mockBlockManagerMaster) + when(mockBlockManager.blockManagerId).thenReturn(localBmId1) + + val testDiskBlockManager = new DiskBlockManager(mockBlockManager, testConf.clone) + + val getBlockDir: String => File = (s: String) => { + val (_, subDirId) = testDiskBlockManager.getDirInfo(s, 1) + new File(tempDir, "%02x".format(subDirId)) + } + + // Create a dummy file for a shuffle block + val blockDir = getBlockDir(blockIdInLocalBmId2.name) + assert(blockDir.mkdir()) + val dummyBlockFile = new File(blockDir, blockIdInLocalBmId2.name) + assert(dummyBlockFile.createNewFile()) + + val file = testDiskBlockManager.getFile( + blockIdInLocalBmId2, localBmId2) + assert(dummyBlockFile.getName === file.getName) + assert(dummyBlockFile.toString.contains(tempDir.toString)) + + verify(mockBlockManagerMaster, times(1)).getLocalDirsPath(meq(localBmId1)) + verify(mockBlockManager, times(1)).conf + verify(mockBlockManager, times(1)).master + verify(mockBlockManager, times(3)).blockManagerId + + // Throw an IOException if given shuffle file not found + val blockIdNotInLocalBmId2 = ShuffleBlockId(2, 0, 0) + val errMsg = intercept[IOException] { + testDiskBlockManager.getFile(blockIdNotInLocalBmId2, localBmId2) + } + assert(errMsg.getMessage contains s"File '${getBlockDir(blockIdNotInLocalBmId2.name)}/" + + s"${blockIdNotInLocalBmId2}' not found in local dir") + } finally { + Utils.deleteRecursively(tempDir) + } + } } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 828153bdbfc4..d7b3b4247420 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -27,16 +27,25 @@ import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.scalatest.PrivateMethodTester +import org.scalatest.{BeforeAndAfterAll, PrivateMethodTester} -import org.apache.spark.{SparkFunSuite, TaskContext} +import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.shuffle.FetchFailedException -class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { +class ShuffleBlockFetcherIteratorSuite + extends SparkFunSuite with BeforeAndAfterAll with PrivateMethodTester +{ + private val testConf = new SparkConf(false) + + override def beforeAll() { + super.beforeAll() + testConf.set("spark.shuffle.bypassNetworkAccess", "false") + } + // Some of the tests are quite tricky because we are testing the cleanup behavior // in the presence of faults. @@ -70,15 +79,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT test("successful 3 local reads + 2 remote reads") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(testConf).when(blockManager).conf doReturn(localBmId).when(blockManager).blockManagerId // Make sure blockManager.getBlockData would return the blocks - val localBlocks = Map[BlockId, ManagedBuffer]( + val localBlocks = Map[ShuffleBlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) localBlocks.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getBlockData(meq(blockId)) + doReturn(buf).when(blockManager).getBlockData(meq(blockId), meq(localBmId)) } // Make sure remote blocks would return @@ -102,14 +112,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024) // 3 local blocks fetched in initialization - verify(blockManager, times(3)).getBlockData(any()) + verify(blockManager, times(3)).getBlockData(any(), any()) for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") val (blockId, inputStream) = iterator.next() // Make sure we release buffers when a wrapped input stream is closed. - val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) + val mockBuf = localBlocks.getOrElse( + blockId.asInstanceOf[ShuffleBlockId], remoteBlocks(blockId)) // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream] verify(mockBuf, times(0)).release() @@ -126,13 +137,70 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) - verify(blockManager, times(3)).getBlockData(any()) + verify(blockManager, times(3)).getBlockData(any(), any()) verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any()) } + test("bypass unnecessary network access if block managers share an identical host") { + val blockManager = mock(classOf[BlockManager]) + + // Assume two executors in an identical host + val localBmId1 = BlockManagerId("test-exec1", "test-client1", 1) + val localBmId2 = BlockManagerId("test-exec2", "test-client1", 2) + + // Enable an option to bypass network access + doReturn(testConf.clone.set("spark.shuffle.bypassNetworkAccess", "true")) + .when(blockManager).conf + doReturn(localBmId1).when(blockManager).blockManagerId + + // Make sure blockManager.getBlockData would return the blocks + val localBlocksInBmId1 = Map[ShuffleBlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) + localBlocksInBmId1.foreach { case (blockId, buf) => + doReturn(buf).when(blockManager).getBlockData(meq(blockId), meq(localBmId1)) + } + val localBlocksInBmId2 = Map[ShuffleBlockId, ManagedBuffer]( + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) + localBlocksInBmId2.foreach { case (blockId, buf) => + doReturn(buf).when(blockManager).getBlockData(meq(blockId), meq(localBmId2)) + } + + // Create mock transfer + val transfer = mock(classOf[BlockTransferService]) + when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = {} + }) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (localBmId1, localBlocksInBmId1.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq), + (localBmId2, localBlocksInBmId2.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) + ) + + val iterator = new ShuffleBlockFetcherIterator( + TaskContext.empty(), + transfer, + blockManager, + blocksByAddress, + 48 * 1024 * 1024) + + // Skip unnecessary remote reads + verify(blockManager, times(3)).getBlockData(any(), any()) + + for (i <- 0 until 3) { + assert(iterator.hasNext, s"iterator should have 3 elements but actually has $i elements") + iterator.next() + } + + // As a result, only 3 local reads (2 remote access skipped) + verify(blockManager, times(3)).getBlockData(any(), any()) + verify(transfer, times(0)).fetchBlocks(any(), any(), any(), any(), any()) + } + test("release current unexhausted buffer in case the task completes early") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(testConf).when(blockManager).conf doReturn(localBmId).when(blockManager).blockManagerId // Make sure remote blocks would return @@ -194,6 +262,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT test("fail all blocks if any of the remote request fails") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(testConf).when(blockManager).conf doReturn(localBmId).when(blockManager).blockManagerId // Make sure remote blocks would return diff --git a/docs/configuration.md b/docs/configuration.md index 741d6b2b37a8..b267919ebfaa 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -477,6 +477,14 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.codec. + + spark.shuffle.bypassNetworkAccess + false + + Whether to bypass network interaction if block managers share an identical host + (e.g., multiple block managers work in a single host). + + #### Spark UI