Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Fixed bug and added unit tests
  • Loading branch information
tdas committed Nov 12, 2015
commit 92694c74abaa23b93a88c29388a6dbbb034b98ef
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,51 @@ import org.apache.spark._
* Record storing the keyed-state [[TrackStateRDD]]. Each record contains a [[StateMap]] and a
* sequence of records returned by the tracking function of `trackStateByKey`.
*/
private[streaming] case class TrackStateRDDRecord[K, S, T](
var stateMap: StateMap[K, S], var emittedRecords: Seq[T])
private[streaming] case class TrackStateRDDRecord[K, S, E](
var stateMap: StateMap[K, S], var emittedRecords: Seq[E])

object TrackStateRDDRecord {
def updateRecordWithData[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
prevRecord: Option[TrackStateRDDRecord[K, S, E]],
dataIterator: Iterator[(K, V)],
updateFunction: (Time, K, Option[V], State[S]) => Option[E],
batchTime: Time,
timeoutThresholdTime: Option[Long],
removeTimedoutData: Boolean
): TrackStateRDDRecord[K, S, E] = {
// Create a new state map by cloning the previous one (if it exists) or by creating an empty one
val newStateMap = prevRecord.map { _.stateMap.copy() }. getOrElse { new EmptyStateMap[K, S]() }

val emittedRecords = new ArrayBuffer[E]
val wrappedState = new StateImpl[S]()

// Call the tracking function on each record in the data iterator, and accordingly
// update the states touched, and collect the data returned by the tracking function
dataIterator.foreach { case (key, value) =>
wrappedState.wrap(newStateMap.get(key))
val emittedRecord = updateFunction(batchTime, key, Some(value), wrappedState)
if (wrappedState.isRemoved) {
newStateMap.remove(key)
} else if (wrappedState.isUpdated || timeoutThresholdTime.isDefined) {
newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
}
emittedRecords ++= emittedRecord
}

// Get the timed out state records, call the tracking function on each and collect the
// data returned
if (removeTimedoutData && timeoutThresholdTime.isDefined) {
newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
wrappedState.wrapTiminoutState(state)
val emittedRecord = updateFunction(batchTime, key, None, wrappedState)
emittedRecords ++= emittedRecord
newStateMap.remove(key)
}
}

TrackStateRDDRecord(newStateMap, emittedRecords)
}
}

/**
* Partition of the [[TrackStateRDD]], which depends on corresponding partitions of prev state
Expand Down Expand Up @@ -71,15 +114,16 @@ private[streaming] class TrackStateRDDPartition(
* @param trackingFunction The function that will be used to update state and return new data
* @param batchTime The time of the batch to which this RDD belongs to. Use to update
*/
private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag](
private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]],
private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, E: ClassTag](
private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, E]],
private var partitionedDataRDD: RDD[(K, V)],
trackingFunction: (Time, K, Option[V], State[S]) => Option[T],
batchTime: Time, timeoutThresholdTime: Option[Long]
) extends RDD[TrackStateRDDRecord[K, S, T]](
trackingFunction: (Time, K, Option[V], State[S]) => Option[E],
batchTime: Time,
timeoutThresholdTime: Option[Long]
) extends RDD[TrackStateRDDRecord[K, S, E]](
partitionedDataRDD.sparkContext,
List(
new OneToOneDependency[TrackStateRDDRecord[K, S, T]](prevStateRDD),
new OneToOneDependency[TrackStateRDDRecord[K, S, E]](prevStateRDD),
new OneToOneDependency(partitionedDataRDD))
) {

Expand All @@ -96,50 +140,24 @@ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T:
}

