Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class SparkEnv (
val outputCommitCoordinator: OutputCommitCoordinator,
val conf: SparkConf) extends Logging {

var currentStage: Int = -1
private[spark] var isStopped = false
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()

Expand Down
31 changes: 31 additions & 0 deletions core/src/main/scala/org/apache/spark/StageExInfo.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark

import scala.collection.mutable

/**
* DS to store info of a stage.
*/
class StageExInfo(val stageId: Int,
val alreadyPerRddSet: Set[Int], // prs
val afterPerRddSet: Set[Int], // aprs
val depMap: mutable.HashMap[Int, Set[Int]],
val curRunningRddMap: mutable.HashMap[Int, Set[Int]]) {

}
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ private[spark] class CoarseGrainedExecutorBackend(
} else {
val taskDesc = ser.deserialize[TaskDescription](data.value)
logInfo("Got assigned task " + taskDesc.taskId)
val currentStageId = taskDesc.name.substring(taskDesc.name.lastIndexOf(' ') + 1,
taskDesc.name.lastIndexOf('.')).toInt
env.currentStage = currentStageId
env.blockManager.currentStage = currentStageId
// logEarne("this Stage has ExInfo: " + env.stageExInfos(currentStageId))

executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
taskDesc.name, taskDesc.serializedTask)
}
Expand Down
5 changes: 4 additions & 1 deletion core/src/main/scala/org/apache/spark/executor/Executor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,10 @@ private[spark] class Executor(
// for the task.
throw new TaskKilledException
}

if (!env.blockManager.stageExInfos.contains(task.stageId)) {
env.blockManager.stageExInfos.put(task.stageId,
new StageExInfo(task.stageId, null, null, task.depMap, task.curRunningRddMap))
}
logDebug("Task " + taskId + "'s epoch is " + task.epoch)
env.mapOutputTracker.updateEpoch(task.epoch)

