Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
1e752f1
Added unpersist method to Broadcast.
Feb 5, 2014
80dd977
Fix for Broadcast unpersist patch.
Feb 6, 2014
c7ccef1
Merge branch 'bc-unpersist-merge' of github.com:ignatich/incubator-sp…
andrewor14 Mar 26, 2014
ba52e00
Refactor broadcast classes
andrewor14 Mar 26, 2014
d0edef3
Add framework for broadcast cleanup
andrewor14 Mar 26, 2014
544ac86
Clean up broadcast blocks through BlockManager*
andrewor14 Mar 26, 2014
e95479c
Add tests for unpersisting broadcast
andrewor14 Mar 27, 2014
f201a8d
Test broadcast cleanup in ContextCleanerSuite + remove BoundedHashMap
andrewor14 Mar 27, 2014
c92e4d9
Merge github.com:apache/spark into cleanup
andrewor14 Mar 27, 2014
0d17060
Import, comments, and style fixes (minor)
andrewor14 Mar 28, 2014
34f436f
Generalize BroadcastBlockId to remove BroadcastHelperBlockId
andrewor14 Mar 28, 2014
fbfeec8
Add functionality to query executors for their local BlockStatuses
andrewor14 Mar 29, 2014
88904a3
Make TimeStampedWeakValueHashMap a wrapper of TimeStampedHashMap
andrewor14 Mar 29, 2014
e442246
Merge github.com:apache/spark into cleanup
andrewor14 Mar 29, 2014
8557c12
Merge github.com:apache/spark into cleanup
andrewor14 Mar 30, 2014
634a097
Merge branch 'state-cleanup' of github.com:tdas/spark into cleanup
andrewor14 Mar 31, 2014
7ed72fb
Fix style test fail + remove verbose test message regarding broadcast
andrewor14 Mar 31, 2014
5016375
Address TD's comments
andrewor14 Apr 1, 2014
f0aabb1
Correct semantics for TimeStampedWeakValueHashMap + add tests
andrewor14 Apr 2, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,13 @@ class SparkContext(
* Broadcast a read-only variable to the cluster, returning a
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*
* If `registerBlocks` is true, workers will notify driver about blocks they create
* and these blocks will be dropped when `unpersist` method of the broadcast variable is called.
*/
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
def broadcast[T](value: T, registerBlocks: Boolean = false) = {
env.broadcastManager.newBroadcast[T](value, isLocal, registerBlocks)
}

/**
* Add a file to be downloaded with this Spark job on every node.
Expand Down
13 changes: 11 additions & 2 deletions core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ import org.apache.spark._
abstract class Broadcast[T](val id: Long) extends Serializable {
def value: T

/**
* Removes all blocks of this broadcast from memory (and disk if removeSource is true).
*
* @param removeSource Whether to remove data from disk as well.
* Will cause errors if broadcast is accessed on workers afterwards
* (e.g. in case of RDD re-computation due to executor failure).
*/
def unpersist(removeSource: Boolean = false)

// We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes.

Expand Down Expand Up @@ -92,8 +101,8 @@ class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager:

private val nextBroadcastId = new AtomicLong(0)

def newBroadcast[T](value_ : T, isLocal: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
def newBroadcast[T](value_ : T, isLocal: Boolean, registerBlocks: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement(), registerBlocks)

def isDriver = _isDriver
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.SparkConf
* entire Spark job.
*/
trait BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long, registerBlocks: Boolean): Broadcast[T]
def stop(): Unit
}
49 changes: 37 additions & 12 deletions core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,24 @@ import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}

private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean)
extends Broadcast[T](id) with Logging with Serializable {

def value = value_

def unpersist(removeSource: Boolean) {
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.master.removeBlock(blockId)
SparkEnv.get.blockManager.removeBlock(blockId)
}

if (removeSource) {
HttpBroadcast.synchronized {
HttpBroadcast.cleanupById(id)
}
}
}

def blockId = BroadcastBlockId(id)

