Skip to content
Closed
33 changes: 30 additions & 3 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ private class ShuffleStatus(numPartitions: Int) {
* locations is so small that we choose to ignore that case and store only a single location
* for each output.
*/
private[this] val mapStatuses = new Array[MapStatus](numPartitions)
// Exposed for testing
val mapStatuses = new Array[MapStatus](numPartitions)

/**
* The cached result of serializing the map statuses array. This cache is lazily populated when
Expand Down Expand Up @@ -105,14 +106,30 @@ private class ShuffleStatus(numPartitions: Int) {
}
}

/**
* Removes all shuffle outputs associated with this host. Note that this will also remove
* outputs which are served by an external shuffle server (if one exists).
*/
def removeOutputsOnHost(host: String): Unit = {
removeOutputsByFilter(x => x.host == host)
}

/**
* Removes all map outputs associated with the specified executor. Note that this will also
* remove outputs which are served by an external shuffle server (if one exists), as they are
* still registered with that execId.
*/
def removeOutputsOnExecutor(execId: String): Unit = synchronized {
removeOutputsByFilter(x => x.executorId == execId)
}

/**
* Removes all shuffle outputs which satisfies the filter. Note that this will also
* remove outputs which are served by an external shuffle server (if one exists).
*/
def removeOutputsByFilter(f: (BlockManagerId) => Boolean): Unit = synchronized {
for (mapId <- 0 until mapStatuses.length) {
if (mapStatuses(mapId) != null && mapStatuses(mapId).location.executorId == execId) {
if (mapStatuses(mapId) != null && f(mapStatuses(mapId).location)) {
_numAvailableOutputs -= 1
mapStatuses(mapId) = null
invalidateSerializedMapOutputStatusCache()
Expand Down Expand Up @@ -317,7 +334,8 @@ private[spark] class MapOutputTrackerMaster(

// HashMap for storing shuffleStatuses in the driver.
// Statuses are dropped only by explicit de-registering.
private val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala
// Exposed for testing
val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala

private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)

Expand Down Expand Up @@ -415,6 +433,15 @@ private[spark] class MapOutputTrackerMaster(
}
}

/**
* Removes all shuffle outputs associated with this host. Note that this will also remove
* outputs which are served by an external shuffle server (if one exists).
*/
def removeOutputsOnHost(host: String): Unit = {
shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnHost(host) }
incrementEpoch()
}

/**
* Removes all shuffle outputs associated with this executor. Note that this will also remove
* outputs which are served by an external shuffle server (if one exists), as they are still
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,14 @@ package object config {
.createOptional
// End blacklist confs

private[spark] val UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE =
ConfigBuilder("spark.files.fetchFailure.unRegisterOutputOnHost")
.doc("Whether to un-register all the outputs on the host in condition that we receive " +
" a FetchFailure. This is set default to false, which means, we only un-register the " +
" outputs related to the exact executor(instead of the host) on a FetchFailure.")
.booleanConf
.createWithDefault(false)

private[spark] val LISTENER_BUS_EVENT_QUEUE_CAPACITY =
ConfigBuilder("spark.scheduler.listenerbus.eventqueue.capacity")
.withAlternative("spark.scheduler.listenerbus.eventqueue.size")
Expand Down
67 changes: 55 additions & 12 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.commons.lang3.SerializationUtils
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.config
import org.apache.spark.internal.Logging
import org.apache.spark.network.util.JavaUtils
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
Expand Down Expand Up @@ -187,6 +188,14 @@ class DAGScheduler(
/** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false)

/**
* Whether to unregister all the outputs on the host in condition that we receive a FetchFailure,
* this is set default to false, which means, we only unregister the outputs related to the exact
* executor(instead of the host) on a FetchFailure.
*/
private[scheduler] val unRegisterOutputOnHostOnFetchFailure =
sc.getConf.get(config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE)

/**
* Number of consecutive stage attempts allowed before a stage is aborted.
*/
Expand Down Expand Up @@ -1336,7 +1345,21 @@ class DAGScheduler(

// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
handleExecutorLost(bmAddress.executorId, filesLost = true, Some(task.epoch))
val hostToUnregisterOutputs = if (env.blockManager.externalShuffleServiceEnabled &&
unRegisterOutputOnHostOnFetchFailure) {
// We had a fetch failure with the external shuffle service, so we
// assume all shuffle data on the node is bad.
Some(bmAddress.host)
} else {
// Unregister shuffle data just for one executor (we don't have any
// reason to believe shuffle data has been lost for the entire host).
None
}
removeExecutorAndUnregisterOutputs(
execId = bmAddress.executorId,
fileLost = true,
hostToUnregisterOutputs = hostToUnregisterOutputs,
maybeEpoch = Some(task.epoch))
}
}

Expand Down Expand Up @@ -1370,22 +1393,42 @@ class DAGScheduler(
*/
private[scheduler] def handleExecutorLost(
execId: String,
filesLost: Boolean,
maybeEpoch: Option[Long] = None) {
workerLost: Boolean): Unit = {
// if the cluster manager explicitly tells us that the entire worker was lost, then
// we know to unregister shuffle output. (Note that "worker" specifically refers to the process
// from a Standalone cluster, where the shuffle service lives in the Worker.)
val fileLost = workerLost || !env.blockManager.externalShuffleServiceEnabled
removeExecutorAndUnregisterOutputs(
execId = execId,
fileLost = fileLost,
hostToUnregisterOutputs = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one more question: if worker lost, shouldn't we unregister outputs on that worker/host?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems we can't get worker id here, nvm

maybeEpoch = None)
}

private def removeExecutorAndUnregisterOutputs(
execId: String,
fileLost: Boolean,
hostToUnregisterOutputs: Option[String],
maybeEpoch: Option[Long] = None): Unit = {
val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch)
if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) {
failedEpoch(execId) = currentEpoch
logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch))
blockManagerMaster.removeExecutor(execId)

if (filesLost || !env.blockManager.externalShuffleServiceEnabled) {
logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch))
mapOutputTracker.removeOutputsOnExecutor(execId)
if (fileLost) {
hostToUnregisterOutputs match {
case Some(host) =>
logInfo("Shuffle files lost for host: %s (epoch %d)".format(host, currentEpoch))
mapOutputTracker.removeOutputsOnHost(host)
case None =>
logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch))
mapOutputTracker.removeOutputsOnExecutor(execId)
}
clearCacheLocs()

} else {
logDebug("Additional executor lost message for %s (epoch %d)".format(execId, currentEpoch))
}
} else {
logDebug("Additional executor lost message for " + execId +
"(epoch " + currentEpoch + ")")
}
}