Expand Down
85 changes: 71 additions & 14 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.partial.BoundedDouble
import org.apache.spark.partial.CountEvaluator
import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.storage.{BlockExInfo, RDDBlockId, StorageLevel}
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler,
Expand Down Expand Up @@ -208,6 +208,7 @@ abstract class RDD[T: ClassTag](
*/
def unpersist(blocking: Boolean = true): this.type = {
logInfo("Removing RDD " + id + " from persistence list")
sc.dagScheduler.renewDepMap(id)
sc.unpersistRDD(id, blocking)
storageLevel = StorageLevel.NONE
this
Expand Down Expand Up @@ -307,6 +308,31 @@ abstract class RDD[T: ClassTag](
ancestors.filterNot(_ == this).toSeq
}

/**
* Return the ancestors
*/
private[spark] def getNarrowCachedAncestors: Set[Int] = {
val cachedAncestors = new mutable.HashSet[Int]
val ancestors = new mutable.HashSet[RDD[_]]
def visit(rdd: RDD[_]): Unit = {
val narrowDependencies = rdd.dependencies.filter(_.isInstanceOf[NarrowDependency[_]])
val narrowParents = narrowDependencies.map(_.rdd)
val narrowParentsNotVisited = narrowParents.filterNot(ancestors.contains)
narrowParentsNotVisited.foreach { parent =>
ancestors.add(parent)
if (parent.getStorageLevel != StorageLevel.NONE) {
cachedAncestors.add(parent.id)
} else {
visit(parent)
}
}
}

visit(this)

cachedAncestors.filterNot(_ == this.id).toSet
}

/**
* Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
*/
Expand All @@ -328,6 +354,39 @@ abstract class RDD[T: ClassTag](
// This method is called on executors, so we need call SparkEnv.get instead of sc.env.
SparkEnv.get.blockManager.getOrElseUpdate(blockId, storageLevel, elementClassTag, () => {
readCachedBlock = false
val key = blockId
logInfo(s"Partition $key not found, computing it")

val blockManager = SparkEnv.get.blockManager

if (!blockManager.blockExInfo.containsKey(key)) {
blockManager.blockExInfo.put(key, new BlockExInfo(key))
}

blockManager.stageExInfos.get(blockManager.currentStage) match {
case Some(curStageExInfo) =>
var parExist = true
for (par <- curStageExInfo.depMap(id)) {
val parBlockId = new RDDBlockId(par, partition.index)
if (blockManager.blockExInfo.containsKey(parBlockId) &&
blockManager.blockExInfo.get(parBlockId).isExist
== 1) { // par is exist

} else { // par not exist now, add this key to it's par's watching set
parExist = false
if (!blockManager.blockExInfo.containsKey(parBlockId)) {
blockManager.blockExInfo.put(parBlockId, new BlockExInfo(parBlockId))
}
blockManager.blockExInfo.get(parBlockId).sonSet += key
}
}
if (parExist) { // par are all exist so we update this rdd's start time
logTrace("par all exist, store start time of " + key)
blockManager.blockExInfo.get(key).creatStartTime = System.currentTimeMillis()
}
case None =>
logError("Some Thing Wrong")
}
computeOrReadCheckpoint(partition, context)
}) match {
case Left(blockResult) =>
Expand Down Expand Up @@ -483,8 +542,7 @@ abstract class RDD[T: ClassTag](
*
* @param weights weights for splits, will be normalized if they don't sum to 1
* @param seed random seed
*
* @return split RDDs in an array
* @return split RDDs in an array
*/
def randomSplit(
weights: Array[Double],
Expand All @@ -499,7 +557,8 @@ abstract class RDD[T: ClassTag](
/**
* Internal method exposed for Random Splits in DataFrames. Samples an RDD given a probability
* range.
* @param lb lower bound to use for the Bernoulli sampler
*
* @param lb lower bound to use for the Bernoulli sampler
* @param ub upper bound to use for the Bernoulli sampler
* @param seed the seed for the Random number generator
* @return A random sub-sample of the RDD without replacement.
Expand All @@ -517,8 +576,7 @@ abstract class RDD[T: ClassTag](
*
* @note this method should only be used if the resulting array is expected to be small, as
* all the data is loaded into the driver's memory.
*
* @param withReplacement whether sampling is done with replacement
* @param withReplacement whether sampling is done with replacement
* @param num size of the returned sample
* @param seed seed for the random number generator
* @return sample of specified size in an array
Expand Down Expand Up @@ -1244,8 +1302,7 @@ abstract class RDD[T: ClassTag](
*
* @note this method should only be used if the resulting array is expected to be small, as
* all the data is loaded into the driver's memory.
*
* @note due to complications in the internal implementation, this method will raise
* @note due to complications in the internal implementation, this method will raise
* an exception if called on an RDD of `Nothing` or `Null`.
*/
def take(num: Int): Array[T] = withScope {
Expand Down Expand Up @@ -1308,8 +1365,7 @@ abstract class RDD[T: ClassTag](
*
* @note this method should only be used if the resulting array is expected to be small, as
* all the data is loaded into the driver's memory.
*
* @param num k, the number of top elements to return
* @param num k, the number of top elements to return
* @param ord the implicit ordering for T
* @return an array of top elements
*/
Expand All @@ -1331,8 +1387,7 @@ abstract class RDD[T: ClassTag](
*
* @note this method should only be used if the resulting array is expected to be small, as
* all the data is loaded into the driver's memory.
*
* @param num k, the number of elements to return
* @param num k, the number of elements to return
* @param ord the implicit ordering for T
* @return an array of top elements
*/
Expand All @@ -1359,15 +1414,17 @@ abstract class RDD[T: ClassTag](

/**
* Returns the max of this RDD as defined by the implicit Ordering[T].
* @return the maximum element of the RDD
*
* @return the maximum element of the RDD
* */
def max()(implicit ord: Ordering[T]): T = withScope {
this.reduce(ord.max)
}

/**
* Returns the min of this RDD as defined by the implicit Ordering[T].
* @return the minimum element of the RDD
*
* @return the minimum element of the RDD
* */
def min()(implicit ord: Ordering[T]): T = withScope {
this.reduce(ord.min)
Expand Down
44 changes: 36 additions & 8 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,11 @@ class DAGScheduler(

private[scheduler] val activeJobs = new HashSet[ActiveJob]

private[scheduler] var preRDDs = new HashSet[RDD[_]]

private[scheduler] var depMap = new HashMap[Int, Set[Int]]

private[scheduler] var curRunningRddMap = new HashMap[Int, Set[Int]]
/**
* Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids
* and its values are arrays indexed by partition numbers. Each array value is the set of
Expand Down Expand Up @@ -554,11 +559,9 @@ class DAGScheduler(
* @param callSite where in the user program this job was called
* @param resultHandler callback to pass each result to
* @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name
*
* @return a JobWaiter object that can be used to block until the job finishes executing
* @return a JobWaiter object that can be used to block until the job finishes executing
* or can be used to cancel the job.
*
* @throws IllegalArgumentException when partitions ids are illegal
* @throws IllegalArgumentException when partitions ids are illegal
*/
def submitJob[T, U](
rdd: RDD[T],
Expand Down Expand Up @@ -601,8 +604,7 @@ class DAGScheduler(
* @param callSite where in the user program this job was called
* @param resultHandler callback to pass each result to
* @param properties scheduler properties to attach to this job, e.g. fair scheduler pool name
*
* @throws Exception when the job fails
* @throws Exception when the job fails
*/
def runJob[T, U](
rdd: RDD[T],
Expand Down Expand Up @@ -928,6 +930,16 @@ class DAGScheduler(
logDebug("missing: " + missing)
if (missing.isEmpty) {
logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")

val curRDDs = stage.rdd.getNarrowAncestors ++ Seq(stage.rdd)
val newRDDs = curRDDs.filter(!preRDDs.contains(_))
val newCachedRDDs = newRDDs.filter(_.getStorageLevel != StorageLevel.NONE)
curRunningRddMap.clear()
newCachedRDDs.foreach { cachedRdd =>
depMap.put(cachedRdd.id, cachedRdd.getNarrowCachedAncestors)
curRunningRddMap.put(cachedRdd.id, cachedRdd.getNarrowCachedAncestors)
}
preRDDs = preRDDs ++ curRDDs
submitMissingTasks(stage, jobId.get)
} else {
for (parent <- missing) {
Expand All @@ -941,6 +953,22 @@ class DAGScheduler(
}
}

/** Renew depMap when unpersist RDD */
def renewDepMap(id: Int): Unit = {
if (depMap.contains(id)) {
logTrace("Remove RDD " + id + " from depMap")
val value = depMap(id)
depMap.foreach { rdd =>
if (rdd._2.contains(id)) {
val tmp = rdd._2 - id
depMap.put(rdd._1, tmp ++ value)
}
}
depMap.remove(id)
logTrace("After Removed RDD " + id + " the depMap is " + depMap)
}
}

/** Called when stage's parents are available and we can now do its task. */
private def submitMissingTasks(stage: Stage, jobId: Int) {
logDebug("submitMissingTasks(" + stage + ")")
Expand Down Expand Up @@ -1036,7 +1064,7 @@ class DAGScheduler(
val locs = taskIdToLocations(id)
val part = stage.rdd.partitions(id)
new ShuffleMapTask(stage.id, stage.latestInfo.attemptId,
taskBinary, part, locs, stage.internalAccumulators)
taskBinary, part, locs, stage.internalAccumulators, depMap, curRunningRddMap)
}

case stage: ResultStage =>
Expand All @@ -1046,7 +1074,7 @@ class DAGScheduler(
val part = stage.rdd.partitions(p)
val locs = taskIdToLocations(id)
new ResultTask(stage.id, stage.latestInfo.attemptId,
taskBinary, part, locs, id, stage.internalAccumulators)
taskBinary, part, locs, id, stage.internalAccumulators, depMap, curRunningRddMap)
}
}
} catch {
Expand Down
10 changes: 7 additions & 3 deletions core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.scheduler
import java.io._
import java.nio.ByteBuffer

import scala.collection.mutable.HashMap

import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -49,9 +51,11 @@ private[spark] class ResultTask[T, U](
partition: Partition,
locs: Seq[TaskLocation],
val outputId: Int,
_initialAccums: Seq[Accumulator[_]] = InternalAccumulator.createAll())
extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums)
with Serializable {
_initialAccums: Seq[Accumulator[_]] = InternalAccumulator.createAll(),
depMap: HashMap[Int, Set[Int]] = null,
curRunningRddMap: HashMap[Int, Set[Int]] = null)
extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums, depMap,
curRunningRddMap) with Serializable {

@transient private[this] val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.scheduler

import java.nio.ByteBuffer

import scala.collection.mutable.HashMap
import scala.language.existentials

import org.apache.spark._
Expand Down Expand Up @@ -49,13 +50,16 @@ private[spark] class ShuffleMapTask(
taskBinary: Broadcast[Array[Byte]],
partition: Partition,
@transient private var locs: Seq[TaskLocation],
_initialAccums: Seq[Accumulator[_]])
extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums)
_initialAccums: Seq[Accumulator[_]],
depMap: HashMap[Int, Set[Int]],
curRunningRddMap: HashMap[Int, Set[Int]])
extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums, depMap,
curRunningRddMap)
with Logging {

/** A constructor used only in test suites. This does not require passing in an RDD. */
def this(partitionId: Int) {
this(0, 0, null, new Partition { override def index: Int = 0 }, null, null)
this(0, 0, null, new Partition { override def index: Int = 0 }, null, null, null, null)
}

@transient private val preferredLocs: Seq[TaskLocation] = {
Expand Down
5 changes: 4 additions & 1 deletion core/src/main/scala/org/apache/spark/scheduler/Task.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ private[spark] abstract class Task[T](
val stageId: Int,
val stageAttemptId: Int,
val partitionId: Int,
val initialAccumulators: Seq[Accumulator[_]]) extends Serializable {
val initialAccumulators: Seq[Accumulator[_]],
var depMap: HashMap[Int, Set[Int]] = new HashMap[Int, Set[Int]],
var curRunningRddMap: HashMap[Int, Set[Int]] =
new HashMap[Int, Set[Int]]) extends Serializable {

/**
* Called by [[org.apache.spark.executor.Executor]] to run this task.
Expand Down
Loading