Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Fix the interface issues
  • Loading branch information
maropu committed Nov 27, 2015
commit 303abcd0f37d137c6a8ce4a0147466bb8feb9d9e
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
package org.apache.spark.network

import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.storage.{ShuffleBlockId, BlockManagerId, BlockId, StorageLevel}
import org.apache.spark.storage.BlockId
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.storage.StorageLevel

private[spark]
trait BlockDataManager {
Expand All @@ -30,11 +32,10 @@ trait BlockDataManager {
def getBlockData(blockId: BlockId): ManagedBuffer

/**
* Interface to get the shuffle block data that block manager with given blockManagerId
* holds in a local host. Throws an exception if the block cannot be found or
* cannot be read successfully.
* Interface to get local block data managed by given BlockManagerId.
* Throws an exception if the block cannot be found or cannot be read successfully.
*/
def getShuffleBlockData(blockId: ShuffleBlockId, blockManagerId: BlockManagerId): ManagedBuffer
def getBlockData(blockId: BlockId, blockManagerId: BlockManagerId): ManagedBuffer

/**
* Put the block locally, using the given storage level.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)

override def getBlockData(blockId: ShuffleBlockId, blockManagerId: BlockManagerId)
: ManagedBuffer = {
val file = if (blockManager.blockManagerId != blockManagerId) {
blockManager.diskBlockManager.getShuffleFileBypassNetworkAccess(blockId, blockManagerId)
} else {
blockManager.diskBlockManager.getFile(blockId)
}
val file = blockManager.diskBlockManager.getFile(blockId, blockManagerId)
new FileSegmentManagedBuffer(transportConf, file, 0, file.length)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,35 +49,19 @@ private[spark] class IndexShuffleBlockResolver(

private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")

private def getDataFile(
shuffleId: Int,
mapId: Int,
blockManagerId: BlockManagerId = blockManager.blockManagerId)
: File = {
if (blockManager.blockManagerId != blockManagerId) {
blockManager.diskBlockManager.getShuffleFileBypassNetworkAccess(
ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID), blockManagerId)
} else {
blockManager.diskBlockManager.getFile(
ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID))
}
}

def getDataFile(shuffleId: Int, mapId: Int): File =
getDataFile(shuffleId, mapId, blockManager.blockManagerId)

private def getIndexFile(
shuffleId: Int,
mapId: Int,
blockManagerId: BlockManagerId = blockManager.blockManagerId): File = {
if (blockManager.blockManagerId != blockManagerId) {
blockManager.diskBlockManager.getShuffleFileBypassNetworkAccess(
ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID), blockManagerId)
} else {
blockManager.diskBlockManager.getFile(
ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID))
}
}
private def getDataFile(shuffleId: Int, mapId: Int, blockManagerId: BlockManagerId): File =
blockManager.diskBlockManager.getFile(
ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID), blockManagerId)

private def getIndexFile(shuffleId: Int, mapId: Int): File =
getIndexFile(shuffleId, mapId, blockManager.blockManagerId)

private def getIndexFile(shuffleId: Int, mapId: Int, blockManagerId: BlockManagerId): File =
blockManager.diskBlockManager.getFile(
ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID), blockManagerId)

/**
* Remove data file and index file that contain the output data from one map.
Expand Down
13 changes: 6 additions & 7 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExternalShuffleClient
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
import org.apache.spark.rpc.RpcEnv
import org.apache.spark.serializer.{SerializerInstance, Serializer}
import org.apache.spark.serializer.{Serializer, SerializerInstance}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.util._

Expand Down Expand Up @@ -288,9 +288,10 @@ private[spark] class BlockManager(
* Interface to get local block data. Throws an exception if the block cannot be found or
* cannot be read successfully.
*/
override def getBlockData(blockId: BlockId): ManagedBuffer = {
override def getBlockData(blockId: BlockId, blockManagerId: BlockManagerId): ManagedBuffer = {
if (blockId.isShuffle) {
getShuffleBlockData(blockId.asInstanceOf[ShuffleBlockId], blockManagerId)
shuffleManager.shuffleBlockResolver.getBlockData(
blockId.asInstanceOf[ShuffleBlockId], blockManagerId)
} else {
val blockBytesOpt = doGetLocal(blockId, asBlockResult = false)
.asInstanceOf[Option[ByteBuffer]]
Expand All @@ -303,10 +304,8 @@ private[spark] class BlockManager(
}
}

override def getShuffleBlockData(
blockId: ShuffleBlockId,
blockManagerId: BlockManagerId): ManagedBuffer = {
shuffleManager.shuffleBlockResolver.getBlockData(blockId, blockManagerId)
override def getBlockData(blockId: BlockId): ManagedBuffer = {
getBlockData(blockId, this.blockManagerId)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon

def getFile(blockId: BlockId): File = getFile(blockId.name)

def getShuffleFileBypassNetworkAccess(blockId: BlockId, blockManagerId: BlockManagerId): File = {
def getFile(blockId: BlockId, blockManagerId: BlockManagerId): File = {
if (this.blockManagerId == blockManagerId) {
getFile(blockId)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ final class ShuffleBlockFetcherIterator(
assert(blockId.isShuffle)
try {
val buf = if (!enableExternalShuffleService) {
blockManager.getShuffleBlockData(blockId.asInstanceOf[ShuffleBlockId], blockManagerId)
blockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId], blockManagerId)
} else {
blockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
// Setup the blockManager mock so the buffer gets returned when the shuffle code tries to
// fetch shuffle data.
val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
when(blockManager.getShuffleBlockData(shuffleBlockId, localBlockManagerId))
when(blockManager.getBlockData(shuffleBlockId, localBlockManagerId))
.thenReturn(managedBuffer)
when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream])))
.thenAnswer(dummyCompressionFunction)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class DiskBlockManagerSuite extends SparkFunSuite
val dummyBlockFile = new File(blockDir, blockIdInLocalBmId2.name)
assert(dummyBlockFile.createNewFile())

val file = testDiskBlockManager.getShuffleFileBypassNetworkAccess(
val file = testDiskBlockManager.getFile(
blockIdInLocalBmId2, localBmId2)
assert(dummyBlockFile.getName === file.getName)
assert(dummyBlockFile.toString.contains(tempDir.toString))
Expand All @@ -126,7 +126,7 @@ class DiskBlockManagerSuite extends SparkFunSuite
// Throw an IOException if given shuffle file not found
val blockIdNotInLocalBmId2 = ShuffleBlockId(2, 0, 0)
val errMsg = intercept[IOException] {
testDiskBlockManager.getShuffleFileBypassNetworkAccess(blockIdNotInLocalBmId2, localBmId2)
testDiskBlockManager.getFile(blockIdNotInLocalBmId2, localBmId2)
}
assert(errMsg.getMessage contains s"File '${getBlockDir(blockIdNotInLocalBmId2.name)}/" +
s"${blockIdNotInLocalBmId2}' not found in local dir")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class ShuffleBlockFetcherIteratorSuite
ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer())
localBlocks.foreach { case (blockId, buf) =>
doReturn(buf).when(blockManager).getShuffleBlockData(meq(blockId), meq(localBmId))
doReturn(buf).when(blockManager).getBlockData(meq(blockId), meq(localBmId))
}

// Make sure remote blocks would return
Expand All @@ -112,7 +112,7 @@ class ShuffleBlockFetcherIteratorSuite
48 * 1024 * 1024)

// 3 local blocks fetched in initialization
verify(blockManager, times(3)).getShuffleBlockData(any(), any())
verify(blockManager, times(3)).getBlockData(any(), any())

for (i <- 0 until 5) {
assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements")
Expand All @@ -137,7 +137,7 @@ class ShuffleBlockFetcherIteratorSuite

// 3 local blocks, and 2 remote blocks
// (but from the same block manager so one call to fetchBlocks)
verify(blockManager, times(3)).getShuffleBlockData(any(), any())
verify(blockManager, times(3)).getBlockData(any(), any())
verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any())
}

Expand All @@ -157,13 +157,13 @@ class ShuffleBlockFetcherIteratorSuite
val localBlocksInBmId1 = Map[ShuffleBlockId, ManagedBuffer](
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer())
localBlocksInBmId1.foreach { case (blockId, buf) =>
doReturn(buf).when(blockManager).getShuffleBlockData(meq(blockId), meq(localBmId1))
doReturn(buf).when(blockManager).getBlockData(meq(blockId), meq(localBmId1))
}
val localBlocksInBmId2 = Map[ShuffleBlockId, ManagedBuffer](
ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer())
localBlocksInBmId2.foreach { case (blockId, buf) =>
doReturn(buf).when(blockManager).getShuffleBlockData(meq(blockId), meq(localBmId2))
doReturn(buf).when(blockManager).getBlockData(meq(blockId), meq(localBmId2))
}

// Create mock transfer
Expand All @@ -185,15 +185,15 @@ class ShuffleBlockFetcherIteratorSuite
48 * 1024 * 1024)

// Skip unnecessary remote reads
verify(blockManager, times(3)).getShuffleBlockData(any(), any())
verify(blockManager, times(3)).getBlockData(any(), any())

for (i <- 0 until 3) {
assert(iterator.hasNext, s"iterator should have 3 elements but actually has $i elements")
iterator.next()
}

// As a result, only 3 local reads (2 remote access skipped)
verify(blockManager, times(3)).getShuffleBlockData(any(), any())
verify(blockManager, times(3)).getBlockData(any(), any())
verify(transfer, times(0)).fetchBlocks(any(), any(), any(), any(), any())
}

Expand Down