diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
index 1745d52c8192..e34c796b60c1 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
@@ -18,7 +18,9 @@
package org.apache.spark.network
import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.storage.{BlockId, StorageLevel}
+import org.apache.spark.storage.BlockId
+import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.storage.StorageLevel
private[spark]
trait BlockDataManager {
@@ -29,6 +31,12 @@ trait BlockDataManager {
*/
def getBlockData(blockId: BlockId): ManagedBuffer
+ /**
+ * Interface to get local block data managed by given BlockManagerId.
+ * Throws an exception if the block cannot be found or cannot be read successfully.
+ */
+ def getBlockData(blockId: BlockId, blockManagerId: BlockManagerId): ManagedBuffer
+
/**
* Put the block locally, using the given storage level.
*/
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
index cc5f933393ad..d8b96aa27fc0 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
@@ -98,8 +98,9 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
}
}
- override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
- val file = blockManager.diskBlockManager.getFile(blockId)
+ override def getBlockData(blockId: ShuffleBlockId, blockManagerId: BlockManagerId)
+ : ManagedBuffer = {
+ val file = blockManager.diskBlockManager.getFile(blockId, blockManagerId)
new FileSegmentManagedBuffer(transportConf, file, 0, file.length)
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index fadb8fe7ed0a..6c5134411ceb 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -49,13 +49,19 @@ private[spark] class IndexShuffleBlockResolver(
private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
- def getDataFile(shuffleId: Int, mapId: Int): File = {
- blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID))
- }
+ def getDataFile(shuffleId: Int, mapId: Int): File =
+ getDataFile(shuffleId, mapId, blockManager.blockManagerId)
- private def getIndexFile(shuffleId: Int, mapId: Int): File = {
- blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID))
- }
+ private def getDataFile(shuffleId: Int, mapId: Int, blockManagerId: BlockManagerId): File =
+ blockManager.diskBlockManager.getFile(
+ ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID), blockManagerId)
+
+ private def getIndexFile(shuffleId: Int, mapId: Int): File =
+ getIndexFile(shuffleId, mapId, blockManager.blockManagerId)
+
+ private def getIndexFile(shuffleId: Int, mapId: Int, blockManagerId: BlockManagerId): File =
+ blockManager.diskBlockManager.getFile(
+ ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID), blockManagerId)
/**
* Remove data file and index file that contain the output data from one map.
@@ -183,10 +189,12 @@ private[spark] class IndexShuffleBlockResolver(
}
}
- override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
+ override def getBlockData(
+ blockId: ShuffleBlockId,
+ blockManagerId: BlockManagerId): ManagedBuffer = {
// The block is actually going to be a range of a single map output file for this map, so
// find out the consolidated file, then the offset within that from our index
- val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)
+ val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId, blockManagerId)
val in = new DataInputStream(new FileInputStream(indexFile))
try {
@@ -195,7 +203,7 @@ private[spark] class IndexShuffleBlockResolver(
val nextOffset = in.readLong()
new FileSegmentManagedBuffer(
transportConf,
- getDataFile(blockId.shuffleId, blockId.mapId),
+ getDataFile(blockId.shuffleId, blockId.mapId, blockManagerId),
offset,
nextOffset - offset)
} finally {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala
index 4342b0d598b1..907ef68ecf88 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala
@@ -17,9 +17,8 @@
package org.apache.spark.shuffle
-import java.nio.ByteBuffer
import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.storage.ShuffleBlockId
+import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId}
private[spark]
/**
@@ -35,7 +34,7 @@ trait ShuffleBlockResolver {
* Retrieve the data for the specified block. If the data for that block is not available,
* throws an unspecified exception.
*/
- def getBlockData(blockId: ShuffleBlockId): ManagedBuffer
+ def getBlockData(blockId: ShuffleBlockId, blockManagerId: BlockManagerId): ManagedBuffer
def stop(): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index ab0007fb7899..03bf8ac7fda2 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -187,7 +187,8 @@ private[spark] class BlockManager(
blockManagerId
}
- master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint)
+ master.registerBlockManager(
+ blockManagerId, maxMemory, diskBlockManager.getLocalDirsPath(), slaveEndpoint)
// Register Executors' configuration with the local shuffle service, if one should exist.
if (externalShuffleServiceEnabled && !blockManagerId.isDriver) {
@@ -250,7 +251,8 @@ private[spark] class BlockManager(
def reregister(): Unit = {
// TODO: We might need to rate limit re-registering.
logInfo("BlockManager re-registering with master")
- master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint)
+ master.registerBlockManager(
+ blockManagerId, maxMemory, diskBlockManager.getLocalDirsPath(), slaveEndpoint)
reportAllBlocks()
}
@@ -286,9 +288,10 @@ private[spark] class BlockManager(
* Interface to get local block data. Throws an exception if the block cannot be found or
* cannot be read successfully.
*/
- override def getBlockData(blockId: BlockId): ManagedBuffer = {
+ override def getBlockData(blockId: BlockId, blockManagerId: BlockManagerId): ManagedBuffer = {
if (blockId.isShuffle) {
- shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
+ shuffleManager.shuffleBlockResolver.getBlockData(
+ blockId.asInstanceOf[ShuffleBlockId], blockManagerId)
} else {
val blockBytesOpt = doGetLocal(blockId, asBlockResult = false)
.asInstanceOf[Option[ByteBuffer]]
@@ -301,6 +304,10 @@ private[spark] class BlockManager(
}
}
+ override def getBlockData(blockId: BlockId): ManagedBuffer = {
+ getBlockData(blockId, this.blockManagerId)
+ }
+
/**
* Put the block locally, using the given storage level.
*/
@@ -432,7 +439,8 @@ private[spark] class BlockManager(
// TODO: This should gracefully handle case where local block is not available. Currently
// downstream code will throw an exception.
Option(
- shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer())
+ shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId], blockManagerId)
+ .nioByteBuffer())
} else {
doGetLocal(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]]
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index f45bff34d4db..3c3f77793323 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -43,9 +43,12 @@ class BlockManagerMaster(
/** Register the BlockManager's id with the driver. */
def registerBlockManager(
- blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = {
+ blockManagerId: BlockManagerId,
+ maxMemSize: Long,
+ localDirsPath: Array[String],
+ slaveEndpoint: RpcEndpointRef): Unit = {
logInfo("Trying to register BlockManager")
- tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint))
+ tell(RegisterBlockManager(blockManagerId, maxMemSize, localDirsPath, slaveEndpoint))
logInfo("Registered BlockManager")
}
@@ -74,6 +77,12 @@ class BlockManagerMaster(
GetLocationsMultipleBlockIds(blockIds))
}
+ /** Return other blockmanager's local dirs with the given blockManagerId */
+ def getLocalDirsPath(blockManagerId: BlockManagerId): Map[BlockManagerId, Array[String]] = {
+ driverEndpoint.askWithRetry[Map[BlockManagerId, Array[String]]](
+ GetLocalDirsPath(blockManagerId))
+ }
+
/**
* Check if block manager master has a block. Note that this can be used to check for only
* those blocks that are reported to block manager master.
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
index 7db6035553ae..e6450b097e3f 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
@@ -19,17 +19,16 @@ package org.apache.spark.storage
import java.util.{HashMap => JHashMap}
-import scala.collection.immutable.HashSet
-import scala.collection.mutable
import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.concurrent.{ExecutionContext, Future}
-import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint}
-import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler._
import org.apache.spark.storage.BlockManagerMessages._
import org.apache.spark.util.{ThreadUtils, Utils}
+import org.apache.spark.{Logging, SparkConf}
/**
* BlockManagerMasterEndpoint is an [[ThreadSafeRpcEndpoint]] on the master node to track statuses
@@ -56,8 +55,8 @@ class BlockManagerMasterEndpoint(
private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool)
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) =>
- register(blockManagerId, maxMemSize, slaveEndpoint)
+ case RegisterBlockManager(blockManagerId, maxMemSize, localDirsPath, slaveEndpoint) =>
+ register(blockManagerId, maxMemSize, localDirsPath, slaveEndpoint)
context.reply(true)
case _updateBlockInfo @ UpdateBlockInfo(
@@ -81,6 +80,9 @@ class BlockManagerMasterEndpoint(
case GetMemoryStatus =>
context.reply(memoryStatus)
+ case GetLocalDirsPath(blockManagerId) =>
+ context.reply(getLocalDirsPath(blockManagerId))
+
case GetStorageStatus =>
context.reply(storageStatus)
@@ -235,11 +237,20 @@ class BlockManagerMasterEndpoint(
// Return a map from the block manager id to max memory and remaining memory.
private def memoryStatus: Map[BlockManagerId, (Long, Long)] = {
- blockManagerInfo.map { case(blockManagerId, info) =>
+ blockManagerInfo.map { case (blockManagerId, info) =>
(blockManagerId, (info.maxMem, info.remainingMem))
}.toMap
}
+ // Return the local dirs of a block manager with the given blockManagerId
+ private def getLocalDirsPath(blockManagerId: BlockManagerId)
+ : Map[BlockManagerId, Array[String]] = {
+ blockManagerInfo
+ .filter { case (id, _) => id != blockManagerId && id.host == blockManagerId.host }
+ .mapValues { info => info.localDirsPath }
+ .toMap
+ }
+
private def storageStatus: Array[StorageStatus] = {
blockManagerInfo.map { case (blockManagerId, info) =>
new StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala)
@@ -299,7 +310,11 @@ class BlockManagerMasterEndpoint(
).map(_.flatten.toSeq)
}
- private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) {
+ private def register(
+ id: BlockManagerId,
+ maxMemSize: Long,
+ localDirsPath: Array[String],
+ slaveEndpoint: RpcEndpointRef): Unit = {
val time = System.currentTimeMillis()
if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
@@ -316,7 +331,7 @@ class BlockManagerMasterEndpoint(
blockManagerIdByExecutor(id.executorId) = id
blockManagerInfo(id) = new BlockManagerInfo(
- id, System.currentTimeMillis(), maxMemSize, slaveEndpoint)
+ id, System.currentTimeMillis(), maxMemSize, localDirsPath, slaveEndpoint)
}
listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize))
}
@@ -423,6 +438,7 @@ private[spark] class BlockManagerInfo(
val blockManagerId: BlockManagerId,
timeMs: Long,
val maxMem: Long,
+ val localDirsPath: Array[String],
val slaveEndpoint: RpcEndpointRef)
extends Logging {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 376e9eb48843..2c32ca174c1c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -50,6 +50,7 @@ private[spark] object BlockManagerMessages {
case class RegisterBlockManager(
blockManagerId: BlockManagerId,
maxMemSize: Long,
+ localDirsPath: Array[String],
sender: RpcEndpointRef)
extends ToBlockManagerMaster
@@ -109,4 +110,6 @@ private[spark] object BlockManagerMessages {
case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
case class HasCachedBlocks(executorId: String) extends ToBlockManagerMaster
+
+ case class GetLocalDirsPath(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index f7e84a2c2e14..6bb504f657f4 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -20,6 +20,8 @@ package org.apache.spark.storage
import java.util.UUID
import java.io.{IOException, File}
+import scala.collection.mutable
+
import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.executor.ExecutorExitCode
import org.apache.spark.util.{ShutdownHookManager, Utils}
@@ -51,16 +53,26 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
// of subDirs(i) is protected by the lock of subDirs(i)
private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
+ // Cache local directories for other block managers
+ private val localDirsByOtherBlkMgr = new mutable.HashMap[BlockManagerId, Array[String]]
+
private val shutdownHook = addShutdownHook()
+ def blockManagerId: BlockManagerId = blockManager.blockManagerId
+
+ def getLocalDirsPath(): Array[String] = {
+ localDirs.map(_.getAbsolutePath)
+ }
+
+ def getLocalDirsPath(blockManagerId: BlockManagerId): Map[BlockManagerId, Array[String]] = {
+ blockManager.master.getLocalDirsPath(blockManagerId)
+ }
+
/** Looks up a file by hashing it into one of our local subdirectories. */
// This method should be kept in sync with
// org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getFile().
def getFile(filename: String): File = {
- // Figure out which local directory it hashes to, and which subdirectory in that
- val hash = Utils.nonNegativeHash(filename)
- val dirId = hash % localDirs.length
- val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
+ val (dirId, subDirId) = getDirInfo(filename, localDirs.length)
// Create the subdirectory if it doesn't already exist
val subDir = subDirs(dirId).synchronized {
@@ -82,6 +94,39 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
def getFile(blockId: BlockId): File = getFile(blockId.name)
+ def getFile(blockId: BlockId, blockManagerId: BlockManagerId): File = {
+ if (this.blockManagerId == blockManagerId) {
+ getFile(blockId)
+ } else {
+ // Get a file from another block manager with given blockManagerId
+ val dirs = localDirsByOtherBlkMgr.synchronized {
+ localDirsByOtherBlkMgr.getOrElse(blockManagerId, {
+ localDirsByOtherBlkMgr ++= getLocalDirsPath(this.blockManagerId)
+ localDirsByOtherBlkMgr.getOrElse(blockManagerId, {
+ throw new IOException(s"Block manager (${blockManagerId}) not found " +
+ s"in host '${this.blockManagerId.host}'")
+ })
+ })
+ }
+ val (dirId, subDirId) = getDirInfo(blockId.name, dirs.length)
+ val file = new File(new File(dirs(dirId), "%02x".format(subDirId)), blockId.name)
+ if (!file.exists()) {
+ throw new IOException(s"File '${file}' not found in local dir")
+ }
+ logInfo(s"${this.blockManagerId} bypasses network access and " +
+ s"directly reads file '${file}' in local dir")
+ file
+ }
+ }
+
+ def getDirInfo(filename: String, numDirs: Int): (Int, Int) = {
+ // Figure out which local directory it hashes to, and which subdirectory in that
+ val hash = Utils.nonNegativeHash(filename)
+ val dirId = hash % numDirs
+ val subDirName = (hash / numDirs) % subDirsPerLocalDir
+ (dirId, subDirName)
+ }
+
/** Check if disk block manager has a block. */
def containsBlock(blockId: BlockId): Boolean = {
getFile(blockId.name).exists()
@@ -166,7 +211,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
// Only perform cleanup if an external service is not serving our shuffle files.
// Also blockManagerId could be null if block manager is not initialized properly.
if (!blockManager.externalShuffleServiceEnabled ||
- (blockManager.blockManagerId != null && blockManager.blockManagerId.isDriver)) {
+ (this.blockManagerId != null && blockManager.blockManagerId.isDriver)) {
localDirs.foreach { localDir =>
if (localDir.isDirectory() && localDir.exists()) {
try {
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 0d0448feb5b0..54383ec8d648 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -20,7 +20,7 @@ package org.apache.spark.storage
import java.io.InputStream
import java.util.concurrent.LinkedBlockingQueue
-import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
+import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet, Queue}
import scala.util.control.NonFatal
import org.apache.spark.{Logging, SparkException, TaskContext}
@@ -58,6 +58,17 @@ final class ShuffleBlockFetcherIterator(
import ShuffleBlockFetcherIterator._
+ private[this] val enableExternalShuffleService =
+ blockManager.conf.getBoolean("spark.shuffle.service.enabled", false)
+
+ /**
+ * If this option enabled, bypass unnecessary network interaction
+ * if multiple block managers work in a single host.
+ */
+ private[this] val enableBypassNetworkAccess =
+ blockManager.conf.getBoolean("spark.shuffle.bypassNetworkAccess", false) &&
+ !enableExternalShuffleService
+
/**
* Total number of blocks to fetch. This can be smaller than the total number of blocks
* in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]].
@@ -74,8 +85,12 @@ final class ShuffleBlockFetcherIterator(
private[this] val startTime = System.currentTimeMillis
- /** Local blocks to fetch, excluding zero-sized blocks. */
- private[this] val localBlocks = new ArrayBuffer[BlockId]()
+ /**
+ * Local blocks to fetch, excluding zero-sized blocks.
+ * This iterator bypasses remote access to fetch the blocks that
+ * other block managers holds in an identical host.
+ */
+ private[this] val localBlocks = new HashMap[BlockManagerId, ArrayBuffer[BlockId]]
/** Remote blocks to fetch, excluding zero-sized blocks. */
private[this] val remoteBlocks = new HashSet[BlockId]()
@@ -188,10 +203,12 @@ final class ShuffleBlockFetcherIterator(
var totalBlocks = 0
for ((address, blockInfos) <- blocksByAddress) {
totalBlocks += blockInfos.size
- if (address.executorId == blockManager.blockManagerId.executorId) {
+ if (address.executorId == blockManager.blockManagerId.executorId ||
+ (enableBypassNetworkAccess && blockManager.blockManagerId.host == address.host)) {
// Filter out zero-sized blocks
- localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
- numBlocksToFetch += localBlocks.size
+ val blocks = blockInfos.filter(_._2 != 0).map(_._1)
+ localBlocks.getOrElseUpdate(address, ArrayBuffer()) ++= blocks
+ numBlocksToFetch += blocks.size
} else {
val iterator = blockInfos.iterator
var curRequestSize = 0L
@@ -233,19 +250,28 @@ final class ShuffleBlockFetcherIterator(
private[this] def fetchLocalBlocks() {
val iter = localBlocks.iterator
while (iter.hasNext) {
- val blockId = iter.next()
- try {
- val buf = blockManager.getBlockData(blockId)
- shuffleMetrics.incLocalBlocksFetched(1)
- shuffleMetrics.incLocalBytesRead(buf.size)
- buf.retain()
- results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf))
- } catch {
- case e: Exception =>
- // If we see an exception, stop immediately.
- logError(s"Error occurred while fetching local blocks", e)
- results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
- return
+ val (blockManagerId, blockIds) = iter.next()
+ val blockIter = blockIds.iterator
+ while (blockIter.hasNext) {
+ val blockId = blockIter.next()
+ assert(blockId.isShuffle)
+ try {
+ val buf = if (!enableExternalShuffleService) {
+ blockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId], blockManagerId)
+ } else {
+ blockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
+ }
+ shuffleMetrics.incLocalBlocksFetched(1)
+ shuffleMetrics.incLocalBytesRead(buf.size)
+ buf.retain()
+ results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf))
+ } catch {
+ case NonFatal(e) =>
+ // If we see an exception, stop immediately.
+ logError(s"Error occurred while fetching local blocks", e)
+ results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e))
+ return
+ }
}
}
}
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 1c3f2bc315dd..676948856a0a 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -86,6 +86,18 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
assert(groups.map(_._2).sum === 2000)
}
+ test("bypass remote access") {
+ val conf = new SparkConf().set("spark.shuffle.bypassNetworkAccess", "true")
+ Seq("hash", "sort", "tungsten-sort").map { shuffle =>
+ sc = new SparkContext(clusterUrl, "test", conf.clone.set("spark.shuffle.manager", shuffle))
+ val rdd = sc.parallelize((0 until 1000).map(x => (x % 4, 1)), 5)
+ val groups = rdd.reduceByKey(_ + _).collect
+ assert(groups.size === 4)
+ assert(groups.forall(_._2 == 250))
+ resetSparkContext()
+ }
+ }
+
test("accumulators") {
sc = new SparkContext(clusterUrl, "test")
val accum = sc.accumulator(0)
diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
index 26a372d6a905..6bd5135c1984 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
@@ -77,6 +77,8 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
// can ensure retain() and release() are properly called.
val blockManager = mock(classOf[BlockManager])
+ when(blockManager.conf).thenReturn(testConf)
+
// Create a return function to use for the mocked wrapForCompression method that just returns
// the original input stream.
val dummyCompressionFunction = new Answer[InputStream] {
@@ -104,7 +106,8 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
// Setup the blockManager mock so the buffer gets returned when the shuffle code tries to
// fetch shuffle data.
val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
- when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer)
+ when(blockManager.getBlockData(shuffleBlockId, localBlockManagerId))
+ .thenReturn(managedBuffer)
when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream])))
.thenAnswer(dummyCompressionFunction)
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
index 688f56f4665f..f566491e380a 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -17,21 +17,22 @@
package org.apache.spark.storage
-import java.io.{File, FileWriter}
+import java.io.{File, FileWriter, IOException}
import scala.language.reflectiveCalls
-import org.mockito.Mockito.{mock, when}
-import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
+import org.mockito.Matchers.{eq => meq}
+import org.mockito.Mockito.{mock, times, verify, when}
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, PrivateMethodTester}
-import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.util.Utils
+import org.apache.spark.{SparkConf, SparkFunSuite}
-class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
+class DiskBlockManagerSuite extends SparkFunSuite
+ with BeforeAndAfterEach with BeforeAndAfterAll with PrivateMethodTester {
private val testConf = new SparkConf(false)
private var rootDir0: File = _
private var rootDir1: File = _
- private var rootDirs: String = _
val blockManager = mock(classOf[BlockManager])
when(blockManager.conf).thenReturn(testConf)
@@ -41,7 +42,7 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B
super.beforeAll()
rootDir0 = Utils.createTempDir()
rootDir1 = Utils.createTempDir()
- rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath
+ testConf.set("spark.local.dir", rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath)
}
override def afterAll() {
@@ -51,9 +52,7 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B
}
override def beforeEach() {
- val conf = testConf.clone
- conf.set("spark.local.dir", rootDirs)
- diskBlockManager = new DiskBlockManager(blockManager, conf)
+ diskBlockManager = new DiskBlockManager(blockManager, testConf.clone)
}
override def afterEach() {
@@ -81,4 +80,58 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B
for (i <- 0 until numBytes) writer.write(i)
writer.close()
}
+
+ test("bypassing network access") {
+ val mockBlockManagerMaster = mock(classOf[BlockManagerMaster])
+ val mockBlockManager = mock(classOf[BlockManager])
+
+ // Assume two executors in an identical host
+ val localBmId1 = BlockManagerId("test-exec1", "test-client1", 1)
+ val localBmId2 = BlockManagerId("test-exec2", "test-client1", 2)
+
+ // Assume that localBmId2 holds 'shuffle_1_0_0'
+ val blockIdInLocalBmId2 = ShuffleBlockId(1, 0, 0)
+ val tempDir = Utils.createTempDir()
+ try {
+ // Create mock classes for testing
+ when(mockBlockManagerMaster.getLocalDirsPath(meq(localBmId1)))
+ .thenReturn(Map(localBmId2 -> Array(tempDir.getAbsolutePath)))
+ when(mockBlockManager.conf).thenReturn(testConf)
+ when(mockBlockManager.master).thenReturn(mockBlockManagerMaster)
+ when(mockBlockManager.blockManagerId).thenReturn(localBmId1)
+
+ val testDiskBlockManager = new DiskBlockManager(mockBlockManager, testConf.clone)
+
+ val getBlockDir: String => File = (s: String) => {
+ val (_, subDirId) = testDiskBlockManager.getDirInfo(s, 1)
+ new File(tempDir, "%02x".format(subDirId))
+ }
+
+ // Create a dummy file for a shuffle block
+ val blockDir = getBlockDir(blockIdInLocalBmId2.name)
+ assert(blockDir.mkdir())
+ val dummyBlockFile = new File(blockDir, blockIdInLocalBmId2.name)
+ assert(dummyBlockFile.createNewFile())
+
+ val file = testDiskBlockManager.getFile(
+ blockIdInLocalBmId2, localBmId2)
+ assert(dummyBlockFile.getName === file.getName)
+ assert(dummyBlockFile.toString.contains(tempDir.toString))
+
+ verify(mockBlockManagerMaster, times(1)).getLocalDirsPath(meq(localBmId1))
+ verify(mockBlockManager, times(1)).conf
+ verify(mockBlockManager, times(1)).master
+ verify(mockBlockManager, times(3)).blockManagerId
+
+ // Throw an IOException if given shuffle file not found
+ val blockIdNotInLocalBmId2 = ShuffleBlockId(2, 0, 0)
+ val errMsg = intercept[IOException] {
+ testDiskBlockManager.getFile(blockIdNotInLocalBmId2, localBmId2)
+ }
+ assert(errMsg.getMessage contains s"File '${getBlockDir(blockIdNotInLocalBmId2.name)}/" +
+ s"${blockIdNotInLocalBmId2}' not found in local dir")
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 828153bdbfc4..d7b3b4247420 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -27,16 +27,25 @@ import org.mockito.Matchers.{any, eq => meq}
import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
-import org.scalatest.PrivateMethodTester
+import org.scalatest.{BeforeAndAfterAll, PrivateMethodTester}
-import org.apache.spark.{SparkFunSuite, TaskContext}
+import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext}
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.shuffle.FetchFailedException
-class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester {
+class ShuffleBlockFetcherIteratorSuite
+ extends SparkFunSuite with BeforeAndAfterAll with PrivateMethodTester
+{
+ private val testConf = new SparkConf(false)
+
+ override def beforeAll() {
+ super.beforeAll()
+ testConf.set("spark.shuffle.bypassNetworkAccess", "false")
+ }
+
// Some of the tests are quite tricky because we are testing the cleanup behavior
// in the presence of faults.
@@ -70,15 +79,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
test("successful 3 local reads + 2 remote reads") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
+ doReturn(testConf).when(blockManager).conf
doReturn(localBmId).when(blockManager).blockManagerId
// Make sure blockManager.getBlockData would return the blocks
- val localBlocks = Map[BlockId, ManagedBuffer](
+ val localBlocks = Map[ShuffleBlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer())
localBlocks.foreach { case (blockId, buf) =>
- doReturn(buf).when(blockManager).getBlockData(meq(blockId))
+ doReturn(buf).when(blockManager).getBlockData(meq(blockId), meq(localBmId))
}
// Make sure remote blocks would return
@@ -102,14 +112,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
48 * 1024 * 1024)
// 3 local blocks fetched in initialization
- verify(blockManager, times(3)).getBlockData(any())
+ verify(blockManager, times(3)).getBlockData(any(), any())
for (i <- 0 until 5) {
assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements")
val (blockId, inputStream) = iterator.next()
// Make sure we release buffers when a wrapped input stream is closed.
- val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId))
+ val mockBuf = localBlocks.getOrElse(
+ blockId.asInstanceOf[ShuffleBlockId], remoteBlocks(blockId))
// Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream
val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream]
verify(mockBuf, times(0)).release()
@@ -126,13 +137,70 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
// 3 local blocks, and 2 remote blocks
// (but from the same block manager so one call to fetchBlocks)
- verify(blockManager, times(3)).getBlockData(any())
+ verify(blockManager, times(3)).getBlockData(any(), any())
verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any())
}
+ test("bypass unnecessary network access if block managers share an identical host") {
+ val blockManager = mock(classOf[BlockManager])
+
+ // Assume two executors in an identical host
+ val localBmId1 = BlockManagerId("test-exec1", "test-client1", 1)
+ val localBmId2 = BlockManagerId("test-exec2", "test-client1", 2)
+
+ // Enable an option to bypass network access
+ doReturn(testConf.clone.set("spark.shuffle.bypassNetworkAccess", "true"))
+ .when(blockManager).conf
+ doReturn(localBmId1).when(blockManager).blockManagerId
+
+ // Make sure blockManager.getBlockData would return the blocks
+ val localBlocksInBmId1 = Map[ShuffleBlockId, ManagedBuffer](
+ ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer())
+ localBlocksInBmId1.foreach { case (blockId, buf) =>
+ doReturn(buf).when(blockManager).getBlockData(meq(blockId), meq(localBmId1))
+ }
+ val localBlocksInBmId2 = Map[ShuffleBlockId, ManagedBuffer](
+ ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
+ ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer())
+ localBlocksInBmId2.foreach { case (blockId, buf) =>
+ doReturn(buf).when(blockManager).getBlockData(meq(blockId), meq(localBmId2))
+ }
+
+ // Create mock transfer
+ val transfer = mock(classOf[BlockTransferService])
+ when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ override def answer(invocation: InvocationOnMock): Unit = {}
+ })
+
+ val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
+ (localBmId1, localBlocksInBmId1.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq),
+ (localBmId2, localBlocksInBmId2.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)
+ )
+
+ val iterator = new ShuffleBlockFetcherIterator(
+ TaskContext.empty(),
+ transfer,
+ blockManager,
+ blocksByAddress,
+ 48 * 1024 * 1024)
+
+ // Skip unnecessary remote reads
+ verify(blockManager, times(3)).getBlockData(any(), any())
+
+ for (i <- 0 until 3) {
+ assert(iterator.hasNext, s"iterator should have 3 elements but actually has $i elements")
+ iterator.next()
+ }
+
+ // As a result, only 3 local reads (2 remote access skipped)
+ verify(blockManager, times(3)).getBlockData(any(), any())
+ verify(transfer, times(0)).fetchBlocks(any(), any(), any(), any(), any())
+ }
+
test("release current unexhausted buffer in case the task completes early") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
+ doReturn(testConf).when(blockManager).conf
doReturn(localBmId).when(blockManager).blockManagerId
// Make sure remote blocks would return
@@ -194,6 +262,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT
test("fail all blocks if any of the remote request fails") {
val blockManager = mock(classOf[BlockManager])
val localBmId = BlockManagerId("test-client", "test-client", 1)
+ doReturn(testConf).when(blockManager).conf
doReturn(localBmId).when(blockManager).blockManagerId
// Make sure remote blocks would return
diff --git a/docs/configuration.md b/docs/configuration.md
index 741d6b2b37a8..b267919ebfaa 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -477,6 +477,14 @@ Apart from these, the following properties are also available, and may be useful
spark.io.compression.codec.
+
spark.shuffle.bypassNetworkAccess