Skip to content

Commit fac6085

Browse files
kayousterhoutpwendell
authored andcommitted
[SPARK-1397] Notify SparkListeners when stages fail or are cancelled.
[I wanted to post this for folks to comment but it depends on (and thus includes the changes in) a currently outstanding PR, #305. You can look at just the second commit: kayousterhout@93f08ba to see just the changes relevant to this PR] Previously, when stages fail or get cancelled, the SparkListener is only notified indirectly through the SparkListenerJobEnd, where we sometimes pass in a single stage that failed. This worked before job cancellation, because jobs would only fail due to a single stage failure. However, with job cancellation, multiple running stages can fail when a job gets cancelled. Right now, this is not handled correctly, which results in stages that get stuck in the “Running Stages” window in the UI even though they’re dead. This PR changes the SparkListenerStageCompleted event to a SparkListenerStageEnded event, and uses this event to tell SparkListeners when stages fail in addition to when they complete successfully. This change is NOT publicly backward compatible for two reasons. First, it changes the SparkListener interface. We could alternately add a new event, SparkListenerStageFailed, and keep the existing SparkListenerStageCompleted. However, this is less consistent with the listener events for tasks / jobs ending, and will result in some code duplication for listeners (because failed and completed stages are handled in similar ways). Note that I haven’t finished updating the JSON code to correctly handle the new event because I’m waiting for feedback on whether this is a good or bad idea (hence the “WIP”). It is also not backwards compatible because it changes the publicly visible JobWaiter.jobFailed() method to no longer include a stage that caused the failure. I think this change should definitely stay, because with cancellation (as described above), a failure isn’t necessarily caused by a single stage. Author: Kay Ousterhout <kayousterhout@gmail.com> Closes #309 from kayousterhout/stage_cancellation and squashes the following commits: 5533ecd [Kay Ousterhout] Fixes in response to Mark's review 320c7c7 [Kay Ousterhout] Notify SparkListeners when stages fail or are cancelled.
1 parent e25b593 commit fac6085

File tree

11 files changed

+151
-78
lines changed

11 files changed

+151
-78
lines changed

core/src/main/scala/org/apache/spark/FutureAction.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
141141
private def awaitResult(): Try[T] = {
142142
jobWaiter.awaitResult() match {
143143
case JobSucceeded => scala.util.Success(resultFunc)
144-
case JobFailed(e: Exception, _) => scala.util.Failure(e)
144+
case JobFailed(e: Exception) => scala.util.Failure(e)
145145
}
146146
}
147147
}

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 76 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -342,22 +342,24 @@ class DAGScheduler(
342342
}
343343

