diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 3bca59e0646d..75db8ea058d7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -290,6 +290,26 @@ class DAGScheduler( */ private def getParentStages(rdd: RDD[_], jobId: Int): List[Stage] = { val parents = new HashSet[Stage] + try { + getParentStagesInner(rdd, jobId, parents) + } catch { + case e: Exception => + parents.foreach { stage => + val jobSet = stage.jobIds + jobSet -= jobId + if (jobSet.isEmpty) { + removeStage(stage.id) + } + } + throw e + } + parents.toList + } + + /** + * Inner method for #getParentStages + */ + private def getParentStagesInner(rdd: RDD[_], jobId: Int, parents: HashSet[Stage]) = { val visited = new HashSet[RDD[_]] // We are manually maintaining a stack here to prevent StackOverflowError // caused by recursively visiting @@ -313,7 +333,6 @@ class DAGScheduler( while (!waitingForVisit.isEmpty) { visit(waitingForVisit.pop()) } - parents.toList } // Find ancestor missing shuffle dependencies and register into shuffleToMapStage @@ -410,6 +429,35 @@ class DAGScheduler( updateJobIdStageIdMapsList(List(stage)) } + /** + * Clean up data structures based on Stage for job and any stages + * that are not needed by ant other job. + */ + private def removeStage(stageId: Int) { + // data structures based on Stage + for (stage <- stageIdToStage.get(stageId)) { + if (runningStages.contains(stage)) { + logDebug("Removing running stage %d".format(stageId)) + runningStages -= stage + } + for ((k, v) <- shuffleToMapStage.find(_._2 == stage)) { + shuffleToMapStage.remove(k) + } + if (waitingStages.contains(stage)) { + logDebug("Removing stage %d from waiting set.".format(stageId)) + waitingStages -= stage + } + if (failedStages.contains(stage)) { + logDebug("Removing stage %d from failed set.".format(stageId)) + failedStages -= stage + } + } + // data structures based on StageId + stageIdToStage -= stageId + logDebug("After removal of stage %d, remaining stages = %d" + .format(stageId, stageIdToStage.size)) + } + /** * Removes state for job and any stages that are not needed by any other job. Does not * handle cancelling tasks or notifying the SparkListener about finished jobs/stages/tasks. @@ -429,31 +477,6 @@ class DAGScheduler( "Job %d not registered for stage %d even though that stage was registered for the job" .format(job.jobId, stageId)) } else { - def removeStage(stageId: Int) { - // data structures based on Stage - for (stage <- stageIdToStage.get(stageId)) { - if (runningStages.contains(stage)) { - logDebug("Removing running stage %d".format(stageId)) - runningStages -= stage - } - for ((k, v) <- shuffleToMapStage.find(_._2 == stage)) { - shuffleToMapStage.remove(k) - } - if (waitingStages.contains(stage)) { - logDebug("Removing stage %d from waiting set.".format(stageId)) - waitingStages -= stage - } - if (failedStages.contains(stage)) { - logDebug("Removing stage %d from failed set.".format(stageId)) - failedStages -= stage - } - } - // data structures based on StageId - stageIdToStage -= stageId - logDebug("After removal of stage %d, remaining stages = %d" - .format(stageId, stageIdToStage.size)) - } - jobSet -= job.jobId if (jobSet.isEmpty) { // no other job needs this stage removeStage(stageId)