Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.util

import java.io.{ObjectInputStream, ObjectOutputStream}

import scala.util.control.NonFatal

import org.apache.hadoop.mapred.JobConf

private[spark]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,8 @@ class PairDStreamFunctions[K, V](self: DStream[(K, V)])
val serializableConf = new SerializableJobConf(conf)
val saveFunc = (rdd: RDD[(K, V)], time: Time) => {
val file = rddToFileName(prefix, suffix, time)
rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass, serializableConf.value)
rdd.saveAsHadoopFile(file, keyClass, valueClass, outputFormatClass,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the same change done in #10088

new JobConf(serializableConf.value))
}
self.foreachRDD(saveFunc)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,22 +132,37 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
/** Method that generates a RDD for the given time */
override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = {
// Get the previous state or create a new empty state RDD
val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse {
TrackStateRDD.createFromPairRDD[K, V, S, E](
spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
partitioner, validTime
)
val prevStateRDD = getOrCompute(validTime - slideDuration) match {
case Some(rdd) =>
if (rdd.partitioner != Some(partitioner)) {
// If the RDD is not partitioned the right way, let us repartition it using the
// partition index as the key. This is to ensure that state RDD is always partitioned
// before creating another state RDD using it
TrackStateRDD.createFromRDD[K, V, S, E](
rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
} else {
rdd
}
case None =>
TrackStateRDD.createFromPairRDD[K, V, S, E](
spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
partitioner,
validTime
)
}


// Compute the new state RDD with previous state RDD and partitioned data RDD
parent.getOrCompute(validTime).map { dataRDD =>
val partitionedDataRDD = dataRDD.partitionBy(partitioner)
val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
(validTime - interval).milliseconds
}
new TrackStateRDD(
prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime)
// Even if there is no data RDD, use an empty one to create a new state RDD
val dataRDD = parent.getOrCompute(validTime).getOrElse {
context.sparkContext.emptyRDD[(K, V)]
}
val partitionedDataRDD = dataRDD.partitionBy(partitioner)
val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
(validTime - interval).milliseconds
}
Some(new TrackStateRDD(
prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,22 +179,43 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:

private[streaming] object TrackStateRDD {

def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
pairRDD: RDD[(K, S)],
partitioner: Partitioner,
updateTime: Time): TrackStateRDD[K, V, S, T] = {
updateTime: Time): TrackStateRDD[K, V, S, E] = {

val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) }
Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T]))
Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
}, preservesPartitioning = true)

val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)

val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None

new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
}

def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
rdd: RDD[(K, S, Long)],
partitioner: Partitioner,
updateTime: Time): TrackStateRDD[K, V, S, E] = {

val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) }
val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions({ iterator =>
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
iterator.foreach { case (key, (state, updateTime)) =>
stateMap.put(key, state, updateTime)
}
Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
}, preservesPartitioning = true)

val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)

val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None

new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,135 @@ import org.mockito.Mockito.mock
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._

import org.apache.spark.TestUtils
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils}
import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
import org.apache.spark.streaming.scheduler._
import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils}

/**
* A trait of that can be mixed in to get methods for testing DStream operations under
* DStream checkpointing. Note that the implementations of this trait has to implement
* the `setupCheckpointOperation`
*/
trait DStreamCheckpointTester { self: SparkFunSuite =>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is refactoring where I extract out the testCheckpointedOperation so that it can be used in other unit tests.


/**
* Tests a streaming operation under checkpointing, by restarting the operation
* from checkpoint file and verifying whether the final output is correct.
* The output is assumed to have come from a reliable queue which an replay
* data as required.
*
* NOTE: This takes into consideration that the last batch processed before
* master failure will be re-processed after restart/recovery.
*/
protected def testCheckpointedOperation[U: ClassTag, V: ClassTag](
input: Seq[Seq[U]],
operation: DStream[U] => DStream[V],
expectedOutput: Seq[Seq[V]],
numBatchesBeforeRestart: Int,
batchDuration: Duration = Milliseconds(500),
stopSparkContextAfterTest: Boolean = true
) {
require(numBatchesBeforeRestart < expectedOutput.size,
"Number of batches before context restart less than number of expected output " +
"(i.e. number of total batches to run)")
require(StreamingContext.getActive().isEmpty,
"Cannot run test with already active streaming context")

// Current code assumes that:
// number of inputs = number of outputs = number of batches to be run
val totalNumBatches = input.size
val nextNumBatches = totalNumBatches - numBatchesBeforeRestart
val initialNumExpectedOutputs = numBatchesBeforeRestart
val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1
// because the last batch will be processed again

// Setup the stream computation
val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString
logDebug(s"Using checkpoint directory $checkpointDir")
val ssc = createContextForCheckpointOperation(batchDuration)
require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName,
"Cannot run test without manual clock in the conf")

val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
val operatedStream = operation(inputStream)
val outputStream = new TestOutputStreamWithPartitions(operatedStream,
new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]])
outputStream.register()
ssc.checkpoint(checkpointDir)