Expand Down Expand Up @@ -1678,11 +1721,11 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
dagScheduler.handleExecutorAdded(execId, host)

case ExecutorLost(execId, reason) =>
val filesLost = reason match {
val workerLost = reason match {
case SlaveLost(_, true) => true
case _ => false
}
dagScheduler.handleExecutorLost(execId, filesLost)
dagScheduler.handleExecutorLost(execId, workerLost)

case BeginEvent(task, taskInfo) =>
dagScheduler.handleBeginEvent(task, taskInfo)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,73 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou
assertDataStructuresEmpty()
}

test("All shuffle files on the slave should be cleaned up when slave lost") {
// reset the test context with the right shuffle service config
afterEach()
val conf = new SparkConf()
conf.set("spark.shuffle.service.enabled", "true")
conf.set("spark.files.fetchFailure.unRegisterOutputOnHost", "true")
init(conf)
runEvent(ExecutorAdded("exec-hostA1", "hostA"))
runEvent(ExecutorAdded("exec-hostA2", "hostA"))
runEvent(ExecutorAdded("exec-hostB", "hostB"))
val firstRDD = new MyRDD(sc, 3, Nil)
val firstShuffleDep = new ShuffleDependency(firstRDD, new HashPartitioner(3))
val firstShuffleId = firstShuffleDep.shuffleId
val shuffleMapRdd = new MyRDD(sc, 3, List(firstShuffleDep))
val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(3))
val secondShuffleId = shuffleDep.shuffleId
val reduceRdd = new MyRDD(sc, 1, List(shuffleDep))
submit(reduceRdd, Array(0))
// map stage1 completes successfully, with one task on each executor
complete(taskSets(0), Seq(
(Success,
MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))),
(Success,
MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))),
(Success, makeMapStatus("hostB", 1))
))
// map stage2 completes successfully, with one task on each executor
complete(taskSets(1), Seq(
(Success,
MapStatus(BlockManagerId("exec-hostA1", "hostA", 12345), Array.fill[Long](1)(2))),
(Success,
MapStatus(BlockManagerId("exec-hostA2", "hostA", 12345), Array.fill[Long](1)(2))),
(Success, makeMapStatus("hostB", 1))
))
// make sure our test setup is correct
val initialMapStatus1 = mapOutputTracker.shuffleStatuses(firstShuffleId).mapStatuses
// val initialMapStatus1 = mapOutputTracker.mapStatuses.get(0).get
assert(initialMapStatus1.count(_ != null) === 3)
assert(initialMapStatus1.map{_.location.executorId}.toSet ===
Set("exec-hostA1", "exec-hostA2", "exec-hostB"))

val initialMapStatus2 = mapOutputTracker.shuffleStatuses(secondShuffleId).mapStatuses
// val initialMapStatus1 = mapOutputTracker.mapStatuses.get(0).get
assert(initialMapStatus2.count(_ != null) === 3)
assert(initialMapStatus2.map{_.location.executorId}.toSet ===
Set("exec-hostA1", "exec-hostA2", "exec-hostB"))

// reduce stage fails with a fetch failure from one host
complete(taskSets(2), Seq(
(FetchFailed(BlockManagerId("exec-hostA2", "hostA", 12345), firstShuffleId, 0, 0, "ignored"),
null)
))

// Here is the main assertion -- make sure that we de-register
// the map outputs for both map stage from both executors on hostA

val mapStatus1 = mapOutputTracker.shuffleStatuses(firstShuffleId).mapStatuses
assert(mapStatus1.count(_ != null) === 1)
assert(mapStatus1(2).location.executorId === "exec-hostB")
assert(mapStatus1(2).location.host === "hostB")

val mapStatus2 = mapOutputTracker.shuffleStatuses(secondShuffleId).mapStatuses
assert(mapStatus2.count(_ != null) === 1)
assert(mapStatus2(2).location.executorId === "exec-hostB")
assert(mapStatus2(2).location.host === "hostB")
}

test("zero split job") {
var numResults = 0
var failureReason: Option[Exception] = None
Expand Down