HttpBroadcast.synchronized {
Expand All @@ -54,7 +67,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
logInfo("Started reading broadcast variable " + id)
val start = System.nanoTime
value_ = HttpBroadcast.read[T](id)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks)
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
Expand All @@ -71,8 +84,8 @@ class HttpBroadcastFactory extends BroadcastFactory {
HttpBroadcast.initialize(isDriver, conf, securityMgr)
}

def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) =
new HttpBroadcast[T](value_, isLocal, id, registerBlocks)

def stop() { HttpBroadcast.stop() }
}
Expand Down Expand Up @@ -136,8 +149,10 @@ private object HttpBroadcast extends Logging {
logInfo("Broadcast server started at " + serverUri)
}

def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name)

def write(id: Long, value: Any) {
val file = new File(broadcastDir, BroadcastBlockId(id).name)
val file = getFile(id)
val out: OutputStream = {
if (compress) {
compressionCodec.compressedOutputStream(new FileOutputStream(file))
Expand Down Expand Up @@ -183,20 +198,30 @@ private object HttpBroadcast extends Logging {
obj
}

def deleteFile(fileName: String) {
try {
new File(fileName).delete()
logInfo("Deleted broadcast file '" + fileName + "'")
} catch {
case e: Exception => logWarning("Could not delete broadcast file '" + fileName + "'", e)
}
}

def cleanup(cleanupTime: Long) {
val iterator = files.internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
val (file, time) = (entry.getKey, entry.getValue)
if (time < cleanupTime) {
try {
iterator.remove()
new File(file.toString).delete()
logInfo("Deleted broadcast file '" + file + "'")
} catch {
case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e)
}
iterator.remove()
deleteFile(file)
}
}
}

def cleanupById(id: Long) {
val file = getFile(id).getAbsolutePath
files.internalMap.remove(file)
deleteFile(file)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,68 @@ import org.apache.spark._
import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
import org.apache.spark.util.Utils

private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean)
extends Broadcast[T](id) with Logging with Serializable {

def value = value_

def unpersist(removeSource: Boolean) {
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.master.removeBlock(broadcastId)
SparkEnv.get.blockManager.removeBlock(broadcastId)
}

if (!removeSource) {
//We can't tell BlockManager master to remove blocks from all nodes except driver,
//so we need to save them here in order to store them on disk later.
//This may be inefficient if blocks were already dropped to disk,
//but since unpersist is supposed to be called right after working with
//a broadcast this should not happen (and getting them from memory is cheap).
arrayOfBlocks = new Array[TorrentBlock](totalBlocks)

for (pid <- 0 until totalBlocks) {
val pieceId = pieceBlockId(pid)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(pieceId) match {
case Some(x) =>
arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
case None =>
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
}
}
}
}

for (pid <- 0 until totalBlocks) {
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.master.removeBlock(pieceBlockId(pid))
}
}

if (removeSource) {
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.removeBlock(metaId)
}
} else {
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.dropFromMemory(metaId)
}

for (i <- 0 until totalBlocks) {
val pieceId = pieceBlockId(i)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
pieceId, arrayOfBlocks(i), StorageLevel.DISK_ONLY, true)
}
}
arrayOfBlocks = null
}
}

def broadcastId = BroadcastBlockId(id)
private def metaId = BroadcastHelperBlockId(broadcastId, "meta")
private def pieceBlockId(pid: Int) = BroadcastHelperBlockId(broadcastId, "piece" + pid)
private def pieceIds = Array.iterate(0, totalBlocks)(_ + 1).toList

TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
Expand All @@ -54,7 +110,6 @@ extends Broadcast[T](id) with Logging with Serializable {
hasBlocks = tInfo.totalBlocks

// Store meta-info
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
Expand All @@ -63,7 +118,7 @@ extends Broadcast[T](id) with Logging with Serializable {

// Store individual pieces
for (i <- 0 until totalBlocks) {
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
val pieceId = pieceBlockId(i)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
Expand Down Expand Up @@ -93,7 +148,7 @@ extends Broadcast[T](id) with Logging with Serializable {
// This creates a tradeoff between memory usage and latency.
// Storing copy doubles the memory footprint; not storing doubles deserialization cost.
SparkEnv.get.blockManager.putSingle(
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks)

// Remove arrayOfBlocks from memory once value_ is on local cache
resetWorkerVariables()
Expand All @@ -116,7 +171,6 @@ extends Broadcast[T](id) with Logging with Serializable {

def receiveBroadcast(variableID: Long): Boolean = {
// Receive meta-info
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
var attemptId = 10
while (attemptId > 0 && totalBlocks == -1) {
TorrentBroadcast.synchronized {
Expand All @@ -139,9 +193,9 @@ extends Broadcast[T](id) with Logging with Serializable {
}

// Receive actual blocks
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
val recvOrder = new Random().shuffle(pieceIds)
for (pid <- recvOrder) {
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
val pieceId = pieceBlockId(pid)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(pieceId) match {
case Some(x) =>
Expand Down Expand Up @@ -245,8 +299,8 @@ class TorrentBroadcastFactory extends BroadcastFactory {
TorrentBroadcast.initialize(isDriver, conf)
}

def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TorrentBroadcast[T](value_, isLocal, id)
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) =
new TorrentBroadcast[T](value_, isLocal, id, registerBlocks)

def stop() { TorrentBroadcast.stop() }
}
12 changes: 12 additions & 0 deletions core/src/main/scala/org/apache/spark/storage/BlockManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ private[spark] class BlockManager(
}
}

/**
* For testing. Returns number of blocks BlockManager knows about that are in memory.
*/
def numberOfBlocksInMemory() = blockInfo.keys.count(memoryStore.contains(_))

/**
* Get storage level of local block. If no info exists for the block, then returns null.
*/
Expand Down Expand Up @@ -812,6 +817,13 @@ private[spark] class BlockManager(
}

/**
* Drop a block from memory, possibly putting it on disk if applicable.
*/
def dropFromMemory(blockId: BlockId) {
memoryStore.asInstanceOf[MemoryStore].dropFromMemory(blockId)
}

/**
* Remove all blocks belonging to the given RDD.
* @return The number of blocks removed.
*/
Expand Down
38 changes: 22 additions & 16 deletions core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,27 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}

/**
* Try to free up a given amount of space to store a particular block, but can fail if
* either the block is bigger than our memory or it would require replacing another block
* from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that
* Drop a block from memory, possibly putting it on disk if applicable.
*/
def dropFromMemory(blockId: BlockId) {
val entry = entries.synchronized { entries.get(blockId) }
// This should never be null if called from ensureFreeSpace as only one
// thread should be dropping blocks and removing entries.
// However the check is required in other cases.
if (entry != null) {
val data = if (entry.deserialized) {
Left(entry.value.asInstanceOf[ArrayBuffer[Any]])
} else {
Right(entry.value.asInstanceOf[ByteBuffer].duplicate())
}
blockManager.dropFromMemory(blockId, data)
}
}

/**
* Tries to free up a given amount of space to store a particular block, but can fail and return
* false if either the block is bigger than our memory or it would require replacing another
* block from the same RDD (which leads to a wasteful cyclic replacement pattern for RDDs that
* don't fit into memory that we want to avoid).
*
* Assume that a lock is held by the caller to ensure only one thread is dropping blocks.
Expand Down Expand Up @@ -254,19 +272,7 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
if (maxMemory - (currentMemory - selectedMemory) >= space) {
logInfo(selectedBlocks.size + " blocks selected for dropping")
for (blockId <- selectedBlocks) {
val entry = entries.synchronized { entries.get(blockId) }
// This should never be null as only one thread should be dropping
// blocks and removing entries. However the check is still here for
// future safety.
if (entry != null) {
val data = if (entry.deserialized) {
Left(entry.value.asInstanceOf[ArrayBuffer[Any]])
} else {
Right(entry.value.asInstanceOf[ByteBuffer].duplicate())
}
val droppedBlockStatus = blockManager.dropFromMemory(blockId, data)
droppedBlockStatus.foreach { status => droppedBlocks += ((blockId, status)) }
}
dropFromMemory(blockId)
}
return ResultWithDroppedBlocks(success = true, droppedBlocks)
} else {
Expand Down
Loading