Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
1e752f1
Added unpersist method to Broadcast.
Feb 5, 2014
80dd977
Fix for Broadcast unpersist patch.
Feb 6, 2014
c7ccef1
Merge branch 'bc-unpersist-merge' of github.com:ignatich/incubator-sp…
andrewor14 Mar 26, 2014
ba52e00
Refactor broadcast classes
andrewor14 Mar 26, 2014
d0edef3
Add framework for broadcast cleanup
andrewor14 Mar 26, 2014
544ac86
Clean up broadcast blocks through BlockManager*
andrewor14 Mar 26, 2014
e95479c
Add tests for unpersisting broadcast
andrewor14 Mar 27, 2014
f201a8d
Test broadcast cleanup in ContextCleanerSuite + remove BoundedHashMap
andrewor14 Mar 27, 2014
c92e4d9
Merge github.com:apache/spark into cleanup
andrewor14 Mar 27, 2014
0d17060
Import, comments, and style fixes (minor)
andrewor14 Mar 28, 2014
34f436f
Generalize BroadcastBlockId to remove BroadcastHelperBlockId
andrewor14 Mar 28, 2014
fbfeec8
Add functionality to query executors for their local BlockStatuses
andrewor14 Mar 29, 2014
88904a3
Make TimeStampedWeakValueHashMap a wrapper of TimeStampedHashMap
andrewor14 Mar 29, 2014
e442246
Merge github.com:apache/spark into cleanup
andrewor14 Mar 29, 2014
8557c12
Merge github.com:apache/spark into cleanup
andrewor14 Mar 30, 2014
634a097
Merge branch 'state-cleanup' of github.com:tdas/spark into cleanup
andrewor14 Mar 31, 2014
7ed72fb
Fix style test fail + remove verbose test message regarding broadcast
andrewor14 Mar 31, 2014
5016375
Address TD's comments
andrewor14 Apr 1, 2014
f0aabb1
Correct semantics for TimeStampedWeakValueHashMap + add tests
andrewor14 Apr 2, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refactor broadcast classes
  • Loading branch information
andrewor14 committed Mar 26, 2014
commit ba52e00303896e46ce9cb5122e78e12d7cae7864
7 changes: 1 addition & 6 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -641,13 +641,8 @@ class SparkContext(
* Broadcast a read-only variable to the cluster, returning a
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*
* If `registerBlocks` is true, workers will notify driver about blocks they create
* and these blocks will be dropped when `unpersist` method of the broadcast variable is called.
*/
def broadcast[T](value: T, registerBlocks: Boolean = false) = {
env.broadcastManager.newBroadcast[T](value, isLocal, registerBlocks)
}
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)

/**
* Add a file to be downloaded with this Spark job on every node.
Expand Down
51 changes: 0 additions & 51 deletions core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
package org.apache.spark.broadcast

import java.io.Serializable
import java.util.concurrent.atomic.AtomicLong

import org.apache.spark._

/**
* A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable
Expand Down Expand Up @@ -53,56 +50,8 @@ import org.apache.spark._
abstract class Broadcast[T](val id: Long) extends Serializable {
def value: T

/**
* Removes all blocks of this broadcast from memory (and disk if removeSource is true).
*
* @param removeSource Whether to remove data from disk as well.
* Will cause errors if broadcast is accessed on workers afterwards
* (e.g. in case of RDD re-computation due to executor failure).
*/
def unpersist(removeSource: Boolean = false)

// We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes.

override def toString = "Broadcast(" + id + ")"
}

private[spark]
class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager)
extends Logging with Serializable {

private var initialized = false
private var broadcastFactory: BroadcastFactory = null

initialize()

// Called by SparkContext or Executor before using Broadcast
private def initialize() {
synchronized {
if (!initialized) {
val broadcastFactoryClass = conf.get(
"spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")

broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]

// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isDriver, conf, securityManager)

initialized = true
}
}
}