344344
/**
345-
* Removes job and any stages that are not needed by any other job. Returns the set of ids for
346-
* stages that were removed. The associated tasks for those stages need to be cancelled if we
347-
* got here via job cancellation.
345+
* Removes state for job and any stages that are not needed by any other job. Does not
346+
* handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks.
347+
*
348+
* @param job The job whose state to cleanup.
349+
* @param resultStage Specifies the result stage for the job; if set to None, this method
350+
* searches resultStagesToJob to find and cleanup the appropriate result stage.
348351
*/
349-
private def removeJobAndIndependentStages(jobId: Int): Set[Int] = {
350-
val registeredStages = jobIdToStageIds(jobId)
351-
val independentStages = new HashSet[Int]()
352-
if (registeredStages.isEmpty) {
353-
logError("No stages registered for job " + jobId)
352+
private def cleanupStateForJobAndIndependentStages(job: ActiveJob, resultStage: Option[Stage]) {
353+
val registeredStages = jobIdToStageIds.get(job.jobId)
354+
if (registeredStages.isEmpty || registeredStages.get.isEmpty) {
355+
logError("No stages registered for job " + job.jobId)
354356
} else {
355-
stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach {
357+
stageIdToJobIds.filterKeys(stageId => registeredStages.get.contains(stageId)).foreach {
356358
case (stageId, jobSet) =>
357-
if (!jobSet.contains(jobId)) {
359+
if (!jobSet.contains(job.jobId)) {
358360
logError(
359361
"Job %d not registered for stage %d even though that stage was registered for the job"
360-
.format(jobId, stageId))
362+
.format(job.jobId, stageId))
361363
} else {
362364
def removeStage(stageId: Int) {
363365
// data structures based on Stage
@@ -394,23 +396,28 @@ class DAGScheduler(
394396
.format(stageId, stageIdToStage.size))
395397
}
396398

397-
jobSet -= jobId
399+
jobSet -= job.jobId
398400
if (jobSet.isEmpty) { // no other job needs this stage
399-
independentStages += stageId
400401
removeStage(stageId)
401402
}
402403
}
403404
}
404405
}
405-
independentStages.toSet
406-
}
406+
jobIdToStageIds -= job.jobId
407+
jobIdToActiveJob -= job.jobId
408+
activeJobs -= job
407409

408-
private def jobIdToStageIdsRemove(jobId: Int) {
409-
if (!jobIdToStageIds.contains(jobId)) {
410-
logDebug("Trying to remove unregistered job " + jobId)
410+
if (resultStage.isEmpty) {
411+
// Clean up result stages.
412+
val resultStagesForJob = resultStageToJob.keySet.filter(
413+
stage => resultStageToJob(stage).jobId == job.jobId)
414+
if (resultStagesForJob.size != 1) {
415+
logWarning(
416+
s"${resultStagesForJob.size} result stages for job ${job.jobId} (expect exactly 1)")
417+
}
418+
resultStageToJob --= resultStagesForJob
411419
} else {
412-
removeJobAndIndependentStages(jobId)
413-
jobIdToStageIds -= jobId
420+
resultStageToJob -= resultStage.get
414421
}
415422
}
416423

@@ -460,7 +467,7 @@ class DAGScheduler(
460467
val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
461468
waiter.awaitResult() match {
462469
case JobSucceeded => {}
463-
case JobFailed(exception: Exception, _) =>
470+
case JobFailed(exception: Exception) =>
464471
logInfo("Failed to run " + callSite)
465472
throw exception
466473
}
@@ -606,7 +613,16 @@ class DAGScheduler(
606613
for (job <- activeJobs) {
607614
val error = new SparkException("Job cancelled because SparkContext was shut down")
608615
job.listener.jobFailed(error)
609-
listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error, -1)))
616+
// Tell the listeners that all of the running stages have ended. Don't bother
617+
// cancelling the stages because if the DAG scheduler is stopped, the entire application
618+
// is in the process of getting stopped.
619+
val stageFailedMessage = "Stage cancelled because SparkContext was shut down"
620+
runningStages.foreach { stage =>
621+
val info = stageToInfos(stage)
622+
info.stageFailed(stageFailedMessage)
623+
listenerBus.post(SparkListenerStageCompleted(info))
624+
}
625+
listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
610626
}
611627
return true
612628
}
@@ -676,7 +692,7 @@ class DAGScheduler(
676692
}
677693
} catch {
678694
case e: Exception =>
679-
jobResult = JobFailed(e, job.finalStage.id)
695+
jobResult = JobFailed(e)
680696
job.listener.jobFailed(e)
681697
} finally {
682698
val s = job.finalStage
@@ -826,11 +842,8 @@ class DAGScheduler(
826842
job.numFinished += 1
827843
// If the whole job has finished, remove it
828844
if (job.numFinished == job.numPartitions) {
829-
jobIdToActiveJob -= stage.jobId
830-
activeJobs -= job
831-
resultStageToJob -= stage
832845
markStageAsFinished(stage)
833-
jobIdToStageIdsRemove(job.jobId)
846+
cleanupStateForJobAndIndependentStages(job, Some(stage))
834847
listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded))
835848
}
836849
job.listener.taskSucceeded(rt.outputId, event.result)
@@ -982,7 +995,7 @@ class DAGScheduler(
982995
if (!jobIdToStageIds.contains(jobId)) {
983996
logDebug("Trying to cancel unregistered job " + jobId)
984997
} else {
985-
failJobAndIndependentStages(jobIdToActiveJob(jobId), s"Job $jobId cancelled")
998+
failJobAndIndependentStages(jobIdToActiveJob(jobId), s"Job $jobId cancelled", None)
986999
}
9871000
}
9881001

@@ -999,7 +1012,8 @@ class DAGScheduler(
9991012
stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis())
10001013
for (resultStage <- dependentStages) {
10011014
val job = resultStageToJob(resultStage)
1002-
failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason")
1015+
failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason",
1016+
Some(resultStage))
10031017
}
10041018
if (dependentStages.isEmpty) {
10051019
logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
@@ -1008,28 +1022,45 @@ class DAGScheduler(
10081022

10091023
/**
10101024
* Fails a job and all stages that are only used by that job, and cleans up relevant state.
1025+
*
1026+
* @param resultStage The result stage for the job, if known. Used to cleanup state for the job
1027+
* slightly more efficiently than when not specified.
10111028
*/
1012-
private def failJobAndIndependentStages(job: ActiveJob, failureReason: String) {
1029+
private def failJobAndIndependentStages(job: ActiveJob, failureReason: String,
1030+
resultStage: Option[Stage]) {
10131031
val error = new SparkException(failureReason)
10141032
job.listener.jobFailed(error)
10151033

1016-
// Cancel all tasks in independent stages.
1017-
val independentStages = removeJobAndIndependentStages(job.jobId)
1018-
independentStages.foreach(taskScheduler.cancelTasks)
1019-
1020-
// Clean up remaining state we store for the job.
1021-
jobIdToActiveJob -= job.jobId
1022-
activeJobs -= job
1023-
jobIdToStageIds -= job.jobId
1024-
val resultStagesForJob = resultStageToJob.keySet.filter(
1025-
stage => resultStageToJob(stage).jobId == job.jobId)
1026-
if (resultStagesForJob.size != 1) {
1027-
logWarning(
1028-
s"${resultStagesForJob.size} result stages for job ${job.jobId} (expect exactly 1)")
1034+
// Cancel all independent, running stages.
1035+
val stages = jobIdToStageIds(job.jobId)
1036+
if (stages.isEmpty) {
1037+
logError("No stages registered for job " + job.jobId)
10291038
}
1030-
resultStageToJob --= resultStagesForJob
1039+
stages.foreach { stageId =>
1040+
val jobsForStage = stageIdToJobIds.get(stageId)
1041+
if (jobsForStage.isEmpty || !jobsForStage.get.contains(job.jobId)) {
1042+
logError(
1043+
"Job %d not registered for stage %d even though that stage was registered for the job"
1044+
.format(job.jobId, stageId))
1045+
} else if (jobsForStage.get.size == 1) {
1046+
if (!stageIdToStage.contains(stageId)) {
1047+
logError("Missing Stage for stage with id $stageId")
1048+
} else {
1049+
// This is the only job that uses this stage, so fail the stage if it is running.
1050+
val stage = stageIdToStage(stageId)
1051+
if (runningStages.contains(stage)) {
1052+
taskScheduler.cancelTasks(stageId)
1053+
val stageInfo = stageToInfos(stage)
1054+
stageInfo.stageFailed(failureReason)
1055+
listenerBus.post(SparkListenerStageCompleted(stageToInfos(stage)))
1056+
}
1057+
}
1058+
}
1059+
}
1060+
1061+
cleanupStateForJobAndIndependentStages(job, resultStage)
10311062

1032-
listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error, job.finalStage.id)))
1063+
listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
10331064
}
10341065

10351066
/**

core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,11 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
191191
*/
192192
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) {
193193
val stageId = stageCompleted.stageInfo.stageId
194-
stageLogInfo(stageId, "STAGE_ID=%d STATUS=COMPLETED".format(stageId))
194+
if (stageCompleted.stageInfo.failureReason.isEmpty) {
195+
stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=COMPLETED")
196+
} else {
197+
stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=FAILED")
198+
}
195199
}
196200