override def compute(
partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, T]] = {
partition: Partition, context: TaskContext): Iterator[TrackStateRDDRecord[K, S, E]] = {

val stateRDDPartition = partition.asInstanceOf[TrackStateRDDPartition]
val prevStateRDDIterator = prevStateRDD.iterator(
stateRDDPartition.previousSessionRDDPartition, context)
val dataIterator = partitionedDataRDD.iterator(
stateRDDPartition.partitionedDataRDDPartition, context)

// Create a new state map by cloning the previous one (if it exists) or by creating an empty one
val newStateMap = if (prevStateRDDIterator.hasNext) {
prevStateRDDIterator.next().stateMap.copy()
} else {
new EmptyStateMap[K, S]()
}

val emittedRecords = new ArrayBuffer[T]
val wrappedState = new StateImpl[S]()

// Call the tracking function on each record in the data RDD partition, and accordingly
// update the states touched, and the data returned by the tracking function.
dataIterator.foreach { case (key, value) =>
wrappedState.wrap(newStateMap.get(key))
val emittedRecord = trackingFunction(batchTime, key, Some(value), wrappedState)
if (wrappedState.isRemoved) {
newStateMap.remove(key)
} else if (wrappedState.isUpdated) {
newStateMap.put(key, wrappedState.get(), batchTime.milliseconds)
}
emittedRecords ++= emittedRecord
}

// If the RDD is expected to be doing a full scan of all the data in the StateMap,
// then use this opportunity to filter out those keys that have timed out.
// For each of them call the tracking function.
if (doFullScan && timeoutThresholdTime.isDefined) {
newStateMap.getByTime(timeoutThresholdTime.get).foreach { case (key, state, _) =>
wrappedState.wrapTiminoutState(state)
val emittedRecord = trackingFunction(batchTime, key, None, wrappedState)
emittedRecords ++= emittedRecord
newStateMap.remove(key)
}
}

Iterator(TrackStateRDDRecord(newStateMap, emittedRecords))
val prevRecord = if (prevStateRDDIterator.hasNext) Some(prevStateRDDIterator.next()) else None
val newRecord = TrackStateRDDRecord.updateRecordWithData(
prevRecord,
dataIterator,
trackingFunction,
batchTime,
timeoutThresholdTime,
removeTimedoutData = doFullScan // remove timedout data only when full scan is enabled
)
Iterator(newRecord)
}

override protected def getPartitions: Array[Partition] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import scala.reflect.ClassTag
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.util.OpenHashMapBasedStateMap
import org.apache.spark.streaming.{Time, State}
import org.apache.spark.{HashPartitioner, SparkConf, SparkContext, SparkFunSuite}

Expand All @@ -46,6 +47,131 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
assert(rdd.partitioner === Some(partitioner))
}

test("updating state and generating emitted data in TrackStateRecord") {

val initialTime = 1000L
val updatedTime = 2000L
val thresholdTime = 1500L
@volatile var functionCalled = false

/**
* Assert that applying given data on a prior record generates correct updated record, with
* correct state map and emitted data
*/
def assertRecordUpdate(
initStates: Iterable[Int],
data: Iterable[String],
expectedStates: Iterable[(Int, Long)],
timeoutThreshold: Option[Long] = None,
removeTimedoutData: Boolean = false,
expectedOutput: Iterable[Int] = None,
expectedTimingOutStates: Iterable[Int] = None,
expectedRemovedStates: Iterable[Int] = None
): Unit = {
val initialStateMap = new OpenHashMapBasedStateMap[String, Int]()
initStates.foreach { s => initialStateMap.put("key", s, initialTime) }
functionCalled = false
val record = TrackStateRDDRecord[String, Int, Int](initialStateMap, Seq.empty)
val dataIterator = data.map { v => ("key", v) }.iterator
val removedStates = new ArrayBuffer[Int]
val timingOutStates = new ArrayBuffer[Int]
/**
* Tracking function that updates/removes state based on instructions in the data, and
* return state (when instructed or when state is timing out).
*/
def testFunc(t: Time, key: String, data: Option[String], state: State[Int]): Option[Int] = {
functionCalled = true

assert(t.milliseconds === updatedTime, "tracking func called with wrong time")

data match {
case Some("noop") =>
None
case Some("get-state") =>
Some(state.getOption().getOrElse(-1))
case Some("update-state") =>
if (state.exists) state.update(state.get + 1) else state.update(0)
None
case Some("remove-state") =>
removedStates += state.get()
state.remove()
None
case None =>
assert(state.isTimingOut() === true, "State is not timing out when data = None")
timingOutStates += state.get()
None
case _ =>
fail("Unexpected test data")
}
}

val updatedRecord = TrackStateRDDRecord.updateRecordWithData[String, String, Int, Int](
Some(record), dataIterator, testFunc,
Time(updatedTime), timeoutThreshold, removeTimedoutData)

val updatedStateData = updatedRecord.stateMap.getAll().map { x => (x._2, x._3) }
assert(updatedStateData.toSet === expectedStates.toSet,
"states do not match after updating the TrackStateRecord")

assert(updatedRecord.emittedRecords.toSet === expectedOutput.toSet,
"emitted data do not match after updating the TrackStateRecord")

assert(timingOutStates.toSet === expectedTimingOutStates.toSet, "timing out states do not " +
"match those that were expected to do so while updating the TrackStateRecord")

assert(removedStates.toSet === expectedRemovedStates.toSet, "removed states do not " +
"match those that were expected to do so while updating the TrackStateRecord")

}

// No data, no state should be changed, function should not be called,
assertRecordUpdate(initStates = Nil, data = None, expectedStates = Nil)
assert(functionCalled === false)
assertRecordUpdate(initStates = Seq(0), data = None, expectedStates = Seq((0, initialTime)))
assert(functionCalled === false)

// Data present, function should be called irrespective of whether state exists
assertRecordUpdate(initStates = Seq(0), data = Seq("noop"),
expectedStates = Seq((0, initialTime)))
assert(functionCalled === true)
assertRecordUpdate(initStates = None, data = Some("noop"), expectedStates = None)
assert(functionCalled === true)

// Function called with right state data
assertRecordUpdate(initStates = None, data = Seq("get-state"),
expectedStates = None, expectedOutput = Seq(-1))
assertRecordUpdate(initStates = Seq(123), data = Seq("get-state"),
expectedStates = Seq((123, initialTime)), expectedOutput = Seq(123))

// Update state and timestamp, when timeout not present
assertRecordUpdate(initStates = Nil, data = Seq("update-state"),
expectedStates = Seq((0, updatedTime)))
assertRecordUpdate(initStates = Seq(0), data = Seq("update-state"),
expectedStates = Seq((1, updatedTime)))

// Remove state
assertRecordUpdate(initStates = Seq(345), data = Seq("remove-state"),
expectedStates = Nil, expectedRemovedStates = Seq(345))

// State strictly older than timeout threshold should be timed out
assertRecordUpdate(initStates = Seq(123), data = Nil,
timeoutThreshold = Some(initialTime), removeTimedoutData = true,
expectedStates = Seq((123, initialTime)), expectedTimingOutStates = Nil)

assertRecordUpdate(initStates = Seq(123), data = Nil,
timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
expectedStates = Nil, expectedTimingOutStates = Seq(123))

// State should not be timed out after it has received data
assertRecordUpdate(initStates = Seq(123), data = Seq("noop"),
timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
expectedStates = Seq((123, updatedTime)), expectedTimingOutStates = Nil)
assertRecordUpdate(initStates = Seq(123), data = Seq("remove-state"),
timeoutThreshold = Some(initialTime + 1), removeTimedoutData = true,
expectedStates = Nil, expectedTimingOutStates = Nil, expectedRemovedStates = Seq(123))

}

