Skip to content

Commit 5d80d8c

Browse files
committed
[SPARK-11932][STREAMING] Partition previous TrackStateRDD if partitioner not present
The reason is that TrackStateRDDs generated by trackStateByKey expect the previous batch's TrackStateRDDs to have a partitioner. However, when recovery from DStream checkpoints, the RDDs recovered from RDD checkpoints do not have a partitioner attached to it. This is because RDD checkpoints do not preserve the partitioner (SPARK-12004). While #9983 solves SPARK-12004 by preserving the partitioner through RDD checkpoints, there may be a non-zero chance that the saving and recovery fails. To be resilient, this PR repartitions the previous state RDD if the partitioner is not detected. Author: Tathagata Das <[email protected]> Closes #9988 from tdas/SPARK-11932.
1 parent ef3f047 commit 5d80d8c

File tree

6 files changed

+258
-84
lines changed

6 files changed

+258
-84
lines changed

streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ class CheckpointWriter(
277277
val bytes = Checkpoint.serialize(checkpoint, conf)
278278
executor.execute(new CheckpointWriteHandler(
279279
checkpoint.checkpointTime, bytes, clearCheckpointDataLater))
280-
logDebug("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue")
280+
logInfo("Submitted checkpoint of time " + checkpoint.checkpointTime + " writer queue")
281281
} catch {
282282
case rej: RejectedExecutionException =>
283283
logError("Could not submit checkpoint task to the thread pool executor", rej)

streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -132,22 +132,37 @@ class InternalTrackStateDStream[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassT
132132
/** Method that generates a RDD for the given time */
133133
override def compute(validTime: Time): Option[RDD[TrackStateRDDRecord[K, S, E]]] = {
134134
// Get the previous state or create a new empty state RDD
135-
val prevStateRDD = getOrCompute(validTime - slideDuration).getOrElse {
136-
TrackStateRDD.createFromPairRDD[K, V, S, E](
137-
spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
138-
partitioner, validTime
139-
)
135+
val prevStateRDD = getOrCompute(validTime - slideDuration) match {
136+
case Some(rdd) =>
137+
if (rdd.partitioner != Some(partitioner)) {
138+
// If the RDD is not partitioned the right way, let us repartition it using the
139+
// partition index as the key. This is to ensure that state RDD is always partitioned
140+
// before creating another state RDD using it
141+
TrackStateRDD.createFromRDD[K, V, S, E](
142+
rdd.flatMap { _.stateMap.getAll() }, partitioner, validTime)
143+
} else {
144+
rdd
145+
}
146+
case None =>
147+
TrackStateRDD.createFromPairRDD[K, V, S, E](
148+
spec.getInitialStateRDD().getOrElse(new EmptyRDD[(K, S)](ssc.sparkContext)),
149+
partitioner,
150+
validTime
151+
)
140152
}
141153

154+
142155
// Compute the new state RDD with previous state RDD and partitioned data RDD
143-
parent.getOrCompute(validTime).map { dataRDD =>
144-
val partitionedDataRDD = dataRDD.partitionBy(partitioner)
145-
val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
146-
(validTime - interval).milliseconds
147-
}
148-
new TrackStateRDD(
149-
prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime)
156+
// Even if there is no data RDD, use an empty one to create a new state RDD
157+
val dataRDD = parent.getOrCompute(validTime).getOrElse {
158+
context.sparkContext.emptyRDD[(K, V)]
159+
}
160+
val partitionedDataRDD = dataRDD.partitionBy(partitioner)
161+
val timeoutThresholdTime = spec.getTimeoutInterval().map { interval =>
162+
(validTime - interval).milliseconds
150163
}
164+
Some(new TrackStateRDD(
165+
prevStateRDD, partitionedDataRDD, trackingFunction, validTime, timeoutThresholdTime))
151166
}
152167
}
153168

streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,22 +179,43 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E:
179179

180180
private[streaming] object TrackStateRDD {
181181

182-
def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
182+
def createFromPairRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
183183
pairRDD: RDD[(K, S)],
184184
partitioner: Partitioner,
185-
updateTime: Time): TrackStateRDD[K, V, S, T] = {
185+
updateTime: Time): TrackStateRDD[K, V, S, E] = {
186186

187187
val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions ({ iterator =>
188188
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
189189
iterator.foreach { case (key, state) => stateMap.put(key, state, updateTime.milliseconds) }
190-
Iterator(TrackStateRDDRecord(stateMap, Seq.empty[T]))
190+
Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
191191
}, preservesPartitioning = true)
192192

193193
val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
194194

195195
val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
196196

197-
new TrackStateRDD[K, V, S, T](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
197+
new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
198+
}
199+
200+
def createFromRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
201+
rdd: RDD[(K, S, Long)],
202+
partitioner: Partitioner,
203+
updateTime: Time): TrackStateRDD[K, V, S, E] = {
204+
205+
val pairRDD = rdd.map { x => (x._1, (x._2, x._3)) }
206+
val rddOfTrackStateRecords = pairRDD.partitionBy(partitioner).mapPartitions({ iterator =>
207+
val stateMap = StateMap.create[K, S](SparkEnv.get.conf)
208+
iterator.foreach { case (key, (state, updateTime)) =>
209+
stateMap.put(key, state, updateTime)
210+
}
211+
Iterator(TrackStateRDDRecord(stateMap, Seq.empty[E]))
212+
}, preservesPartitioning = true)
213+
214+
val emptyDataRDD = pairRDD.sparkContext.emptyRDD[(K, V)].partitionBy(partitioner)
215+
216+
val noOpFunc = (time: Time, key: K, value: Option[V], state: State[S]) => None
217+
218+
new TrackStateRDD[K, V, S, E](rddOfTrackStateRecords, emptyDataRDD, noOpFunc, updateTime, None)
198219
}
199220
}
200221

streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala

Lines changed: 138 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,149 @@ import org.mockito.Mockito.mock
3333
import org.scalatest.concurrent.Eventually._
3434
import org.scalatest.time.SpanSugar._
3535

36-
import org.apache.spark.TestUtils
36+
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils}
3737
import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
3838
import org.apache.spark.streaming.scheduler._
3939
import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils}
4040