def stop() {
broadcastFactory.stop()
}

private val nextBroadcastId = new AtomicLong(0)

def newBroadcast[T](value_ : T, isLocal: Boolean, registerBlocks: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement(), registerBlocks)

def isDriver = _isDriver
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ import org.apache.spark.SparkConf
*/
trait BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long, registerBlocks: Boolean): Broadcast[T]
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.broadcast

import java.util.concurrent.atomic.AtomicLong

import org.apache.spark._

private[spark] class BroadcastManager(
val isDriver: Boolean,
conf: SparkConf,
securityManager: SecurityManager)
extends Logging with Serializable {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why BroadcastManager needs to be serializable?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No idea... I'm not sure if it does (I pulled this directly from Broadcast.scala)


private var initialized = false
private var broadcastFactory: BroadcastFactory = null

initialize()

// Called by SparkContext or Executor before using Broadcast
private def initialize() {
synchronized {
if (!initialized) {
val broadcastFactoryClass =
conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")

broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]

// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isDriver, conf, securityManager)

initialized = true
}
}
}

def stop() {
broadcastFactory.stop()
}

private val nextBroadcastId = new AtomicLong(0)

def newBroadcast[T](value_ : T, isLocal: Boolean) = {
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
}

}
59 changes: 10 additions & 49 deletions core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,11 @@ import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}

private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean)
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {

def value = value_

def unpersist(removeSource: Boolean) {
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.master.removeBlock(blockId)
SparkEnv.get.blockManager.removeBlock(blockId)
}

if (removeSource) {
HttpBroadcast.synchronized {
HttpBroadcast.cleanupById(id)
}
}
}

def blockId = BroadcastBlockId(id)

HttpBroadcast.synchronized {
Expand All @@ -67,7 +54,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
logInfo("Started reading broadcast variable " + id)
val start = System.nanoTime
value_ = HttpBroadcast.read[T](id)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, registerBlocks)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
Expand All @@ -76,20 +63,6 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea
}
}

/**
* A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium.
*/
class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
HttpBroadcast.initialize(isDriver, conf, securityMgr)
}

def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long, registerBlocks: Boolean) =
new HttpBroadcast[T](value_, isLocal, id, registerBlocks)

def stop() { HttpBroadcast.stop() }
}

private object HttpBroadcast extends Logging {
private var initialized = false

Expand Down Expand Up @@ -149,10 +122,8 @@ private object HttpBroadcast extends Logging {
logInfo("Broadcast server started at " + serverUri)
}

def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name)

def write(id: Long, value: Any) {
val file = getFile(id)
val file = new File(broadcastDir, BroadcastBlockId(id).name)
val out: OutputStream = {
if (compress) {
compressionCodec.compressedOutputStream(new FileOutputStream(file))
Expand Down Expand Up @@ -198,30 +169,20 @@ private object HttpBroadcast extends Logging {
obj
}

def deleteFile(fileName: String) {
try {
new File(fileName).delete()
logInfo("Deleted broadcast file '" + fileName + "'")
} catch {
case e: Exception => logWarning("Could not delete broadcast file '" + fileName + "'", e)
}
}

def cleanup(cleanupTime: Long) {
val iterator = files.internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
val (file, time) = (entry.getKey, entry.getValue)
if (time < cleanupTime) {
iterator.remove()
deleteFile(file)
try {
iterator.remove()
new File(file.toString).delete()
logInfo("Deleted broadcast file '" + file + "'")
} catch {
case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e)
}
}
}
}

def cleanupById(id: Long) {
val file = getFile(id).getAbsolutePath
files.internalMap.remove(file)
deleteFile(file)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.broadcast

import org.apache.spark.{SecurityManager, SparkConf}

/**
* A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium.
*/
class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
HttpBroadcast.initialize(isDriver, conf, securityMgr)
}

def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)

def stop() { HttpBroadcast.stop() }
}
Loading