test("states generated by TrackStateRDD") {
val initStates = Seq(("k1", 0), ("k2", 0))
val initTime = 123
Expand Down Expand Up @@ -100,16 +226,18 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
// Assert that the function was called only for the keys present in the data
assert(TrackStateRDDSuite.touchedStateKeys.size === testData.size,
"More number of keys are being touched than that is expected")

assert(TrackStateRDDSuite.touchedStateKeys.toSet === testData.toMap.keys,
"Keys not in the data are being touched unexpectedly")

// Assert that the test RDD's data has not changed
assertRDD(initStateRDD, initStateWthTime, Set.empty)
//assertRDD(initStateRDD, initStateWthTime, Set.empty)
newStateRDD
}

// Test no-op, no state should change
testStateUpdates(initStateRDD, Seq(), initStateWthTime) // should not scan any state

testStateUpdates(
initStateRDD, Seq(("k1", 0)), initStateWthTime) // should not update existing state
testStateUpdates(
Expand All @@ -123,6 +251,7 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
Set(("k1", 0, initTime), ("k2", 0, initTime), ("k3", 0, updateTime), ("k4", 0, updateTime)))

// Test updating of state
println("---------------")
val rdd3 = testStateUpdates(
initStateRDD, Seq(("k1", 1)), // should increment k1's state 0 -> 1
Set(("k1", 1, updateTime), ("k2", 0, initTime)))
Expand All @@ -142,9 +271,8 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
val rdd7 = testStateUpdates( // should remove k2's state
rdd6, Seq(("k2", 2), ("k0", 2), ("k3", 1)), Set(("k3", 0, updateTime)))

val rdd8 = testStateUpdates(
rdd7, Seq(("k3", 2)), Set() //
)
val rdd8 = testStateUpdates( // should remove k3's state
rdd7, Seq(("k3", 2)), Set())
}

/** Assert whether the `trackStateByKey` operation generates expected results */
Expand All @@ -170,7 +298,7 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {

// Persist to make sure that it gets computed only once and we can track precisely how many
// state keys the computing touched
newStateRDD.persist()
newStateRDD.persist().count()
assertRDD(newStateRDD, expectedStates, expectedEmittedRecords)
newStateRDD
}
Expand All @@ -182,7 +310,8 @@ class TrackStateRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
expectedEmittedRecords: Set[T]): Unit = {
val states = trackStateRDD.flatMap { _.stateMap.getAll() }.collect().toSet
val emittedRecords = trackStateRDD.flatMap { _.emittedRecords }.collect().toSet
assert(states === expectedStates, "states after track state operation were not as expected")
assert(states === expectedStates,
"states after track state operation were not as expected")
assert(emittedRecords === expectedEmittedRecords,
"emitted records after track state operation were not as expected")
}
Expand Down