// Do the computation for initial number of batches, create checkpoint file and quit
generateAndAssertOutput[V](ssc, batchDuration, checkpointDir, numBatchesBeforeRestart,
expectedOutput.take(numBatchesBeforeRestart), stopSparkContextAfterTest)

// Restart and complete the computation from checkpoint file
logInfo(
"\n-------------------------------------------\n" +
" Restarting stream computation " +
"\n-------------------------------------------\n"
)
val restartedSsc = new StreamingContext(checkpointDir)
generateAndAssertOutput[V](restartedSsc, batchDuration, checkpointDir, nextNumBatches,
expectedOutput.takeRight(nextNumExpectedOutputs), stopSparkContextAfterTest)
}

protected def createContextForCheckpointOperation(batchDuration: Duration): StreamingContext = {
val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName)
conf.set("spark.streaming.clock", classOf[ManualClock].getName())
new StreamingContext(SparkContext.getOrCreate(conf), batchDuration)
}

private def generateAndAssertOutput[V: ClassTag](
ssc: StreamingContext,
batchDuration: Duration,
checkpointDir: String,
numBatchesToRun: Int,
expectedOutput: Seq[Seq[V]],
stopSparkContext: Boolean
) {
try {
val batchCounter = new BatchCounter(ssc)
ssc.start()
val numBatches = expectedOutput.size
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
logDebug("Manual clock before advancing = " + clock.getTimeMillis())
clock.advance((batchDuration * numBatches).milliseconds)
logDebug("Manual clock after advancing = " + clock.getTimeMillis())

val outputStream = ssc.graph.getOutputStreams().filter { dstream =>
dstream.isInstanceOf[TestOutputStreamWithPartitions[V]]
}.head.asInstanceOf[TestOutputStreamWithPartitions[V]]

eventually(timeout(10 seconds)) {
ssc.awaitTerminationOrTimeout(10)
assert(batchCounter.getNumCompletedBatches === numBatchesToRun)
}

eventually(timeout(10 seconds)) {
Checkpoint.getCheckpointFiles(checkpointDir).exists {
_.toString.contains(clock.getTimeMillis.toString)
}
}

val output = outputStream.output.map(_.flatten)
assert(
output.zip(expectedOutput).forall { case (o, e) => o.toSet === e.toSet },
s"Set comparison failed\n" +
s"Expected output (${expectedOutput.size} items):\n${expectedOutput.mkString("\n")}\n" +
s"Generated output (${output.size} items): ${output.mkString("\n")}"
)
} finally {
ssc.stop(stopSparkContext = stopSparkContext)
}
}
}

/**
* This test suites tests the checkpointing functionality of DStreams -
* the checkpointing of a DStream's RDDs as well as the checkpointing of
* the whole DStream graph.
*/
class CheckpointSuite extends TestSuiteBase {
class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester {

var ssc: StreamingContext = null

Expand All @@ -56,7 +174,7 @@ class CheckpointSuite extends TestSuiteBase {

override def afterFunction() {
super.afterFunction()
if (ssc != null) ssc.stop()
if (ssc != null) { ssc.stop() }
Utils.deleteRecursively(new File(checkpointDir))
}

Expand Down Expand Up @@ -634,53 +752,6 @@ class CheckpointSuite extends TestSuiteBase {
checkpointWriter.stop()
}

/**
* Tests a streaming operation under checkpointing, by restarting the operation
* from checkpoint file and verifying whether the final output is correct.
* The output is assumed to have come from a reliable queue which an replay
* data as required.
*
* NOTE: This takes into consideration that the last batch processed before
* master failure will be re-processed after restart/recovery.
*/
def testCheckpointedOperation[U: ClassTag, V: ClassTag](
input: Seq[Seq[U]],
operation: DStream[U] => DStream[V],
expectedOutput: Seq[Seq[V]],
initialNumBatches: Int
) {

// Current code assumes that:
// number of inputs = number of outputs = number of batches to be run
val totalNumBatches = input.size
val nextNumBatches = totalNumBatches - initialNumBatches
val initialNumExpectedOutputs = initialNumBatches
val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1
// because the last batch will be processed again

// Do the computation for initial number of batches, create checkpoint file and quit
ssc = setupStreams[U, V](input, operation)
ssc.start()
val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches)
ssc.stop()
verifyOutput[V](output, expectedOutput.take(initialNumBatches), true)
Thread.sleep(1000)

// Restart and complete the computation from checkpoint file
logInfo(
"\n-------------------------------------------\n" +
" Restarting stream computation " +
"\n-------------------------------------------\n"
)
ssc = new StreamingContext(checkpointDir)
ssc.start()
val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches)
// the first element will be re-processed data of the last batch before restart
verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true)
ssc.stop()
ssc = null
}

/**
* Advances the manual clock on the streaming scheduler by given number of batches.
* It also waits for the expected amount of time for each batch.
Expand Down
Loading