diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 7ba1182f0ed27..45feac983407c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -59,7 +59,7 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i)) } - checkpointData = Some(new RDDCheckpointData[T](this)) + checkpointData = Some(new RDDCheckpointData[T](this, None)) checkpointData.get.cpFile = Some(checkpointPath) override def getPreferredLocations(split: Partition): Seq[String] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e4025bcf48db6..98c07fbea398d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1193,22 +1193,36 @@ abstract class RDD[T: ClassTag]( sc.runJob(this, (iter: Iterator[T]) => iter.toArray) } - /** - * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint - * directory set with SparkContext.setCheckpointDir() and all references to its parent - * RDDs will be removed. This function must be called before any job has been - * executed on this RDD. It is strongly recommended that this RDD is persisted in - * memory, otherwise saving it on a file will require recomputation. - */ - def checkpoint() { - if (context.checkpointDir.isEmpty) { + /** A private method to execute checkpointing with the provided f function */ + private[spark] def checkpoint(f: Option[RDD[T] => RDD[T]]) { + if (f.isEmpty && context.checkpointDir.isEmpty) { throw new SparkException("Checkpoint directory has not been set in the SparkContext") } else if (checkpointData.isEmpty) { - checkpointData = Some(new RDDCheckpointData(this)) + checkpointData = Some(new RDDCheckpointData(this, f)) checkpointData.get.markForCheckpoint() } } + /** + * Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint + * directory set with SparkContext.setCheckpointDir() and all references to its parent + * RDDs will be removed. This function must be called before any job has been executed + * on this RDD. It is strongly recommended that this RDD is persisted in memory, + * otherwise saving it on a file will require recomputation. + * Not providing any parameters invokes the default implementation. + */ + def checkpoint(): Unit = checkpoint(None) + + /** + * Mark this RDD for checkpointing. Its saving and RDD reloading logic will be defined + * by function f (ie. save to a custom file format) and all references to its parent + * RDDs will be removed. This function must be called before any job has been executed + * on this RDD. It is strongly recommended that this RDD is persisted in memory, + * otherwise saving it on a file will require recomputation. + * f should not break the deterministic behavior of RDDs. + */ + def checkpoint(f: RDD[T] => RDD[T]): Unit = checkpoint(Some(f)) + /** * Return whether this RDD has been checkpointed or not */ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index f67e5f1857979..2fcc31ddffc15 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -35,12 +35,16 @@ private[spark] object CheckpointState extends Enumeration { /** * This class contains all the information related to RDD checkpointing. Each instance of this - * class is associated with a RDD. It manages process of checkpointing of the associated RDD, - * as well as, manages the post-checkpoint state by providing the updated partitions, - * iterator and preferred locations of the checkpointed RDD. + * class is associated with an RDD. It manages process of checkpointing of the associated RDD, + * as well as, manages the post-checkpoint state by providing the updated partitions, iterator + * and preferred locations of the checkpointed RDD. The default save and reload implementation + * can be overridden by providing a custom saveAndReloadRDD function that can return any kind + * of RDD, not only a CheckpointRDD. */ -private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) - extends Logging with Serializable { +private[spark] class RDDCheckpointData[T: ClassTag]( + @transient rdd: RDD[T], + @transient saveAndReloadRDD: Option[RDD[T] => RDD[T]] + ) extends Logging with Serializable { import CheckpointState._ @@ -82,32 +86,46 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) } } - // Create the output path for the checkpoint - val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id) - val fs = path.getFileSystem(rdd.context.hadoopConfiguration) - if (!fs.mkdirs(path)) { - throw new SparkException("Failed to create checkpoint path " + path) - } + if (saveAndReloadRDD.isEmpty) { + // Default implementation + // Create the output path for the checkpoint + val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id) + val fs = path.getFileSystem(rdd.context.hadoopConfiguration) + if (!fs.mkdirs(path)) { + throw new SparkException("Failed to create checkpoint path " + path) + } - // Save to file, and reload it as an RDD - val broadcastedConf = rdd.context.broadcast( - new SerializableWritable(rdd.context.hadoopConfiguration)) - rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) - val newRDD = new CheckpointRDD[T](rdd.context, path.toString) - if (newRDD.partitions.size != rdd.partitions.size) { - throw new SparkException( - "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.size + ") has different " + - "number of partitions than original RDD " + rdd + "(" + rdd.partitions.size + ")") - } + // Save to file, and reload it as an RDD + val broadcastedConf = rdd.context.broadcast( + new SerializableWritable(rdd.context.hadoopConfiguration)) + rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _) + val newRDD = new CheckpointRDD[T](rdd.context, path.toString) + if (newRDD.partitions.size != rdd.partitions.size) { + throw new SparkException( + "Checkpoint RDD " + newRDD + "(" + newRDD.partitions.size + ") has different " + + "number of partitions than original RDD " + rdd + "(" + rdd.partitions.size + ")") + } - // Change the dependencies and partitions of the RDD - RDDCheckpointData.synchronized { - cpFile = Some(path.toString) - cpRDD = Some(newRDD) - rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions - cpState = Checkpointed + // Change the dependencies and partitions of the RDD + RDDCheckpointData.synchronized { + cpFile = Some(path.toString) + cpRDD = Some(newRDD) + rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions + cpState = Checkpointed + } + + logInfo( + "Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) + } else { + val newRDD = saveAndReloadRDD.get(rdd) + RDDCheckpointData.synchronized { + cpRDD = Some(newRDD) + rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions + cpState = Checkpointed + } + + logInfo("Done custom checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id) } - logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) } // Get preferred location of a split after checkpointing diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index a41914a1a9d0c..e0d1ccea8ecdc 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -56,6 +56,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { assert(flatMappedRDD.collect() === result) } + test("checkpointing with external function") { + val parCollection = sc.makeRDD(1 to 4) + val flatMappedRDD = parCollection.flatMap(x => 1 to x) + flatMappedRDD.checkpoint{ rdd => rdd.sparkContext.makeRDD(rdd.collect, rdd.partitions.size) } + assert(flatMappedRDD.dependencies.head.rdd == parCollection) + val result = flatMappedRDD.collect() + assert(flatMappedRDD.dependencies.head.rdd != parCollection) + assert(flatMappedRDD.collect() === result) + } + test("RDDs with one-to-one dependencies") { testRDD(_.map(x => x.toString)) testRDD(_.flatMap(x => 1 to x))