197201
/**
@@ -227,7 +231,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
227231
var info = "JOB_ID=" + jobId
228232
jobEnd.jobResult match {
229233
case JobSucceeded => info += " STATUS=SUCCESS"
230-
case JobFailed(exception, _) =>
234+
case JobFailed(exception) =>
231235
info += " STATUS=FAILED REASON="
232236
exception.getMessage.split("\\s+").foreach(info += _ + "_")
233237
case _ =>

core/src/main/scala/org/apache/spark/scheduler/JobResult.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,5 +24,4 @@ private[spark] sealed trait JobResult
2424

2525
private[spark] case object JobSucceeded extends JobResult
2626

27-
// A failed stage ID of -1 means there is not a particular stage that caused the failure
28-
private[spark] case class JobFailed(exception: Exception, failedStageId: Int) extends JobResult
27+
private[spark] case class JobFailed(exception: Exception) extends JobResult

core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ private[spark] class JobWaiter[T](
6464

6565
override def jobFailed(exception: Exception): Unit = synchronized {
6666
_jobFinished = true
67-
jobResult = JobFailed(exception, -1)
67+
jobResult = JobFailed(exception)
6868
this.notifyAll()
6969
}
7070

core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ private[spark] case object SparkListenerShutdown extends SparkListenerEvent
7171
*/
7272
trait SparkListener {
7373
/**
74-
* Called when a stage is completed, with information on the completed stage
74+
* Called when a stage completes successfully or fails, with information on the completed stage.
7575
*/
7676
def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { }
7777

core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,17 @@ private[spark]
2626
class StageInfo(val stageId: Int, val name: String, val numTasks: Int, val rddInfo: RDDInfo) {
2727
/** When this stage was submitted from the DAGScheduler to a TaskScheduler. */
2828
var submissionTime: Option[Long] = None
29+
/** Time when all tasks in the stage completed or when the stage was cancelled. */
2930
var completionTime: Option[Long] = None
31+
/** If the stage failed, the reason why. */
32+
var failureReason: Option[String] = None
33+
3034
var emittedTaskSizeWarning = false
35+
36+
def stageFailed(reason: String) {
37+
failureReason = Some(reason)
38+
completionTime = Some(System.currentTimeMillis)
39+
}
3140
}
3241

3342
private[spark]

core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,13 @@ private[ui] class JobProgressListener(conf: SparkConf) extends SparkListener {
7474
// Remove by stageId, rather than by StageInfo, in case the StageInfo is from storage
7575
poolToActiveStages(stageIdToPool(stageId)).remove(stageId)
7676
activeStages.remove(stageId)
77-
completedStages += stage
78-
trimIfNecessary(completedStages)
77+
if (stage.failureReason.isEmpty) {
78+
completedStages += stage
79+
trimIfNecessary(completedStages)
80+
} else {
81+
failedStages += stage
82+
trimIfNecessary(failedStages)
83+
}
7984
}
8085

8186
/** If stages is too large, remove and garbage collect old stages */
@@ -215,20 +220,6 @@ private[ui] class JobProgressListener(conf: SparkConf) extends SparkListener {
215220
}
216221
}
217222