41+
/**
42+
* A trait of that can be mixed in to get methods for testing DStream operations under
43+
* DStream checkpointing. Note that the implementations of this trait has to implement
44+
* the `setupCheckpointOperation`
45+
*/
46+
trait DStreamCheckpointTester { self: SparkFunSuite =>
47+
48+
/**
49+
* Tests a streaming operation under checkpointing, by restarting the operation
50+
* from checkpoint file and verifying whether the final output is correct.
51+
* The output is assumed to have come from a reliable queue which an replay
52+
* data as required.
53+
*
54+
* NOTE: This takes into consideration that the last batch processed before
55+
* master failure will be re-processed after restart/recovery.
56+
*/
57+
protected def testCheckpointedOperation[U: ClassTag, V: ClassTag](
58+
input: Seq[Seq[U]],
59+
operation: DStream[U] => DStream[V],
60+
expectedOutput: Seq[Seq[V]],
61+
numBatchesBeforeRestart: Int,
62+
batchDuration: Duration = Milliseconds(500),
63+
stopSparkContextAfterTest: Boolean = true
64+
) {
65+
require(numBatchesBeforeRestart < expectedOutput.size,
66+
"Number of batches before context restart less than number of expected output " +
67+
"(i.e. number of total batches to run)")
68+
require(StreamingContext.getActive().isEmpty,
69+
"Cannot run test with already active streaming context")
70+
71+
// Current code assumes that number of batches to be run = number of inputs
72+
val totalNumBatches = input.size
73+
val batchDurationMillis = batchDuration.milliseconds
74+
75+
// Setup the stream computation
76+
val checkpointDir = Utils.createTempDir(this.getClass.getSimpleName()).toString
77+
logDebug(s"Using checkpoint directory $checkpointDir")
78+
val ssc = createContextForCheckpointOperation(batchDuration)
79+
require(ssc.conf.get("spark.streaming.clock") === classOf[ManualClock].getName,
80+
"Cannot run test without manual clock in the conf")
81+
82+
val inputStream = new TestInputStream(ssc, input, numPartitions = 2)
83+
val operatedStream = operation(inputStream)
84+
operatedStream.print()
85+
val outputStream = new TestOutputStreamWithPartitions(operatedStream,
86+
new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]])
87+
outputStream.register()
88+
ssc.checkpoint(checkpointDir)
89+
90+
// Do the computation for initial number of batches, create checkpoint file and quit
91+
val beforeRestartOutput = generateOutput[V](ssc,
92+
Time(batchDurationMillis * numBatchesBeforeRestart), checkpointDir, stopSparkContextAfterTest)
93+
assertOutput(beforeRestartOutput, expectedOutput, beforeRestart = true)
94+
// Restart and complete the computation from checkpoint file
95+
logInfo(
96+
"\n-------------------------------------------\n" +
97+
" Restarting stream computation " +
98+
"\n-------------------------------------------\n"
99+
)
100+
101+
val restartedSsc = new StreamingContext(checkpointDir)
102+
val afterRestartOutput = generateOutput[V](restartedSsc,
103+
Time(batchDurationMillis * totalNumBatches), checkpointDir, stopSparkContextAfterTest)
104+
assertOutput(afterRestartOutput, expectedOutput, beforeRestart = false)
105+
}
106+
107+
protected def createContextForCheckpointOperation(batchDuration: Duration): StreamingContext = {
108+
val conf = new SparkConf().setMaster("local").setAppName(this.getClass.getSimpleName)
109+
conf.set("spark.streaming.clock", classOf[ManualClock].getName())
110+
new StreamingContext(SparkContext.getOrCreate(conf), batchDuration)
111+
}
112+
113+
private def generateOutput[V: ClassTag](
114+
ssc: StreamingContext,
115+
targetBatchTime: Time,
116+
checkpointDir: String,
117+
stopSparkContext: Boolean
118+
): Seq[Seq[V]] = {
119+
try {
120+
val batchDuration = ssc.graph.batchDuration
121+
val batchCounter = new BatchCounter(ssc)
122+
ssc.start()
123+
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
124+
val currentTime = clock.getTimeMillis()
125+
126+
logInfo("Manual clock before advancing = " + clock.getTimeMillis())
127+
clock.setTime(targetBatchTime.milliseconds)
128+
logInfo("Manual clock after advancing = " + clock.getTimeMillis())
129+
130+
val outputStream = ssc.graph.getOutputStreams().filter { dstream =>
131+
dstream.isInstanceOf[TestOutputStreamWithPartitions[V]]
132+
}.head.asInstanceOf[TestOutputStreamWithPartitions[V]]
133+
134+
eventually(timeout(10 seconds)) {
135+
ssc.awaitTerminationOrTimeout(10)
136+
assert(batchCounter.getLastCompletedBatchTime === targetBatchTime)
137+
}
138+
139+
eventually(timeout(10 seconds)) {
140+
val checkpointFilesOfLatestTime = Checkpoint.getCheckpointFiles(checkpointDir).filter {
141+
_.toString.contains(clock.getTimeMillis.toString)
142+
}
143+
// Checkpoint files are written twice for every batch interval. So assert that both
144+
// are written to make sure that both of them have been written.
145+
assert(checkpointFilesOfLatestTime.size === 2)
146+
}
147+
outputStream.output.map(_.flatten)
148+
149+
} finally {
150+
ssc.stop(stopSparkContext = stopSparkContext)
151+
}
152+
}
153+
154+
private def assertOutput[V: ClassTag](
155+
output: Seq[Seq[V]],
156+
expectedOutput: Seq[Seq[V]],
157+
beforeRestart: Boolean): Unit = {
158+
val expectedPartialOutput = if (beforeRestart) {
159+
expectedOutput.take(output.size)
160+
} else {
161+
expectedOutput.takeRight(output.size)
162+
}
163+
val setComparison = output.zip(expectedPartialOutput).forall {
164+
case (o, e) => o.toSet === e.toSet
165+
}
166+
assert(setComparison, s"set comparison failed\n" +
167+
s"Expected output items:\n${expectedPartialOutput.mkString("\n")}\n" +
168+
s"Generated output items: ${output.mkString("\n")}"
169+
)
170+
}
171+
}
172+
41173
/**
42174
* This test suites tests the checkpointing functionality of DStreams -
43175
* the checkpointing of a DStream's RDDs as well as the checkpointing of
44176
* the whole DStream graph.
45177
*/
46-
class CheckpointSuite extends TestSuiteBase {
178+
class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester {
47179

48180
var ssc: StreamingContext = null
49181

@@ -56,7 +188,7 @@ class CheckpointSuite extends TestSuiteBase {
56188

57189
override def afterFunction() {
58190
super.afterFunction()
59-
if (ssc != null) ssc.stop()
191+
if (ssc != null) { ssc.stop() }
60192
Utils.deleteRecursively(new File(checkpointDir))
61193
}
62194

@@ -251,7 +383,9 @@ class CheckpointSuite extends TestSuiteBase {
251383
Seq(("", 2)),
252384
Seq(),
253385
Seq(("a", 2), ("b", 1)),
254-
Seq(("", 2)), Seq() ),
386+
Seq(("", 2)),
387+
Seq()
388+
),
255389
3
256390
)
257391
}
@@ -634,53 +768,6 @@ class CheckpointSuite extends TestSuiteBase {
634768
checkpointWriter.stop()
635769
}
636770

