-
Notifications
You must be signed in to change notification settings - Fork 3
Add support to both explicitly and automatically clean-up broadcast #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 16 commits
1e752f1
80dd977
c7ccef1
ba52e00
d0edef3
544ac86
e95479c
f201a8d
c92e4d9
0d17060
34f436f
fbfeec8
88904a3
e442246
8557c12
634a097
7ed72fb
5016375
f0aabb1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,105 +21,106 @@ import java.lang.ref.{ReferenceQueue, WeakReference} | |
|
|
||
| import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} | ||
|
|
||
| import org.apache.spark.broadcast.Broadcast | ||
| import org.apache.spark.rdd.RDD | ||
|
|
||
| /** Listener class used for testing when any item has been cleaned by the Cleaner class */ | ||
| private[spark] trait CleanerListener { | ||
| def rddCleaned(rddId: Int) | ||
| def shuffleCleaned(shuffleId: Int) | ||
| } | ||
| /** | ||
| * Classes that represent cleaning tasks. | ||
| */ | ||
| private sealed trait CleanupTask | ||
| private case class CleanRDD(rddId: Int) extends CleanupTask | ||
| private case class CleanShuffle(shuffleId: Int) extends CleanupTask | ||
| private case class CleanBroadcast(broadcastId: Long) extends CleanupTask | ||
|
|
||
| /** | ||
| * Cleans RDDs and shuffle data. | ||
| * A WeakReference associated with a CleanupTask. | ||
| * | ||
| * When the referent object becomes only weakly reachable, the corresponding | ||
| * CleanupTaskWeakReference is automatically added to the given reference queue. | ||
| */ | ||
| private class CleanupTaskWeakReference( | ||
| val task: CleanupTask, | ||
| referent: AnyRef, | ||
| referenceQueue: ReferenceQueue[AnyRef]) | ||
| extends WeakReference(referent, referenceQueue) | ||
|
|
||
| /** | ||
| * An asynchronous cleaner for RDD, shuffle, and broadcast state. | ||
| * | ||
| * This maintains a weak reference for each RDD, ShuffleDependency, and Broadcast of interest, | ||
| * to be processed when the associated object goes out of scope of the application. Actual | ||
| * cleanup is performed in a separate daemon thread. | ||
| */ | ||
| private[spark] class ContextCleaner(sc: SparkContext) extends Logging { | ||
|
|
||
| /** Classes to represent cleaning tasks */ | ||
| private sealed trait CleanupTask | ||
| private case class CleanRDD(rddId: Int) extends CleanupTask | ||
| private case class CleanShuffle(shuffleId: Int) extends CleanupTask | ||
| // TODO: add CleanBroadcast | ||
| private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference] | ||
| with SynchronizedBuffer[CleanupTaskWeakReference] | ||
|
|
||
| private val referenceBuffer = new ArrayBuffer[WeakReferenceWithCleanupTask] | ||
| with SynchronizedBuffer[WeakReferenceWithCleanupTask] | ||
| private val referenceQueue = new ReferenceQueue[AnyRef] | ||
|
|
||
| private val listeners = new ArrayBuffer[CleanerListener] | ||
| with SynchronizedBuffer[CleanerListener] | ||
|
|
||
| private val cleaningThread = new Thread() { override def run() { keepCleaning() }} | ||
|
|
||
| private val REF_QUEUE_POLL_TIMEOUT = 100 | ||
|
|
||
| @volatile private var stopped = false | ||
|
|
||
| private class WeakReferenceWithCleanupTask(referent: AnyRef, val task: CleanupTask) | ||
| extends WeakReference(referent, referenceQueue) | ||
| /** Attach a listener object to get information of when objects are cleaned. */ | ||
| def attachListener(listener: CleanerListener) { | ||
| listeners += listener | ||
| } | ||
|
|
||
| /** Start the cleaner */ | ||
| /** Start the cleaner. */ | ||
| def start() { | ||
| cleaningThread.setDaemon(true) | ||
| cleaningThread.setName("ContextCleaner") | ||
| cleaningThread.start() | ||
| } | ||
|
|
||
| /** Stop the cleaner */ | ||
| /** Stop the cleaner. */ | ||
| def stop() { | ||
| stopped = true | ||
| cleaningThread.interrupt() | ||
| } | ||
|
|
||
| /** | ||
| * Register a RDD for cleanup when it is garbage collected. | ||
| */ | ||
| /** Register a RDD for cleanup when it is garbage collected. */ | ||
| def registerRDDForCleanup(rdd: RDD[_]) { | ||
| registerForCleanup(rdd, CleanRDD(rdd.id)) | ||
| } | ||
|
|
||
| /** | ||
| * Register a shuffle dependency for cleanup when it is garbage collected. | ||
| */ | ||
| /** Register a ShuffleDependency for cleanup when it is garbage collected. */ | ||
| def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) { | ||
| registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) | ||
| } | ||
|
|
||
| /** Cleanup RDD. */ | ||
| def cleanupRDD(rdd: RDD[_]) { | ||
| doCleanupRDD(rdd.id) | ||
| } | ||
|
|
||
| /** Cleanup shuffle. */ | ||
| def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) { | ||
| doCleanupShuffle(shuffleDependency.shuffleId) | ||
| } | ||
|
|
||
| /** Attach a listener object to get information of when objects are cleaned. */ | ||
| def attachListener(listener: CleanerListener) { | ||
| listeners += listener | ||
| /** Register a Broadcast for cleanup when it is garbage collected. */ | ||
| def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) { | ||
| registerForCleanup(broadcast, CleanBroadcast(broadcast.id)) | ||
| } | ||
|
|
||
| /** Register an object for cleanup. */ | ||
| private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) { | ||
| referenceBuffer += new WeakReferenceWithCleanupTask(objectForCleanup, task) | ||
| referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) | ||
| } | ||
|
|
||
| /** Keep cleaning RDDs and shuffle data */ | ||
| /** Keep cleaning RDD, shuffle, and broadcast state. */ | ||
| private def keepCleaning() { | ||
| while (!isStopped) { | ||
| while (!stopped) { | ||
| try { | ||
| val reference = Option(referenceQueue.remove(REF_QUEUE_POLL_TIMEOUT)) | ||
| .map(_.asInstanceOf[WeakReferenceWithCleanupTask]) | ||
| val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT)) | ||
| .map(_.asInstanceOf[CleanupTaskWeakReference]) | ||
| reference.map(_.task).foreach { task => | ||
| logDebug("Got cleaning task " + task) | ||
| referenceBuffer -= reference.get | ||
| task match { | ||
| case CleanRDD(rddId) => doCleanupRDD(rddId) | ||
| case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId) | ||
| case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId) | ||
| } | ||
| } | ||
| } catch { | ||
| case ie: InterruptedException => | ||
| if (!isStopped) logWarning("Cleaning thread interrupted") | ||
| if (!stopped) logWarning("Cleaning thread interrupted") | ||
| case t: Throwable => logError("Error in cleaning thread", t) | ||
| } | ||
| } | ||
|
|
@@ -129,7 +130,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { | |
| private def doCleanupRDD(rddId: Int) { | ||
| try { | ||
| logDebug("Cleaning RDD " + rddId) | ||
| sc.unpersistRDD(rddId, false) | ||
| sc.unpersistRDD(rddId, blocking = false) | ||
| listeners.foreach(_.rddCleaned(rddId)) | ||
| logInfo("Cleaned RDD " + rddId) | ||
| } catch { | ||
|
|
@@ -150,10 +151,47 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { | |
| } | ||
| } | ||
|
|
||
| private def mapOutputTrackerMaster = | ||
| sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] | ||
| /** Perform broadcast cleanup. */ | ||
| private def doCleanupBroadcast(broadcastId: Long) { | ||
| try { | ||
| logDebug("Cleaning broadcast " + broadcastId) | ||
| broadcastManager.unbroadcast(broadcastId, removeFromDriver = true) | ||
| listeners.foreach(_.broadcastCleaned(broadcastId)) | ||
| logInfo("Cleaned broadcast " + broadcastId) | ||
| } catch { | ||
| case t: Throwable => logError("Error cleaning broadcast " + broadcastId, t) | ||
| } | ||
| } | ||
|
|
||
| private def blockManagerMaster = sc.env.blockManager.master | ||
| private def broadcastManager = sc.env.broadcastManager | ||
| private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] | ||
|
|
||
| // Used for testing | ||
|
|
||
| private[spark] def cleanupRDD(rdd: RDD[_]) { | ||
| doCleanupRDD(rdd.id) | ||
| } | ||
|
|
||
| private[spark] def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) { | ||
| doCleanupShuffle(shuffleDependency.shuffleId) | ||
| } | ||
|
|
||
| private def isStopped = stopped | ||
| private[spark] def cleanupBroadcast[T](broadcast: Broadcast[T]) { | ||
| doCleanupBroadcast(broadcast.id) | ||
| } | ||
|
|
||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. extra space. |
||
| } | ||
|
|
||
| private object ContextCleaner { | ||
| private val REF_QUEUE_POLL_TIMEOUT = 100 | ||
| } | ||
|
|
||
| /** | ||
| * Listener class used for testing when any item has been cleaned by the Cleaner class. | ||
| */ | ||
| private[spark] trait CleanerListener { | ||
| def rddCleaned(rddId: Int) | ||
| def shuffleCleaned(shuffleId: Int) | ||
| def broadcastCleaned(broadcastId: Long) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,7 +35,6 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad | |
| import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} | ||
| import org.apache.mesos.MesosNativeLibrary | ||
|
|
||
| import org.apache.spark.broadcast.Broadcast | ||
| import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} | ||
| import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} | ||
| import org.apache.spark.rdd._ | ||
|
|
@@ -230,6 +229,7 @@ class SparkContext( | |
|
|
||
| private[spark] val cleaner = new ContextCleaner(this) | ||
| cleaner.start() | ||
|
|
||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this space intentional?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes |
||
| postEnvironmentUpdate() | ||
|
|
||
| /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ | ||
|
|
@@ -643,7 +643,11 @@ class SparkContext( | |
| * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. | ||
| * The variable will be sent to each cluster only once. | ||
| */ | ||
| def broadcast[T](value: T): Broadcast[T] = env.broadcastManager.newBroadcast[T](value, isLocal) | ||
| def broadcast[T](value: T) = { | ||
| val bc = env.broadcastManager.newBroadcast[T](value, isLocal) | ||
| cleaner.registerBroadcastForCleanup(bc) | ||
| bc | ||
| } | ||
|
|
||
| /** | ||
| * Add a file to be downloaded with this Spark job on every node. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -185,6 +185,7 @@ object SparkEnv extends Logging { | |
| } else { | ||
| new MapOutputTrackerWorker(conf) | ||
| } | ||
|
|
||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this space intentional?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes |
||
| // Have to assign trackerActor after initialization as MapOutputTrackerActor | ||
| // requires the MapOutputTracker itself | ||
| mapOutputTracker.trackerActor = registerOrLookup( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,9 +18,6 @@ | |
| package org.apache.spark.broadcast | ||
|
|
||
| import java.io.Serializable | ||
| import java.util.concurrent.atomic.AtomicLong | ||
|
|
||
| import org.apache.spark._ | ||
|
|
||
| /** | ||
| * A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable | ||
|
|
@@ -51,49 +48,26 @@ import org.apache.spark._ | |
| * @tparam T Type of the data contained in the broadcast variable. | ||
| */ | ||
| abstract class Broadcast[T](val id: Long) extends Serializable { | ||
| def value: T | ||
|
|
||
| // We cannot have an abstract readObject here due to some weird issues with | ||
| // readObject having to be 'private' in sub-classes. | ||
|
|
||
| override def toString = "Broadcast(" + id + ")" | ||
| } | ||
|
|
||
| private[spark] | ||
| class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager) | ||
| extends Logging with Serializable { | ||
|
|
||
| private var initialized = false | ||
| private var broadcastFactory: BroadcastFactory = null | ||
| /** | ||
| * Whether this Broadcast is actually usable. This should be false once persisted state is | ||
| * removed from the driver. | ||
| */ | ||
| protected var isValid: Boolean = true | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe this isValid should be public function. If we are providing a way to completely cleanup broadcast from driver such that it is invalidated, then we should provide a way to identify that as well. |
||
|
|
||
| initialize() | ||
|
|
||
| // Called by SparkContext or Executor before using Broadcast | ||
| private def initialize() { | ||
| synchronized { | ||
| if (!initialized) { | ||
| val broadcastFactoryClass = conf.get( | ||
| "spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") | ||
|
|
||
| broadcastFactory = | ||
| Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] | ||
|
|
||
| // Initialize appropriate BroadcastFactory and BroadcastObject | ||
| broadcastFactory.initialize(isDriver, conf, securityManager) | ||
|
|
||
| initialized = true | ||
| } | ||
| } | ||
| } | ||
|
|
||
| def stop() { | ||
| broadcastFactory.stop() | ||
| } | ||
| def value: T | ||
|
|
||
| private val nextBroadcastId = new AtomicLong(0) | ||
| /** | ||
| * Remove all persisted state associated with this broadcast. Overriding implementations | ||
| * should set isValid to false if persisted state is also removed from the driver. | ||
| * | ||
| * @param removeFromDriver Whether to remove state from the driver. | ||
| * If true, the resulting broadcast should no longer be valid. | ||
| */ | ||
| def unpersist(removeFromDriver: Boolean) | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There should be a unpersist() (that is no parameters) for simple usecase. Many people will not the implications of the removeFromDriver and not know what to set it as if that is only option. unpersist() = unpersist(false).
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also note to self: add this to Java and/or python API |
||
|
|
||
| def newBroadcast[T](value_ : T, isLocal: Boolean) = | ||
| broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) | ||
| // We cannot define abstract readObject and writeObject here due to some weird issues | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not weird issue! This is the known fact that readObject and writeObject has to be private. And obviously cannot be abstract. It not like Object has these functions and you are overriding them. These functions are actually being defined here to explicitly control the serialization of the fields of this class and only this class (not the fields of subclasses). If subclasses want to customize its serialization, then they need to define their own readObject and writeObject. The objectoutputstream actually uses reflection to test whether these functions are defined on the instance to be serialized, and calls them. It is a well known behavior and doesnt deserve a comment :)
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ha actually it was already there. All I did was move it, and GitHub's not smart enough to figure that out. |
||
| // with these methods having to be 'private' in sub-classes. | ||
|
|
||
| def isDriver = _isDriver | ||
| override def toString = "Broadcast(" + id + ")" | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does it make sense to have private[spark] here since the whole class is private[spark]? Might as well make it public.