diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 41a1b51a43154..2f17862107de2 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -19,6 +19,7 @@ package org.apache.spark import java.io.{ByteArrayInputStream, InputStream, IOException, ObjectInputStream, ObjectOutputStream} import java.nio.ByteBuffer +import java.util.{HashMap => JHashMap} import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.locks.ReentrantReadWriteLock @@ -711,6 +712,10 @@ private[spark] class MapOutputTrackerMaster( private[spark] val isLocal: Boolean) extends MapOutputTracker(conf) { + // Keep track of last access times for shuffle based TTL. Note: we don't use concurrent + // here because we don't care about overwriting times that are "close." + private[spark] val shuffleAccessTime = new JHashMap[Int, Long] + // The size at which we use Broadcast to send the map output statuses to the executors private val minSizeForBroadcast = conf.get(SHUFFLE_MAPOUTPUT_MIN_SIZE_FOR_BROADCAST).toInt @@ -745,6 +750,16 @@ private[spark] class MapOutputTrackerMaster( private val pushBasedShuffleEnabled = Utils.isPushBasedShuffleEnabled(conf, isDriver = true) + private[spark] val cleanerThreadpool: Option[ThreadPoolExecutor] = { + if (conf.get(SPARK_TTL_SHUFFLE_BLOCK_CLEANER).isDefined) { + val pool = ThreadUtils.newDaemonFixedThreadPool(1, "map-output-ttl-cleaner") + pool.execute(new TTLCleaner) + Some(pool) + } else { + None + } + } + // Thread pool used for handling map output status requests. This is a separate thread pool // to ensure we don't block the normal dispatcher threads. private val threadpool: ThreadPoolExecutor = { @@ -758,6 +773,68 @@ private[spark] class MapOutputTrackerMaster( private val availableProcessors = Runtime.getRuntime.availableProcessors() + def updateShuffleAtime(shuffleId: Int): Unit = { + if (conf.get(SPARK_TTL_SHUFFLE_BLOCK_CLEANER).isDefined) { + shuffleAccessTime.put(shuffleId, System.currentTimeMillis()) + } + } + + private class TTLCleaner extends Runnable { + override def run(): Unit = { + try { + // Poll the shuffle access times if we're configured for it. + conf.get(SPARK_TTL_SHUFFLE_BLOCK_CLEANER) match { + case Some(ttl) => + while (true) { + val maxAge = System.currentTimeMillis() - ttl + // Find the elements to be removed & update oldest remaining time (if any) + var oldest = System.currentTimeMillis() + // Make a copy here to reduce the chance of CME + try { + val toBeRemoved = shuffleAccessTime.asScala.toList.flatMap { + case (shuffleId, atime) => + if (atime < maxAge) { + Some(shuffleId) + } else { + if (atime < oldest) { + oldest = atime + } + None + } + }.toList + toBeRemoved.map { shuffleId => + try { + // Remove the shuffle access time regardless of + // if we cleanup the shuffle successfully or not + // since we could have a shuffle that's already + // been cleaned up elsewhere. + shuffleAccessTime.remove(shuffleId) + unregisterAllMapAndMergeOutput(shuffleId) + } catch { + case NonFatal(e) => + logDebug( + log"Error removing shuffle ${MDC(SHUFFLE_ID, shuffleId)}", e) + } + } + // Wait until the next possible element to be removed + val delay = math.max((oldest + ttl) - System.currentTimeMillis(), 100) + Thread.sleep(delay) + } catch { + case _: java.util.ConcurrentModificationException => + // Just retry, blocks were stored while we were iterating + Thread.sleep(100) + } + } + case None => + logDebug("Tried to start TTL cleaner when not configured.") + } + } catch { + case _: InterruptedException => + logInfo("MapOutputTrackerMaster TTLCleaner thread interrupted, exiting.") + } + } + } + // Make sure that we aren't going to exceed the max RPC message size by making sure // we use broadcast to send large map output statuses. if (minSizeForBroadcast > maxRpcMessageSize) { @@ -783,6 +860,7 @@ private[spark] class MapOutputTrackerMaster( val shuffleStatus = shuffleStatuses.get(shuffleId).head logDebug(s"Handling request to send ${if (needMergeOutput) "map/merge" else "map"}" + s" output locations for shuffle $shuffleId to $hostPort") + updateShuffleAtime(shuffleId) if (needMergeOutput) { context.reply( shuffleStatus. @@ -834,6 +912,7 @@ private[spark] class MapOutputTrackerMaster( } def registerShuffle(shuffleId: Int, numMaps: Int, numReduces: Int): Unit = { + updateShuffleAtime(shuffleId) if (pushBasedShuffleEnabled) { if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps, numReduces)).isDefined) { throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice") @@ -880,6 +959,7 @@ private[spark] class MapOutputTrackerMaster( shuffleStatus.removeOutputsByFilter(x => true) shuffleStatus.removeMergeResultsByFilter(x => true) shuffleStatus.removeShuffleMergerLocations() + shuffleAccessTime.remove(shuffleId) incrementEpoch() } @@ -1257,12 +1337,14 @@ private[spark] class MapOutputTrackerMaster( // This method is only called in local-mode. override def getShufflePushMergerLocations(shuffleId: Int): Seq[BlockManagerId] = { + updateShuffleAtime(shuffleId) shuffleStatuses.get(shuffleId).map(_.getShufflePushMergerLocations).getOrElse(Seq.empty) } override def stop(): Unit = { mapOutputTrackerMasterMessages.offer(PoisonPill) threadpool.shutdown() + cleanerThreadpool.map(_.shutdownNow()) try { sendTracker(StopMapOutputTracker) } catch { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 331d798a3d768..287f6e815ba98 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -2919,4 +2919,24 @@ package object config { .checkValue(v => v.forall(Set("stdout", "stderr").contains), "The value only can be one or more of 'stdout, stderr'.") .createWithDefault(Seq("stdout", "stderr")) + + private[spark] val SPARK_TTL_BLOCK_CLEANER = + ConfigBuilder("spark.cleaner.ttl.all") + .doc("Add a TTL for all blocks tracked in Spark. By default blocks are only removed after " + + " GC on driver which with DataFrames or RDDs at the global scope will not occur. " + + "This must be configured before starting the SparkContext (e.g. can not be added to " + + "a running Spark instance.)") + .version("4.1.0") + .timeConf(TimeUnit.MILLISECONDS) + .createOptional + + private[spark] val SPARK_TTL_SHUFFLE_BLOCK_CLEANER = + ConfigBuilder("spark.cleaner.ttl.shuffle") + .doc("Add a TTL for shuffle blocks tracked in Spark. By default blocks are only removed " + + "after GC on driver which with DataFrames or RDDs at the global scope will not occur." + + "This must be configured before starting the SparkContext (e.g. can not be added to " + + "a running Spark instance.)") + .version("4.1.0") + .timeConf(TimeUnit.MILLISECONDS) + .createOptional } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 3e46a53ee082c..56c0f540400fd 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -45,12 +45,18 @@ sealed abstract class BlockId { (isInstanceOf[ShuffleBlockId] || isInstanceOf[ShuffleBlockBatchId] || isInstanceOf[ShuffleDataBlockId] || isInstanceOf[ShuffleIndexBlockId]) } + def asShuffleId: Option[ShuffleId] = if (isShuffle) Some(asInstanceOf[ShuffleId]) else None def isShuffleChunk: Boolean = isInstanceOf[ShuffleBlockChunkId] def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId] override def toString: String = name } +@DeveloperApi +trait ShuffleId { + def shuffleId: Int +} + @DeveloperApi case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { override def name: String = "rdd_" + rddId + "_" + splitIndex @@ -59,7 +65,8 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId { // Format of the shuffle block ids (including data and index) should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getBlockData(). @DeveloperApi -case class ShuffleBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { +case class ShuffleBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId + with ShuffleId { override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } @@ -69,7 +76,7 @@ case class ShuffleBlockBatchId( shuffleId: Int, mapId: Long, startReduceId: Int, - endReduceId: Int) extends BlockId { + endReduceId: Int) extends BlockId with ShuffleId { override def name: String = { "shuffle_" + shuffleId + "_" + mapId + "_" + startReduceId + "_" + endReduceId } @@ -81,18 +88,20 @@ case class ShuffleBlockChunkId( shuffleId: Int, shuffleMergeId: Int, reduceId: Int, - chunkId: Int) extends BlockId { + chunkId: Int) extends BlockId with ShuffleId { override def name: String = "shuffleChunk_" + shuffleId + "_" + shuffleMergeId + "_" + reduceId + "_" + chunkId } @DeveloperApi -case class ShuffleDataBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { +case class ShuffleDataBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId + with ShuffleId { override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".data" } @DeveloperApi -case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId { +case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, reduceId: Int) extends BlockId + with ShuffleId { override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId + ".index" } @@ -108,7 +117,7 @@ case class ShufflePushBlockId( shuffleId: Int, shuffleMergeId: Int, mapIndex: Int, - reduceId: Int) extends BlockId { + reduceId: Int) extends BlockId with ShuffleId { override def name: String = "shufflePush_" + shuffleId + "_" + shuffleMergeId + "_" + mapIndex + "_" + reduceId + "" } @@ -118,7 +127,7 @@ case class ShufflePushBlockId( case class ShuffleMergedBlockId( shuffleId: Int, shuffleMergeId: Int, - reduceId: Int) extends BlockId { + reduceId: Int) extends BlockId with ShuffleId { override def name: String = "shuffleMerged_" + shuffleId + "_" + shuffleMergeId + "_" + reduceId } @@ -129,7 +138,7 @@ case class ShuffleMergedDataBlockId( appId: String, shuffleId: Int, shuffleMergeId: Int, - reduceId: Int) extends BlockId { + reduceId: Int) extends BlockId with ShuffleId { override def name: String = RemoteBlockPushResolver.MERGED_SHUFFLE_FILE_NAME_PREFIX + "_" + appId + "_" + shuffleId + "_" + shuffleMergeId + "_" + reduceId + ".data" } @@ -140,7 +149,7 @@ case class ShuffleMergedIndexBlockId( appId: String, shuffleId: Int, shuffleMergeId: Int, - reduceId: Int) extends BlockId { + reduceId: Int) extends BlockId with ShuffleId { override def name: String = RemoteBlockPushResolver.MERGED_SHUFFLE_FILE_NAME_PREFIX + "_" + appId + "_" + shuffleId + "_" + shuffleMergeId + "_" + reduceId + ".index" } @@ -151,7 +160,7 @@ case class ShuffleMergedMetaBlockId( appId: String, shuffleId: Int, shuffleMergeId: Int, - reduceId: Int) extends BlockId { + reduceId: Int) extends BlockId with ShuffleId { override def name: String = RemoteBlockPushResolver.MERGED_SHUFFLE_FILE_NAME_PREFIX + "_" + appId + "_" + shuffleId + "_" + shuffleMergeId + "_" + reduceId + ".meta" } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 9d6539e09f452..5064e979dd230 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -19,7 +19,7 @@ package org.apache.spark.storage import java.io.IOException import java.util.{HashMap => JHashMap} -import java.util.concurrent.TimeUnit +import java.util.concurrent.{ThreadPoolExecutor, TimeUnit} import scala.collection.mutable import scala.concurrent.{ExecutionContext, ExecutionContextExecutorService, Future, TimeoutException} @@ -85,6 +85,11 @@ class BlockManagerMasterEndpoint( // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] + // Keep track of last access times if we're using block TTLs + // We intentionally use a non-concurrent datastructure since "close" + // is good enough for atimes and reducing update cost matters. + private[spark] val rddAccessTime = new JHashMap[Int, Long] + // Mapping from task id to the set of rdd blocks which are generated from the task. private val tidToRddBlockIds = new mutable.HashMap[Long, mutable.HashSet[RDDBlockId]] // Record the RDD blocks which are not visible yet, a block will be removed from this collection @@ -104,6 +109,17 @@ class BlockManagerMasterEndpoint( private implicit val askExecutionContext: ExecutionContextExecutorService = ExecutionContext.fromExecutorService(askThreadPool) + + private[spark] val cleanerThreadpool: Option[ThreadPoolExecutor] = { + if (conf.get(config.SPARK_TTL_BLOCK_CLEANER).isDefined) { + val pool = ThreadUtils.newDaemonFixedThreadPool(1, "rdd-ttl-cleaner") + pool.execute(new TTLCleaner) + Some(pool) + } else { + None + } + } + private val topologyMapper = { val topologyMapperClassName = conf.get( config.STORAGE_REPLICATION_TOPOLOGY_MAPPER) @@ -143,6 +159,8 @@ class BlockManagerMasterEndpoint( case _updateBlockInfo @ UpdateBlockInfo(blockManagerId, blockId, storageLevel, deserializedSize, size) => + // We don't update the block access times here because the update block infos are triggered by + // migrations rather than actual access. @inline def handleResult(success: Boolean): Unit = { // SPARK-30594: we should not post `SparkListenerBlockUpdated` when updateBlockInfo // returns false since the block info would be updated again later. @@ -249,6 +267,77 @@ class BlockManagerMasterEndpoint( context.reply(updateRDDBlockVisibility(taskId, visible)) } + private def updateBlockAtime(blockId: BlockId) = { + // First handle "regular" blocks + if (!blockId.isShuffle) { + // Only update access times if we have the cleaner enabled. + if (conf.get(config.SPARK_TTL_BLOCK_CLEANER).isDefined) { + // Note: we don't _really_ care about concurrency here too much, if we have + // conflicting updates in time they're going to "close enough" to be a wash + // so we don't bother checking the return value here. + // For now we only do RDD blocks, because I'm not convinced it's safe to TTL + // clean Broadcast blocks, but maybe we can revisit that. + blockId.asRDDId.map { r => rddAccessTime.put(r.rddId, System.currentTimeMillis()) } + } + } else if (conf.get(config.SPARK_TTL_SHUFFLE_BLOCK_CLEANER).isDefined) { + // We track shuffles in the mapoutput tracker. + blockId.asShuffleId.map(s => mapOutputTracker.updateShuffleAtime(s.shuffleId)) + } + } + + + private class TTLCleaner extends Runnable { + override def run(): Unit = { + try { + // Poll the shuffle access times if we're configured for it. + conf.get(config.SPARK_TTL_BLOCK_CLEANER) match { + case Some(ttl) => + while (true) { + val maxAge = System.currentTimeMillis() - ttl + // Find the elements to be removed & update oldest remaining time (if any) + var oldest = System.currentTimeMillis() + // Make a copy here to reduce the chance of CME + try { + val toBeRemoved = rddAccessTime.asScala.toList.flatMap { case (rddId, atime) => + if (atime < maxAge) { + Some(rddId) + } else { + if (atime < oldest) { + oldest = atime + } + None + } + }.toList + toBeRemoved.map { rddId => + try { + // Always remove the RDD from our tracking list first incase an error occurs. + rddAccessTime.remove(rddId) + removeRdd(rddId) + } catch { + case NonFatal(e) => + logDebug(log"Error removing rdd ${MDC(RDD_ID, rddId)} with TTL cleaner", e) + } + } + // Wait until the next possible element to be removed + val delay = math.max((oldest + ttl) - System.currentTimeMillis(), 100) + Thread.sleep(delay) + } catch { + case _: java.util.ConcurrentModificationException => + // Just retry, blocks were stored while we were iterating + Thread.sleep(10) + } + } + case None => + logDebug("Tried to start TTL cleaner when not configured.") + } + } catch { + case _: InterruptedException => + // Exit gracefully + logInfo("RDD TTL cleaner thread interrupted, shutting down.") + } + } + } + private def isRDDBlockVisible(blockId: RDDBlockId): Boolean = { if (trackingCacheVisibility) { blockLocations.containsKey(blockId) && @@ -345,7 +434,17 @@ class BlockManagerMasterEndpoint( } private def removeRdd(rddId: Int): Future[Seq[Int]] = { - // First remove the metadata for the given RDD, and then asynchronously remove the blocks + // Drop the RDD from TTL tracking. + try { + if (conf.get(config.SPARK_TTL_BLOCK_CLEANER).isDefined) { + rddAccessTime.remove(rddId) + } + } catch { + case NonFatal(e) => + logWarning(log"Error removing ${MDC(RDD_ID, rddId)} from RDD TTL tracking", e) + } + + // Then remove the metadata for the given RDD, and then asynchronously remove the blocks // from the storage endpoints. // The message sent to the storage endpoints to remove the RDD @@ -411,7 +510,11 @@ class BlockManagerMasterEndpoint( Future.sequence(removeRddFromExecutorsFutures ++ removeRddBlockViaExtShuffleServiceFutures) } + // For testing. + private[spark] def getMapOutputTrackerMaster(): MapOutputTrackerMaster = mapOutputTracker + private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { + // Start with removing shuffle blocks without an associated executor (e.g. ESS only). // Find all shuffle blocks on executors that are no longer running val blocksToDeleteByShuffleService = new mutable.HashMap[BlockManagerId, mutable.HashSet[BlockId]] @@ -465,6 +568,7 @@ class BlockManagerMasterEndpoint( } }.getOrElse(Seq.empty) + // Remove shuffle blocks from running executors. val removeMsg = RemoveShuffle(shuffleId) val removeShuffleFromExecutorsFutures = blockManagerInfo.values.map { bm => bm.storageEndpoint.ask[Boolean](removeMsg).recover { @@ -547,6 +651,7 @@ class BlockManagerMasterEndpoint( } private def addMergerLocation(blockManagerId: BlockManagerId): Unit = { + logDebug(log"Adding merger location ${MDC(BLOCK_MANAGER_ID, blockManagerId)}") if (!blockManagerId.isDriver && !shuffleMergerLocations.contains(blockManagerId.host)) { val shuffleServerId = BlockManagerId(BlockManagerId.SHUFFLE_MERGER_IDENTIFIER, blockManagerId.host, externalShuffleServicePort) @@ -766,7 +871,8 @@ class BlockManagerMasterEndpoint( } private def updateShuffleBlockInfo(blockId: BlockId, blockManagerId: BlockManagerId) - : Future[Boolean] = { + : Future[Boolean] = { + logDebug(s"Updating shuffle block info ${blockId} on ${blockManagerId}") blockId match { case ShuffleIndexBlockId(shuffleId, mapId, _) => // SPARK-36782: Invoke `MapOutputTracker.updateMapOutput` within the thread @@ -823,6 +929,8 @@ class BlockManagerMasterEndpoint( } else { locations = new mutable.HashSet[BlockManagerId] blockLocations.put(blockId, locations) + // Since it's the initial put we register this as an access as well. + updateBlockAtime(blockId) } if (storageLevel.isValid) { @@ -863,12 +971,14 @@ class BlockManagerMasterEndpoint( } private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { + updateBlockAtime(blockId) if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty } private def getLocationsAndStatus( blockId: BlockId, requesterHost: String): Option[BlockLocationsAndStatus] = { + updateBlockAtime(blockId) val allLocations = Option(blockLocations.get(blockId)).map(_.toSeq).getOrElse(Seq.empty) val blockStatusWithBlockManagerId: Option[(BlockStatus, BlockManagerId)] = (if (externalShuffleServiceRddFetchEnabled && blockId.isRDD) { @@ -983,6 +1093,7 @@ class BlockManagerMasterEndpoint( override def onStop(): Unit = { askThreadPool.shutdownNow() + cleanerThreadpool.map(_.shutdownNow()) } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockTTLIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockTTLIntegrationSuite.scala new file mode 100644 index 0000000000000..7476c4ad5ed86 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/storage/BlockTTLIntegrationSuite.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.storage + +import scala.jdk.CollectionConverters._ + +import org.scalatest.concurrent.Eventually +import org.scalatest.time._ + +import org.apache.spark._ +import org.apache.spark.internal.config +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef} +import org.apache.spark.util.ResetSystemProperties + +class BlockTTLIntegrationSuite extends SparkFunSuite with LocalSparkContext + with ResetSystemProperties with Eventually { + + implicit override val patienceConfig: PatienceConfig = + PatienceConfig(timeout = scaled(Span(20, Seconds)), interval = scaled(Span(5, Millis))) + + val blockTTL = 5000L + + val numExecs = 3 + val numParts = 3 + val TaskStarted = "TASK_STARTED" + val TaskEnded = "TASK_ENDED" + val JobEnded = "JOB_ENDED" + + // TODO(holden): This is shared with MapOutputTrackerSuite move to a BlockTestUtils or similar. + private def fetchDeclaredField(value: AnyRef, fieldName: String): AnyRef = { + val field = value.getClass.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(value) + } + + private def lookupBlockManagerMasterEndpoint(sc: SparkContext): BlockManagerMasterEndpoint = { + val rpcEnv = sc.env.rpcEnv + val dispatcher = fetchDeclaredField(rpcEnv, "dispatcher") + fetchDeclaredField(dispatcher, "endpointRefs"). + asInstanceOf[java.util.Map[RpcEndpoint, RpcEndpointRef]].asScala. + filter(_._1.isInstanceOf[BlockManagerMasterEndpoint]). + head._1.asInstanceOf[BlockManagerMasterEndpoint] + } + + private def lookupMapOutputTrackerMaster(sc: SparkContext): MapOutputTrackerMaster = { + val bme = lookupBlockManagerMasterEndpoint(sc) + bme.getMapOutputTrackerMaster() + } + + test("Test that cache blocks are recorded.") { + val conf = new SparkConf() + .setAppName("test-blockmanager-decommissioner") + .setMaster("local-cluster[2, 1, 1024]") + .set(config.SPARK_TTL_BLOCK_CLEANER, blockTTL) + .set(config.SPARK_TTL_SHUFFLE_BLOCK_CLEANER, blockTTL) + sc = new SparkContext(conf) + sc.setLogLevel("DEBUG") + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) + val managerMasterEndpoint = lookupBlockManagerMasterEndpoint(sc) + assert(managerMasterEndpoint.rddAccessTime.isEmpty) + // Make some cache blocks + val input = sc.parallelize(1.to(100)).cache() + input.count() + // Check that the blocks were registered with the TTL tracker + assert(!managerMasterEndpoint.rddAccessTime.isEmpty) + val trackedRDDBlocks = managerMasterEndpoint.rddAccessTime.asScala.keys + assert(!trackedRDDBlocks.isEmpty) + } + + test("Test that shuffle blocks are tracked properly and removed after TTL") { + val conf = new SparkConf() + .setAppName("test-blockmanager-ttls-shuffle-only") + .setMaster("local-cluster[2, 1, 1024]") + .set(config.SPARK_TTL_BLOCK_CLEANER, blockTTL) + .set(config.SPARK_TTL_SHUFFLE_BLOCK_CLEANER, blockTTL) + sc = new SparkContext(conf) + sc.setLogLevel("DEBUG") + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) + val managerMasterEndpoint = lookupBlockManagerMasterEndpoint(sc) + val mapOutputTracker = lookupMapOutputTrackerMaster(sc) + // Make sure it's empty at the start + assert(managerMasterEndpoint.rddAccessTime.isEmpty) + assert(mapOutputTracker.shuffleAccessTime.isEmpty) + // Make some cache blocks + val input = sc.parallelize(1.to(100)).groupBy(_ % 10) + input.count() + // Make sure we've got the tracker threads defined + assert(mapOutputTracker.cleanerThreadpool.isDefined) + // Check that the shuffle blocks were NOT registered with the RDD TTL tracker. + assert(managerMasterEndpoint.rddAccessTime.isEmpty) + // Check that the shuffle blocks are registered with the map output TTL + eventually { assert(!mapOutputTracker.shuffleAccessTime.isEmpty) } + // It should be expired! + eventually { + val t = System.currentTimeMillis() + assert( + mapOutputTracker.shuffleAccessTime.isEmpty, + s"We should have no blocks since we are now at time ${t} with ttl of ${blockTTL}") + } + } + + + test(s"Test that all blocks are tracked properly and removed after TTL") { + val conf = new SparkConf() + .setAppName("test-blockmanager-ttls-enabled") + .setMaster("local-cluster[2, 1, 1024]") + .set(config.SPARK_TTL_BLOCK_CLEANER, blockTTL) + .set(config.SPARK_TTL_SHUFFLE_BLOCK_CLEANER, blockTTL) + sc = new SparkContext(conf) + sc.setLogLevel("DEBUG") + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) + val managerMasterEndpoint = lookupBlockManagerMasterEndpoint(sc) + val mapOutputTracker = lookupMapOutputTrackerMaster(sc) + assert(managerMasterEndpoint.rddAccessTime.isEmpty) + // Make some cache blocks + val input = sc.parallelize(1.to(100)).groupBy(_ % 10) + val cachedInput = input.cache() + cachedInput.count() + // Check that we have both shuffle & RDD blocks registered + eventually { assert(!managerMasterEndpoint.rddAccessTime.isEmpty) } + eventually { assert(!mapOutputTracker.shuffleAccessTime.isEmpty) } + // Both should be expired! + eventually { + val t = System.currentTimeMillis() + assert(mapOutputTracker.shuffleAccessTime.isEmpty, + s"We should have no blocks since we are now at time ${t} with ttl of ${blockTTL}") + assert(managerMasterEndpoint.rddAccessTime.isEmpty, + s"We should have no blocks since we are now at time ${t} with ttl of ${blockTTL}") + } + // And redoing the count should work and everything should come back. + input.count() + eventually { + assert(!managerMasterEndpoint.rddAccessTime.isEmpty) + assert(!mapOutputTracker.shuffleAccessTime.isEmpty) + } + } + + test("Test that blocks TTLS are not tracked when not enabled") { + val conf = new SparkConf() + .setAppName("test-blockmanager-decommissioner") + .setMaster("local-cluster[2, 1, 1024]") + sc = new SparkContext(conf) + sc.setLogLevel("DEBUG") + TestUtils.waitUntilExecutorsUp(sc, 2, 60000) + val managerMasterEndpoint = lookupBlockManagerMasterEndpoint(sc) + assert(managerMasterEndpoint.rddAccessTime.isEmpty) + // Make some cache blocks + val input = sc.parallelize(1.to(100)).groupBy(_ % 10).cache() + input.count() + // Check that no RDD blocks are tracked + assert(managerMasterEndpoint.rddAccessTime.isEmpty) + // Check that the no shuffle blocks are tracked. + val mapOutputTracker = lookupMapOutputTrackerMaster(sc) + assert(mapOutputTracker.shuffleAccessTime.isEmpty) + } +}