Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.rdd

import java.io.{FileNotFoundException, IOException}
import java.io.{FileNotFoundException, InputStream, IOException, OutputStream}

import scala.reflect.ClassTag
import scala.util.control.NonFatal
Expand All @@ -27,8 +27,11 @@ import org.apache.hadoop.fs.Path
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.io.CompressionCodec
import org.apache.spark.util.{SerializableConfiguration, Utils}


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please remove unnecessary space changes


/**
* An RDD that reads from checkpoint files previously written to reliable storage.
*/
Expand Down Expand Up @@ -133,9 +136,14 @@ private[spark] object ReliableCheckpointRDD extends Logging {
val broadcastedConf = sc.broadcast(
new SerializableConfiguration(sc.hadoopConfiguration))
// TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
val startTime = System.currentTimeMillis()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sc.runJob(originalRDD,
writePartitionToCheckpointFile[T](checkpointDirPath.toString, broadcastedConf) _)

logInfo(s"Checkpointing took ${System.currentTimeMillis() - startTime} ms.")
sc.conf.getOption("spark.checkpoint.compress.codec").foreach(codec => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, I suggest we just add a new config spark.checkpoint.compress which means whether to enable checkpoint compression. See

compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) {
for example.

logInfo(s"The checkpoint compression codec is $codec.")
})
if (originalRDD.partitioner.nonEmpty) {
writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath)
}
Expand All @@ -156,7 +164,7 @@ private[spark] object ReliableCheckpointRDD extends Logging {
def writePartitionToCheckpointFile[T: ClassTag](
path: String,
broadcastedConf: Broadcast[SerializableConfiguration],
blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]): Unit = {
val env = SparkEnv.get
val outputDir = new Path(path)
val fs = outputDir.getFileSystem(broadcastedConf.value.value)
Expand All @@ -169,14 +177,23 @@ private[spark] object ReliableCheckpointRDD extends Logging {
val bufferSize = env.conf.getInt("spark.buffer.size", 65536)

val fileOutputStream = if (blockSize < 0) {
fs.create(tempOutputPath, false, bufferSize)
lazy val fileStream: OutputStream = fs.create(tempOutputPath, false, bufferSize)
env.conf.getOption("spark.checkpoint.compress.codec").fold(fileStream) {
codec => {
logDebug(s"Compressing using $codec.")
CompressionCodec.createCodec(env.conf, codec)
.compressedOutputStream(fileStream)
}
}
} else {
// This is mainly for testing purpose
fs.create(tempOutputPath, false, bufferSize,
fs.getDefaultReplication(fs.getWorkingDirectory), blockSize)
}
val serializer = env.serializer.newInstance()
val serializeStream = serializer.serializeStream(fileOutputStream)
logTrace(s"Starting to write to checkpoint file $tempOutputPath.")
val startTimeMs = System.currentTimeMillis()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Utils.tryWithSafeFinally {
serializeStream.writeAll(iterator)
} {
Expand All @@ -197,6 +214,7 @@ private[spark] object ReliableCheckpointRDD extends Logging {
}
}
}
logInfo(s"Checkpointing took ${System.currentTimeMillis() - startTimeMs} ms.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add codec (if used) here.

}

/**
Expand Down Expand Up @@ -273,9 +291,17 @@ private[spark] object ReliableCheckpointRDD extends Logging {
val env = SparkEnv.get
val fs = path.getFileSystem(broadcastedConf.value.value)
val bufferSize = env.conf.getInt("spark.buffer.size", 65536)
val fileInputStream = fs.open(path, bufferSize)
lazy val fileStream: InputStream = fs.open(path, bufferSize)
val inputStream: InputStream =
env.conf.getOption("spark.checkpoint.compress.codec").fold(fileStream) {
codec => {
logDebug(s"Decompressing using $codec.")
CompressionCodec.createCodec(env.conf, codec)
.compressedInputStream(fileStream)
}
}
val serializer = env.serializer.newInstance()
val deserializeStream = serializer.deserializeStream(fileInputStream)
val deserializeStream = serializer.deserializeStream(inputStream)

// Register an on-task-completion callback to close the input stream.
context.addTaskCompletionListener(context => deserializeStream.close())
Expand Down
101 changes: 94 additions & 7 deletions core/src/test/scala/org/apache/spark/CheckpointSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ import java.io.File

import scala.reflect.ClassTag

import com.google.common.io.ByteStreams
import org.apache.hadoop.fs.Path

import org.apache.spark.io.CompressionCodec
import org.apache.spark.rdd._
import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId}
import org.apache.spark.util.Utils


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please move unnecessary changes.

trait RDDCheckpointTester { self: SparkFunSuite =>

protected val partitioner = new HashPartitioner(2)
Expand Down Expand Up @@ -238,6 +241,42 @@ trait RDDCheckpointTester { self: SparkFunSuite =>
protected def generateFatPairRDD(): RDD[(Int, Int)] = {
new FatPairRDD(sparkContext.makeRDD(1 to 100, 4), partitioner).mapValues(x => x)
}

protected def testBasicCheckpoint(sc: SparkContext, reliableCheckpoint: Boolean): Unit = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: does this one test any special logic? If it's covered by other tests, not need to add it to increase the test time.

val parCollection = sc.makeRDD(1 to 4)
val flatMappedRDD = parCollection.flatMap(x => 1 to x)
checkpoint(flatMappedRDD, reliableCheckpoint)
assert(flatMappedRDD.dependencies.head.rdd === parCollection)
val result = flatMappedRDD.collect()
assert(flatMappedRDD.dependencies.head.rdd != parCollection)
assert(flatMappedRDD.collect() === result)
}

protected def testCompression(checkpointDir: File, compressionCodec: String): Unit = {
val sparkConf = new SparkConf()
sparkConf.set("spark.checkpoint.compress.codec", compressionCodec)
val sc = new SparkContext("local", "test", sparkConf)
sc.setCheckpointDir(checkpointDir.toString)
val initialSize = 20
// Use just one partition for now since compression works best on large data sets.
val collection = sc.makeRDD(1 to initialSize, numSlices = 1)
val flatMappedRDD = collection.flatMap(x => 1 to x)
checkpoint(flatMappedRDD, reliableCheckpoint = true)
assert(flatMappedRDD.collect().length == initialSize * (initialSize + 1)/2,
"The checkpoint was lossy!")
sc.stop()
val checkpointPath = new Path(flatMappedRDD.getCheckpointFile.get)
val fs = checkpointPath.getFileSystem(sc.hadoopConfiguration)
val fileStatus = fs.listStatus(checkpointPath).find(_.getPath.getName.startsWith("part-")).get
val compressedSize = fileStatus.getLen
assert(compressedSize > 0, "The checkpoint file was not written!")
val compressedInputStream = CompressionCodec.createCodec(sparkConf, compressionCodec)
.compressedInputStream(fs.open(fileStatus.getPath))
val uncompressedSize = ByteStreams.toByteArray(compressedInputStream).length
compressedInputStream.close()
assert(compressedSize < uncompressedSize, "The compression was not successful!")
}

}

/**
Expand All @@ -251,10 +290,14 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
super.beforeEach()
checkpointDir = File.createTempFile("temp", "", Utils.createTempDir())
checkpointDir.delete()
}

private def startSparkContext(): Unit = {
sc = new SparkContext("local", "test")
sc.setCheckpointDir(checkpointDir.toString)
}


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please move unnecessary changes.

override def afterEach(): Unit = {
try {
Utils.deleteRecursively(checkpointDir)
Expand All @@ -266,13 +309,44 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
override def sparkContext: SparkContext = sc

runTest("basic checkpointing") { reliableCheckpoint: Boolean =>
val parCollection = sc.makeRDD(1 to 4)
val flatMappedRDD = parCollection.flatMap(x => 1 to x)
checkpoint(flatMappedRDD, reliableCheckpoint)
assert(flatMappedRDD.dependencies.head.rdd === parCollection)
val result = flatMappedRDD.collect()
assert(flatMappedRDD.dependencies.head.rdd != parCollection)
assert(flatMappedRDD.collect() === result)
startSparkContext()
testBasicCheckpoint(sc, reliableCheckpoint)
}

runTest("compression with snappy", skipLocalCheckpoint = true) { _: Boolean =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After you change the config to spark.checkpoint.compress, you don't need to test all compression codecs. Just write one test for the default codec. Others should be covered in CompressionCodecSuite.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the new test, I think we just need one simple test. And if we put it into a new suite (e.g., the below example), then we don't need to touch the existing codes.

class CheckpointCompressionSuite extends SparkFunSuite with LocalSparkContext {

  test("checkpoint compression") {
    val checkpointDir = File.createTempFile("temp", "", Utils.createTempDir())
    try {
      val conf = new SparkConf().set("spark.checkpoint.compress", "true")
      sc = new SparkContext("local", "test", conf)
      sc.setCheckpointDir(checkpointDir.toString)
      val rdd = sc.makeRDD(1 to 20, numSlices = 1)
      rdd.checkpoint()
      assert(rdd.collect().toSeq === (1 to 20))
      val checkpointPath = new Path(rdd.getCheckpointFile.get)
      val fs = checkpointPath.getFileSystem(sc.hadoopConfiguration)
      val checkpointFile =
        fs.listStatus(checkpointPath).map(_.getPath).find(_.getName.startsWith("part-")).get

      // Verify the checkpoint file can be decompressed
      val compressedInputStream = CompressionCodec.createCodec(conf)
        .compressedInputStream(fs.open(checkpointFile))
      ByteStreams.toByteArray(compressedInputStream)

      // Verify that the compressed content can be read back
      assert(rdd.collect().toSeq === (1 to 20))
    } finally {
      Utils.deleteRecursively(checkpointDir)
    }
  }
}

val sparkConf = new SparkConf()
sparkConf.set("spark.checkpoint.compress.codec", "snappy")
sc = new SparkContext("local", "test", sparkConf)
sc.setCheckpointDir(checkpointDir.toString)
testBasicCheckpoint(sc, reliableCheckpoint = true)
}

runTest("compression with lz4", skipLocalCheckpoint = true) { _: Boolean =>
val sparkConf = new SparkConf()
sparkConf.set("spark.checkpoint.compress.codec", "lz4")
sc = new SparkContext("local", "test", sparkConf)
sc.setCheckpointDir(checkpointDir.toString)
testBasicCheckpoint(sc, reliableCheckpoint = true)
}

runTest("compression with lzf", skipLocalCheckpoint = true) { _: Boolean =>
val sparkConf = new SparkConf()
sparkConf.set("spark.checkpoint.compress.codec", "lzf")
sc = new SparkContext("local", "test", sparkConf)
sc.setCheckpointDir(checkpointDir.toString)
testBasicCheckpoint(sc, reliableCheckpoint = true)
}

runTest("compression size snappy", skipLocalCheckpoint = true) { _: Boolean =>
testCompression(checkpointDir, "snappy")
}

runTest("compression size lzf", skipLocalCheckpoint = true) { _: Boolean =>
testCompression(checkpointDir, "lzf")
}

runTest("compression size lz4", skipLocalCheckpoint = true) { _: Boolean =>
testCompression(checkpointDir, "lz4")
}

runTest("checkpointing partitioners", skipLocalCheckpoint = true) { _: Boolean =>
Expand Down Expand Up @@ -312,13 +386,15 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
}
}

startSparkContext()
testPartitionerCheckpointing(partitioner)

// Test that corrupted partitioner file does not prevent recovery of RDD
testPartitionerCheckpointing(partitioner, corruptPartitionerFile = true)
}

runTest("RDDs with one-to-one dependencies") { reliableCheckpoint: Boolean =>
startSparkContext()
testRDD(_.map(x => x.toString), reliableCheckpoint)
testRDD(_.flatMap(x => 1 to x), reliableCheckpoint)
testRDD(_.filter(_ % 2 == 0), reliableCheckpoint)
Expand All @@ -332,6 +408,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
}

runTest("ParallelCollectionRDD") { reliableCheckpoint: Boolean =>
startSparkContext()
val parCollection = sc.makeRDD(1 to 4, 2)
val numPartitions = parCollection.partitions.size
checkpoint(parCollection, reliableCheckpoint)
Expand All @@ -348,6 +425,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
}

runTest("BlockRDD") { reliableCheckpoint: Boolean =>
startSparkContext()
val blockId = TestBlockId("id")
val blockManager = SparkEnv.get.blockManager
blockManager.putSingle(blockId, "test", StorageLevel.MEMORY_ONLY)
Expand All @@ -365,19 +443,22 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
}

runTest("ShuffleRDD") { reliableCheckpoint: Boolean =>
startSparkContext()
testRDD(rdd => {
// Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
new ShuffledRDD[Int, Int, Int](rdd.map(x => (x % 2, 1)), partitioner)
}, reliableCheckpoint)
}

runTest("UnionRDD") { reliableCheckpoint: Boolean =>
startSparkContext()
def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1)
testRDD(_.union(otherRDD), reliableCheckpoint)
testRDDPartitions(_.union(otherRDD), reliableCheckpoint)
}

runTest("CartesianRDD") { reliableCheckpoint: Boolean =>
startSparkContext()
def otherRDD: RDD[Int] = sc.makeRDD(1 to 10, 1)
testRDD(new CartesianRDD(sc, _, otherRDD), reliableCheckpoint)
testRDDPartitions(new CartesianRDD(sc, _, otherRDD), reliableCheckpoint)
Expand All @@ -401,6 +482,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
}

runTest("CoalescedRDD") { reliableCheckpoint: Boolean =>
startSparkContext()
testRDD(_.coalesce(2), reliableCheckpoint)
testRDDPartitions(_.coalesce(2), reliableCheckpoint)

Expand All @@ -423,6 +505,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
}

runTest("CoGroupedRDD") { reliableCheckpoint: Boolean =>
startSparkContext()
val longLineageRDD1 = generateFatPairRDD()

// Collect the RDD as sequences instead of arrays to enable equality tests in testRDD
Expand All @@ -441,6 +524,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
}

runTest("ZippedPartitionsRDD") { reliableCheckpoint: Boolean =>
startSparkContext()
testRDD(rdd => rdd.zip(rdd.map(x => x)), reliableCheckpoint)
testRDDPartitions(rdd => rdd.zip(rdd.map(x => x)), reliableCheckpoint)

Expand All @@ -466,6 +550,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
}

runTest("PartitionerAwareUnionRDD") { reliableCheckpoint: Boolean =>
startSparkContext()
testRDD(rdd => {
new PartitionerAwareUnionRDD[(Int, Int)](sc, Array(
generateFatPairRDD(),
Expand Down Expand Up @@ -500,6 +585,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
}

runTest("CheckpointRDD with zero partitions") { reliableCheckpoint: Boolean =>
startSparkContext()
val rdd = new BlockRDD[Int](sc, Array.empty[BlockId])
assert(rdd.partitions.size === 0)
assert(rdd.isCheckpointed === false)
Expand All @@ -514,6 +600,7 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS
}

runTest("checkpointAllMarkedAncestors") { reliableCheckpoint: Boolean =>
startSparkContext()
testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = true)
testCheckpointAllMarkedAncestors(reliableCheckpoint, checkpointAllMarkedAncestors = false)
}
Expand Down