Skip to content
Closed
25 changes: 5 additions & 20 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
package org.apache.spark

import java.io._
import java.util.Arrays
import java.util.concurrent.ConcurrentHashMap
import java.util.zip.{GZIPInputStream, GZIPOutputStream}

Expand Down Expand Up @@ -267,8 +266,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
}

/**
* MapOutputTracker for the driver. This uses TimeStampedHashMap to keep track of map
* output information, which allows old output information based on a TTL.
* MapOutputTracker for the driver.
*/
private[spark] class MapOutputTrackerMaster(conf: SparkConf)
extends MapOutputTracker(conf) {
Expand All @@ -291,17 +289,10 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
// can be read locally, but may lead to more delay in scheduling if those locations are busy.
private val REDUCER_PREF_LOCS_FRACTION = 0.2

/**
* Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the driver,
* so that statuses are dropped only by explicit de-registering or by TTL-based cleaning (if set).
* Other than these two scenarios, nothing should be dropped from this HashMap.
*/
protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]()
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]()

// For cleaning up TimeStampedHashMaps
private val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf)
// HashMaps for storing mapStatuses and cached serialized statuses in the driver.
// Statuses are dropped only by explicit de-registering.
protected val mapStatuses = new HashMap[Int, Array[MapStatus]]()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that these need to be ConcurrentHashMaps in order to preserve the old code's thread-safety guarantees.

private val cachedSerializedStatuses = new HashMap[Int, Array[Byte]]()

def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
Expand Down Expand Up @@ -462,14 +453,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
sendTracker(StopMapOutputTracker)
mapStatuses.clear()
trackerEndpoint = null
metadataCleaner.cancel()
cachedSerializedStatuses.clear()
}

private def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}
}

/**
Expand Down
17 changes: 2 additions & 15 deletions core/src/main/scala/org/apache/spark/SparkContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import scala.collection.mutable.HashMap
import scala.reflect.{ClassTag, classTag}
import scala.util.control.NonFatal

import com.google.common.collect.MapMaker
import org.apache.commons.lang.SerializationUtils
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
Expand Down Expand Up @@ -221,7 +222,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
private var _eventLogDir: Option[URI] = None
private var _eventLogCodec: Option[String] = None
private var _env: SparkEnv = _
private var _metadataCleaner: MetadataCleaner = _
private var _jobProgressListener: JobProgressListener = _
private var _statusTracker: SparkStatusTracker = _
private var _progressBar: Option[ConsoleProgressBar] = None
Expand Down Expand Up @@ -295,8 +295,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
private[spark] val addedJars = HashMap[String, Long]()

// Keeps track of all persisted RDDs
private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]]
private[spark] def metadataCleaner: MetadataCleaner = _metadataCleaner
private[spark] val persistentRdds = new MapMaker().weakValues().makeMap[Int, RDD[_]]().asScala
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll add a comment here to clarify that MapMaker returns a ConcurrentHashMap (or will just use an explicit return type, if Guava supports that).

private[spark] def jobProgressListener: JobProgressListener = _jobProgressListener

def statusTracker: SparkStatusTracker = _statusTracker
Expand Down Expand Up @@ -463,8 +462,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
_conf.set("spark.repl.class.uri", replUri)
}

_metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, _conf)

_statusTracker = new SparkStatusTracker(this)

_progressBar =
Expand Down Expand Up @@ -1721,11 +1718,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
env.metricsSystem.report()
}
}
if (metadataCleaner != null) {
Utils.tryLogNonFatalError {
metadataCleaner.cancel()
}
}
Utils.tryLogNonFatalError {
_cleaner.foreach(_.stop())
}
Expand Down Expand Up @@ -2193,11 +2185,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
}

/** Called by MetadataCleaner to clean up the persistentRdds map periodically */
private[spark] def cleanup(cleanupTime: Long) {
persistentRdds.clearOldValues(cleanupTime)
}

// In order to prevent multiple SparkContexts from being active at the same time, mark this
// context as having finished construction.
// NOTE: this must be placed at the end of the SparkContext constructor.
Expand Down
27 changes: 3 additions & 24 deletions core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ import java.io.{BufferedInputStream, BufferedOutputStream}
import java.net.{URL, URLConnection, URI}
import java.util.concurrent.TimeUnit

import scala.collection.mutable
import scala.reflect.ClassTag

import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
import org.apache.spark.util.Utils

/**
* A [[org.apache.spark.broadcast.Broadcast]] implementation that uses HTTP server
Expand Down Expand Up @@ -112,10 +113,9 @@ private[broadcast] object HttpBroadcast extends Logging {
private var securityManager: SecurityManager = null

// TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist
private val files = new TimeStampedHashSet[File]
private val files = new mutable.HashSet[File]
private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt
private var compressionCodec: CompressionCodec = null
private var cleaner: MetadataCleaner = null

def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
synchronized {
Expand All @@ -128,7 +128,6 @@ private[broadcast] object HttpBroadcast extends Logging {
conf.set("spark.httpBroadcast.uri", serverUri)
}
serverUri = conf.get("spark.httpBroadcast.uri")
cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf)
compressionCodec = CompressionCodec.createCodec(conf)
initialized = true
}
Expand All @@ -141,10 +140,6 @@ private[broadcast] object HttpBroadcast extends Logging {
server.stop()
server = null
}
if (cleaner != null) {
cleaner.cancel()
cleaner = null
}
compressionCodec = null
initialized = false
}
Expand Down Expand Up @@ -236,22 +231,6 @@ private[broadcast] object HttpBroadcast extends Logging {
}
}

/**
* Periodically clean up old broadcasts by removing the associated map entries and
* deleting the associated files.
*/
private def cleanup(cleanupTime: Long) {
val iterator = files.internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
val (file, time) = (entry.getKey, entry.getValue)
if (time < cleanupTime) {
iterator.remove()
deleteBroadcastFile(file)
}
}
}

private def deleteBroadcastFile(file: File) {
try {
if (file.exists) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage._
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
import org.apache.spark.util.Utils
import org.apache.spark.{Logging, SparkConf, SparkEnv}

/** A group of writers for a ShuffleMapTask, one writer per reducer. */
Expand Down Expand Up @@ -63,10 +63,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
val completedMapTasks = new ConcurrentLinkedQueue[Int]()
}

private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState]

private val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf)
private val shuffleStates = new scala.collection.mutable.HashMap[ShuffleId, ShuffleState]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll also make this into a ConcurrentHashMap.


/**
* Get a ShuffleWriterGroup for the given map task, which will register it as complete
Expand All @@ -75,9 +72,8 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer,
writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = {
new ShuffleWriterGroup {
shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers))
private val shuffleState = shuffleStates(shuffleId)

private val shuffleState =
shuffleStates.getOrElseUpdate(shuffleId, new ShuffleState(numReducers))
val openStartTime = System.nanoTime
val serializerInstance = serializer.newInstance()
val writers: Array[DiskBlockObjectWriter] = {
Expand Down Expand Up @@ -131,11 +127,5 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
}
}

private def cleanup(cleanupTime: Long) {
shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId))
}

override def stop() {
metadataCleaner.cancel()
}
override def stop(): Unit = {}
}
Loading