Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i))
}

checkpointData = Some(new RDDCheckpointData[T](this))
checkpointData = Some(new RDDCheckpointData[T](this, None))
checkpointData.get.cpFile = Some(checkpointPath)

override def getPreferredLocations(split: Partition): Seq[String] = {
Expand Down
34 changes: 24 additions & 10 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1193,22 +1193,36 @@ abstract class RDD[T: ClassTag](
sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
}

/**
* Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
* directory set with SparkContext.setCheckpointDir() and all references to its parent
* RDDs will be removed. This function must be called before any job has been
* executed on this RDD. It is strongly recommended that this RDD is persisted in
* memory, otherwise saving it on a file will require recomputation.
*/
def checkpoint() {
if (context.checkpointDir.isEmpty) {
/** A private method to execute checkpointing with the provided f function */
private[spark] def checkpoint(f: Option[RDD[T] => RDD[T]]) {
if (f.isEmpty && context.checkpointDir.isEmpty) {
throw new SparkException("Checkpoint directory has not been set in the SparkContext")
} else if (checkpointData.isEmpty) {
checkpointData = Some(new RDDCheckpointData(this))
checkpointData = Some(new RDDCheckpointData(this, f))
checkpointData.get.markForCheckpoint()
}
}

/**
* Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
* directory set with SparkContext.setCheckpointDir() and all references to its parent
* RDDs will be removed. This function must be called before any job has been executed
* on this RDD. It is strongly recommended that this RDD is persisted in memory,
* otherwise saving it on a file will require recomputation.
* Not providing any parameters invokes the default implementation.
*/
def checkpoint(): Unit = checkpoint(None)

/**
* Mark this RDD for checkpointing. Its saving and RDD reloading logic will be defined
* by function f (ie. save to a custom file format) and all references to its parent
* RDDs will be removed. This function must be called before any job has been executed
* on this RDD. It is strongly recommended that this RDD is persisted in memory,
* otherwise saving it on a file will require recomputation.
* f should not break the deterministic behavior of RDDs.
*/
def checkpoint(f: RDD[T] => RDD[T]): Unit = checkpoint(Some(f))

/**
* Return whether this RDD has been checkpointed or not
*/
Expand Down
74 changes: 46 additions & 28 deletions core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,16 @@ private[spark] object CheckpointState extends Enumeration {

/**
* This class contains all the information related to RDD checkpointing. Each instance of this
* class is associated with a RDD. It manages process of checkpointing of the associated RDD,
* as well as, manages the post-checkpoint state by providing the updated partitions,
* iterator and preferred locations of the checkpointed RDD.
* class is associated with an RDD. It manages process of checkpointing of the associated RDD,
* as well as, manages the post-checkpoint state by providing the updated partitions, iterator
* and preferred locations of the checkpointed RDD. The default save and reload implementation
* can be overridden by providing a custom saveAndReloadRDD function that can return any kind
* of RDD, not only a CheckpointRDD.
*/
private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
extends Logging with Serializable {
private[spark] class RDDCheckpointData[T: ClassTag](
@transient rdd: RDD[T],
@transient saveAndReloadRDD: Option[RDD[T] => RDD[T]]
) extends Logging with Serializable {

import CheckpointState._

Expand Down Expand Up @@ -82,32 +86,46 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
}
}

// Create the output path for the checkpoint
val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id)
val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
if (!fs.mkdirs(path)) {
throw new SparkException("Failed to create checkpoint path " + path)
}
if (saveAndReloadRDD.isEmpty) {
// Default implementation
// Create the output path for the checkpoint
val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id)
val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
if (!fs.mkdirs(path)) {
throw new SparkException("Failed to create checkpoint path " + path)
}

// Save to file, and reload it as an RDD
val broadcastedConf = rdd.context.broadcast(
new SerializableWritable(rdd.context.hadoopConfiguration))
rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
if (newRDD.partitions.size != rdd.partitions.size) {
throw new SparkException(
"Checkpoint RDD " + newRDD + "(" + newRDD.partitions.size + ") has different " +
"number of partitions than original RDD " + rdd + "(" + rdd.partitions.size + ")")
}
// Save to file, and reload it as an RDD
val broadcastedConf = rdd.context.broadcast(
new SerializableWritable(rdd.context.hadoopConfiguration))
rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
if (newRDD.partitions.size != rdd.partitions.size) {
throw new SparkException(
"Checkpoint RDD " + newRDD + "(" + newRDD.partitions.size + ") has different " +
"number of partitions than original RDD " + rdd + "(" + rdd.partitions.size + ")")
}

// Change the dependencies and partitions of the RDD
RDDCheckpointData.synchronized {
cpFile = Some(path.toString)
cpRDD = Some(newRDD)
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
// Change the dependencies and partitions of the RDD
RDDCheckpointData.synchronized {
cpFile = Some(path.toString)
cpRDD = Some(newRDD)
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
}

logInfo(
"Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id)
} else {
val newRDD = saveAndReloadRDD.get(rdd)
RDDCheckpointData.synchronized {
cpRDD = Some(newRDD)
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
}

logInfo("Done custom checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
}
logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id)
}

// Get preferred location of a split after checkpointing
Expand Down
10 changes: 10 additions & 0 deletions core/src/test/scala/org/apache/spark/CheckpointSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
assert(flatMappedRDD.collect() === result)
}

test("checkpointing with external function") {
val parCollection = sc.makeRDD(1 to 4)
val flatMappedRDD = parCollection.flatMap(x => 1 to x)
flatMappedRDD.checkpoint{ rdd => rdd.sparkContext.makeRDD(rdd.collect, rdd.partitions.size) }
assert(flatMappedRDD.dependencies.head.rdd == parCollection)
val result = flatMappedRDD.collect()
assert(flatMappedRDD.dependencies.head.rdd != parCollection)
assert(flatMappedRDD.collect() === result)
}

test("RDDs with one-to-one dependencies") {
testRDD(_.map(x => x.toString))
testRDD(_.flatMap(x => 1 to x))
Expand Down