@@ -33,17 +33,149 @@ import org.mockito.Mockito.mock
3333import org .scalatest .concurrent .Eventually ._
3434import org .scalatest .time .SpanSugar ._
3535
36- import org .apache .spark .TestUtils
36+ import org .apache .spark .{ SparkConf , SparkContext , SparkFunSuite , TestUtils }
3737import org .apache .spark .streaming .dstream .{DStream , FileInputDStream }
3838import org .apache .spark .streaming .scheduler ._
3939import org .apache .spark .util .{MutableURLClassLoader , Clock , ManualClock , Utils }
4040
41+ /**
42+ * A trait of that can be mixed in to get methods for testing DStream operations under
43+ * DStream checkpointing. Note that the implementations of this trait has to implement
44+ * the `setupCheckpointOperation`
45+ */
46+ trait DStreamCheckpointTester { self : SparkFunSuite =>
47+
48+ /**
49+ * Tests a streaming operation under checkpointing, by restarting the operation
50+ * from checkpoint file and verifying whether the final output is correct.
51+ * The output is assumed to have come from a reliable queue which an replay
52+ * data as required.
53+ *
54+ * NOTE: This takes into consideration that the last batch processed before
55+ * master failure will be re-processed after restart/recovery.
56+ */
57+ protected def testCheckpointedOperation [U : ClassTag , V : ClassTag ](
58+ input : Seq [Seq [U ]],
59+ operation : DStream [U ] => DStream [V ],
60+ expectedOutput : Seq [Seq [V ]],
61+ numBatchesBeforeRestart : Int ,
62+ batchDuration : Duration = Milliseconds (500 ),
63+ stopSparkContextAfterTest : Boolean = true
64+ ) {
65+ require(numBatchesBeforeRestart < expectedOutput.size,
66+ " Number of batches before context restart less than number of expected output " +
67+ " (i.e. number of total batches to run)" )
68+ require(StreamingContext .getActive().isEmpty,
69+ " Cannot run test with already active streaming context" )
70+
71+ // Current code assumes that number of batches to be run = number of inputs
72+ val totalNumBatches = input.size
73+ val batchDurationMillis = batchDuration.milliseconds
74+
75+ // Setup the stream computation
76+ val checkpointDir = Utils .createTempDir(this .getClass.getSimpleName()).toString
77+ logDebug(s " Using checkpoint directory $checkpointDir" )
78+ val ssc = createContextForCheckpointOperation(batchDuration)
79+ require(ssc.conf.get(" spark.streaming.clock" ) === classOf [ManualClock ].getName,
80+ " Cannot run test without manual clock in the conf" )
81+
82+ val inputStream = new TestInputStream (ssc, input, numPartitions = 2 )
83+ val operatedStream = operation(inputStream)
84+ operatedStream.print()
85+ val outputStream = new TestOutputStreamWithPartitions (operatedStream,
86+ new ArrayBuffer [Seq [Seq [V ]]] with SynchronizedBuffer [Seq [Seq [V ]]])
87+ outputStream.register()
88+ ssc.checkpoint(checkpointDir)
89+
90+ // Do the computation for initial number of batches, create checkpoint file and quit
91+ val beforeRestartOutput = generateOutput[V ](ssc,
92+ Time (batchDurationMillis * numBatchesBeforeRestart), checkpointDir, stopSparkContextAfterTest)
93+ assertOutput(beforeRestartOutput, expectedOutput, beforeRestart = true )
94+ // Restart and complete the computation from checkpoint file
95+ logInfo(
96+ " \n -------------------------------------------\n " +
97+ " Restarting stream computation " +
98+ " \n -------------------------------------------\n "
99+ )
100+
101+ val restartedSsc = new StreamingContext (checkpointDir)
102+ val afterRestartOutput = generateOutput[V ](restartedSsc,
103+ Time (batchDurationMillis * totalNumBatches), checkpointDir, stopSparkContextAfterTest)
104+ assertOutput(afterRestartOutput, expectedOutput, beforeRestart = false )
105+ }
106+
107+ protected def createContextForCheckpointOperation (batchDuration : Duration ): StreamingContext = {
108+ val conf = new SparkConf ().setMaster(" local" ).setAppName(this .getClass.getSimpleName)
109+ conf.set(" spark.streaming.clock" , classOf [ManualClock ].getName())
110+ new StreamingContext (SparkContext .getOrCreate(conf), batchDuration)
111+ }
112+
113+ private def generateOutput [V : ClassTag ](
114+ ssc : StreamingContext ,
115+ targetBatchTime : Time ,
116+ checkpointDir : String ,
117+ stopSparkContext : Boolean
118+ ): Seq [Seq [V ]] = {
119+ try {
120+ val batchDuration = ssc.graph.batchDuration
121+ val batchCounter = new BatchCounter (ssc)
122+ ssc.start()
123+ val clock = ssc.scheduler.clock.asInstanceOf [ManualClock ]
124+ val currentTime = clock.getTimeMillis()
125+
126+ logInfo(" Manual clock before advancing = " + clock.getTimeMillis())
127+ clock.setTime(targetBatchTime.milliseconds)
128+ logInfo(" Manual clock after advancing = " + clock.getTimeMillis())
129+
130+ val outputStream = ssc.graph.getOutputStreams().filter { dstream =>
131+ dstream.isInstanceOf [TestOutputStreamWithPartitions [V ]]
132+ }.head.asInstanceOf [TestOutputStreamWithPartitions [V ]]
133+
134+ eventually(timeout(10 seconds)) {
135+ ssc.awaitTerminationOrTimeout(10 )
136+ assert(batchCounter.getLastCompletedBatchTime === targetBatchTime)
137+ }
138+
139+ eventually(timeout(10 seconds)) {
140+ val checkpointFilesOfLatestTime = Checkpoint .getCheckpointFiles(checkpointDir).filter {
141+ _.toString.contains(clock.getTimeMillis.toString)
142+ }
143+ // Checkpoint files are written twice for every batch interval. So assert that both
144+ // are written to make sure that both of them have been written.
145+ assert(checkpointFilesOfLatestTime.size === 2 )
146+ }
147+ outputStream.output.map(_.flatten)
148+
149+ } finally {
150+ ssc.stop(stopSparkContext = stopSparkContext)
151+ }
152+ }
153+
154+ private def assertOutput [V : ClassTag ](
155+ output : Seq [Seq [V ]],
156+ expectedOutput : Seq [Seq [V ]],
157+ beforeRestart : Boolean ): Unit = {
158+ val expectedPartialOutput = if (beforeRestart) {
159+ expectedOutput.take(output.size)
160+ } else {
161+ expectedOutput.takeRight(output.size)
162+ }
163+ val setComparison = output.zip(expectedPartialOutput).forall {
164+ case (o, e) => o.toSet === e.toSet
165+ }
166+ assert(setComparison, s " set comparison failed \n " +
167+ s " Expected output items: \n ${expectedPartialOutput.mkString(" \n " )}\n " +
168+ s " Generated output items: ${output.mkString(" \n " )}"
169+ )
170+ }
171+ }
172+
41173/**
42174 * This test suites tests the checkpointing functionality of DStreams -
43175 * the checkpointing of a DStream's RDDs as well as the checkpointing of
44176 * the whole DStream graph.
45177 */
46- class CheckpointSuite extends TestSuiteBase {
178+ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester {
47179
48180 var ssc : StreamingContext = null
49181
@@ -56,7 +188,7 @@ class CheckpointSuite extends TestSuiteBase {
56188
57189 override def afterFunction () {
58190 super .afterFunction()
59- if (ssc != null ) ssc.stop()
191+ if (ssc != null ) { ssc.stop() }
60192 Utils .deleteRecursively(new File (checkpointDir))
61193 }
62194
@@ -251,7 +383,9 @@ class CheckpointSuite extends TestSuiteBase {
251383 Seq ((" " , 2 )),
252384 Seq (),
253385 Seq ((" a" , 2 ), (" b" , 1 )),
254- Seq ((" " , 2 )), Seq () ),
386+ Seq ((" " , 2 )),
387+ Seq ()
388+ ),
255389 3
256390 )
257391 }
@@ -634,53 +768,6 @@ class CheckpointSuite extends TestSuiteBase {
634768 checkpointWriter.stop()
635769 }
636770
637- /**
638- * Tests a streaming operation under checkpointing, by restarting the operation
639- * from checkpoint file and verifying whether the final output is correct.
640- * The output is assumed to have come from a reliable queue which an replay
641- * data as required.
642- *
643- * NOTE: This takes into consideration that the last batch processed before
644- * master failure will be re-processed after restart/recovery.
645- */
646- def testCheckpointedOperation [U : ClassTag , V : ClassTag ](
647- input : Seq [Seq [U ]],
648- operation : DStream [U ] => DStream [V ],
649- expectedOutput : Seq [Seq [V ]],
650- initialNumBatches : Int
651- ) {
652-
653- // Current code assumes that:
654- // number of inputs = number of outputs = number of batches to be run
655- val totalNumBatches = input.size
656- val nextNumBatches = totalNumBatches - initialNumBatches
657- val initialNumExpectedOutputs = initialNumBatches
658- val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1
659- // because the last batch will be processed again
660-
661- // Do the computation for initial number of batches, create checkpoint file and quit
662- ssc = setupStreams[U , V ](input, operation)
663- ssc.start()
664- val output = advanceTimeWithRealDelay[V ](ssc, initialNumBatches)
665- ssc.stop()
666- verifyOutput[V ](output, expectedOutput.take(initialNumBatches), true )
667- Thread .sleep(1000 )
668-
669- // Restart and complete the computation from checkpoint file
670- logInfo(
671- " \n -------------------------------------------\n " +
672- " Restarting stream computation " +
673- " \n -------------------------------------------\n "
674- )
675- ssc = new StreamingContext (checkpointDir)
676- ssc.start()
677- val outputNew = advanceTimeWithRealDelay[V ](ssc, nextNumBatches)
678- // the first element will be re-processed data of the last batch before restart
679- verifyOutput[V ](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true )
680- ssc.stop()
681- ssc = null
682- }
683-
684771 /**
685772 * Advances the manual clock on the streaming scheduler by given number of batches.
686773 * It also waits for the expected amount of time for each batch.
0 commit comments