637-
/**
638-
* Tests a streaming operation under checkpointing, by restarting the operation
639-
* from checkpoint file and verifying whether the final output is correct.
640-
* The output is assumed to have come from a reliable queue which an replay
641-
* data as required.
642-
*
643-
* NOTE: This takes into consideration that the last batch processed before
644-
* master failure will be re-processed after restart/recovery.
645-
*/
646-
def testCheckpointedOperation[U: ClassTag, V: ClassTag](
647-
input: Seq[Seq[U]],
648-
operation: DStream[U] => DStream[V],
649-
expectedOutput: Seq[Seq[V]],
650-
initialNumBatches: Int
651-
) {
652-
653-
// Current code assumes that:
654-
// number of inputs = number of outputs = number of batches to be run
655-
val totalNumBatches = input.size
656-
val nextNumBatches = totalNumBatches - initialNumBatches
657-
val initialNumExpectedOutputs = initialNumBatches
658-
val nextNumExpectedOutputs = expectedOutput.size - initialNumExpectedOutputs + 1
659-
// because the last batch will be processed again
660-
661-
// Do the computation for initial number of batches, create checkpoint file and quit
662-
ssc = setupStreams[U, V](input, operation)
663-
ssc.start()
664-
val output = advanceTimeWithRealDelay[V](ssc, initialNumBatches)
665-
ssc.stop()
666-
verifyOutput[V](output, expectedOutput.take(initialNumBatches), true)
667-
Thread.sleep(1000)
668-
669-
// Restart and complete the computation from checkpoint file
670-
logInfo(
671-
"\n-------------------------------------------\n" +
672-
" Restarting stream computation " +
673-
"\n-------------------------------------------\n"
674-
)
675-
ssc = new StreamingContext(checkpointDir)
676-
ssc.start()
677-
val outputNew = advanceTimeWithRealDelay[V](ssc, nextNumBatches)
678-
// the first element will be re-processed data of the last batch before restart
679-
verifyOutput[V](outputNew, expectedOutput.takeRight(nextNumExpectedOutputs), true)
680-
ssc.stop()
681-
ssc = null
682-
}
683-
684771
/**
685772
* Advances the manual clock on the streaming scheduler by given number of batches.
686773
* It also waits for the expected amount of time for each batch.

streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ class BatchCounter(ssc: StreamingContext) {
142142
// All access to this state should be guarded by `BatchCounter.this.synchronized`
143143
private var numCompletedBatches = 0
144144
private var numStartedBatches = 0
145+
private var lastCompletedBatchTime: Time = null
145146

146147
private val listener = new StreamingListener {
147148
override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit =
@@ -152,6 +153,7 @@ class BatchCounter(ssc: StreamingContext) {
152153
override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit =
153154
BatchCounter.this.synchronized {
154155
numCompletedBatches += 1
156+
lastCompletedBatchTime = batchCompleted.batchInfo.batchTime
155157
BatchCounter.this.notifyAll()
156158
}
157159
}
@@ -165,6 +167,10 @@ class BatchCounter(ssc: StreamingContext) {
165167
numStartedBatches
166168
}
167169

170+
def getLastCompletedBatchTime: Time = this.synchronized {
171+
lastCompletedBatchTime
172+
}
173+
168174
/**
169175
* Wait until `expectedNumCompletedBatches` batches are completed, or timeout. Return true if
170176
* `expectedNumCompletedBatches` batches are completed. Otherwise, return false to indicate it's

0 commit comments

Comments
 (0)