218-
override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized {
219-
jobEnd.jobResult match {
220-
case JobFailed(_, stageId) =>
221-
activeStages.get(stageId).foreach { s =>
222-
// Remove by stageId, rather than by StageInfo, in case the StageInfo is from storage
223-
activeStages.remove(s.stageId)
224-
poolToActiveStages(stageIdToPool(stageId)).remove(s.stageId)
225-
failedStages += s
226-
trimIfNecessary(failedStages)
227-
}
228-
case _ =>
229-
}
230-
}
231-
232223
override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) {
233224
synchronized {
234225
val schedulingModeName =

core/src/main/scala/org/apache/spark/util/JsonProtocol.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,14 @@ private[spark] object JsonProtocol {
166166
val rddInfo = rddInfoToJson(stageInfo.rddInfo)
167167
val submissionTime = stageInfo.submissionTime.map(JInt(_)).getOrElse(JNothing)
168168
val completionTime = stageInfo.completionTime.map(JInt(_)).getOrElse(JNothing)
169+
val failureReason = stageInfo.failureReason.map(JString(_)).getOrElse(JNothing)
169170
("Stage ID" -> stageInfo.stageId) ~
170171
("Stage Name" -> stageInfo.name) ~
171172
("Number of Tasks" -> stageInfo.numTasks) ~
172173
("RDD Info" -> rddInfo) ~
173174
("Submission Time" -> submissionTime) ~
174175
("Completion Time" -> completionTime) ~
176+
("Failure Reason" -> failureReason) ~
175177
("Emitted Task Size Warning" -> stageInfo.emittedTaskSizeWarning)
176178
}
177179

@@ -259,9 +261,7 @@ private[spark] object JsonProtocol {
259261
val json = jobResult match {
260262
case JobSucceeded => Utils.emptyJson
261263
case jobFailed: JobFailed =>
262-
val exception = exceptionToJson(jobFailed.exception)
263-
("Exception" -> exception) ~
264-
("Failed Stage ID" -> jobFailed.failedStageId)
264+
JObject("Exception" -> exceptionToJson(jobFailed.exception))
265265
}
266266
("Result" -> result) ~ json
267267
}
@@ -442,11 +442,13 @@ private[spark] object JsonProtocol {
442442
val rddInfo = rddInfoFromJson(json \ "RDD Info")
443443
val submissionTime = Utils.jsonOption(json \ "Submission Time").map(_.extract[Long])
444444
val completionTime = Utils.jsonOption(json \ "Completion Time").map(_.extract[Long])
445+
val failureReason = Utils.jsonOption(json \ "Failure Reason").map(_.extract[String])
445446
val emittedTaskSizeWarning = (json \ "Emitted Task Size Warning").extract[Boolean]
446447

447448
val stageInfo = new StageInfo(stageId, stageName, numTasks, rddInfo)
448449
stageInfo.submissionTime = submissionTime
449450
stageInfo.completionTime = completionTime
451+
stageInfo.failureReason = failureReason
450452
stageInfo.emittedTaskSizeWarning = emittedTaskSizeWarning
451453
stageInfo
452454
}
@@ -561,8 +563,7 @@ private[spark] object JsonProtocol {
561563
case `jobSucceeded` => JobSucceeded
562564
case `jobFailed` =>
563565
val exception = exceptionFromJson(json \ "Exception")
564-
val failedStageId = (json \ "Failed Stage ID").extract[Int]
565-
new JobFailed(exception, failedStageId)
566+
new JobFailed(exception)
566567
}
567568
}
568569

0 commit comments

Comments
 (0)