Skip to content
Next Next commit
Partial implementation
  • Loading branch information
tdas committed Jul 3, 2018
commit ef509c8986dbcc9b37387b0bde56c3d71abb7602
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
import org.apache.spark.sql.types.IntegerType
import org.apache.spark.util.CompletionIterator

/**
Expand Down Expand Up @@ -60,32 +58,14 @@ case class FlatMapGroupsWithStateExec(
) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter with WatermarkSupport {

import GroupStateImpl._
import FlatMapGroupsWithStateExecHelper._

private val isTimeoutEnabled = timeoutConf != NoTimeout
private val timestampTimeoutAttribute =
AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)()
private val stateAttributes: Seq[Attribute] = {
val encSchemaAttribs = stateEncoder.schema.toAttributes
if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs
}
// Get the serializer for the state, taking into account whether we need to save timestamps
private val stateSerializer = {
val encoderSerializer = stateEncoder.namedExpressions
if (isTimeoutEnabled) {
encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP)
} else {
encoderSerializer
}
}
// Get the deserializer for the state. Note that this must be done in the driver, as
// resolving and binding of deserializer expressions to the encoded type can be safely done
// only in the driver.
private val stateDeserializer = stateEncoder.resolveAndBind().deserializer

private val watermarkPresent = child.output.exists {
case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true
case _ => false
}
private[sql] val stateManager = createStateManager(stateEncoder, isTimeoutEnabled)

/** Distribute by grouping attributes */
override def requiredChildDistribution: Seq[Distribution] =
Expand Down Expand Up @@ -125,11 +105,11 @@ case class FlatMapGroupsWithStateExec(
child.execute().mapPartitionsWithStateStore[InternalRow](
getStateInfo,
groupingAttributes.toStructType,
stateAttributes.toStructType,
stateManager.stateSchema,
indexOrdinal = None,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
val updater = new StateStoreUpdater(store)
val processor = new InputProcessor(store)

// If timeout is based on event time, then filter late data based on watermark
val filteredIter = watermarkPredicateForData match {
Expand All @@ -143,7 +123,7 @@ case class FlatMapGroupsWithStateExec(
// all the data has been processed. This is to ensure that the timeout information of all
// the keys with data is updated before they are processed for timeouts.
val outputIterator =
updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys()
processor.processNewData(filteredIter) ++ processor.processTimedOutState()

// Return an iterator of all the rows generated by all the keys, such that when fully
// consumed, all the state updates will be committed by the state store
Expand All @@ -158,7 +138,7 @@ case class FlatMapGroupsWithStateExec(
}

/** Helper class to update the state store */
class StateStoreUpdater(store: StateStore) {
class InputProcessor(store: StateStore) {

// Converters for translating input keys, values, output data between rows and Java objects
private val getKeyObj =
Expand All @@ -167,14 +147,6 @@ case class FlatMapGroupsWithStateExec(
ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)

// Converters for translating state between rows and Java objects
private val getStateObjFromRow = ObjectOperator.deserializeRowToObject(
stateDeserializer, stateAttributes)
private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer)

// Index of the additional metadata fields in the state row
private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute)

// Metrics
private val numUpdatedStateRows = longMetric("numUpdatedStateRows")
private val numOutputRows = longMetric("numOutputRows")
Expand All @@ -183,20 +155,19 @@ case class FlatMapGroupsWithStateExec(
* For every group, get the key, values and corresponding state and call the function,
* and return an iterator of rows
*/
def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = {
def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = {
val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output)
groupedIter.flatMap { case (keyRow, valueRowIter) =>
val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow]
callFunctionAndUpdateState(
keyUnsafeRow,
stateManager.getState(store, keyUnsafeRow),
valueRowIter,
store.get(keyUnsafeRow),
hasTimedOut = false)
}
}

/** Find the groups that have timeout set and are timing out right now, and call the function */
def updateStateForTimedOutKeys(): Iterator[InternalRow] = {
def processTimedOutState(): Iterator[InternalRow] = {
if (isTimeoutEnabled) {
val timeoutThreshold = timeoutConf match {
case ProcessingTimeTimeout => batchTimestampMs.get
Expand All @@ -205,12 +176,11 @@ case class FlatMapGroupsWithStateExec(
throw new IllegalStateException(
s"Cannot filter timed out keys for $timeoutConf")
}
val timingOutPairs = store.getRange(None, None).filter { rowPair =>
val timeoutTimestamp = getTimeoutTimestamp(rowPair.value)
timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold
val timingOutPairs = stateManager.getAllState(store).filter { state =>
state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold
}
timingOutPairs.flatMap { rowPair =>
callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true)
timingOutPairs.flatMap { stateData =>
callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true)
}
} else Iterator.empty
}
Expand All @@ -220,73 +190,44 @@ case class FlatMapGroupsWithStateExec(
* iterator. Note that the store updating is lazy, that is, the store will be updated only
* after the returned iterator is fully consumed.
*
* @param keyRow Row representing the key, cannot be null
* @param stateData All the data related to the state to be updated
* @param valueRowIter Iterator of values as rows, cannot be null, but can be empty
* @param prevStateRow Row representing the previous state, can be null
* @param hasTimedOut Whether this function is being called for a key timeout
*/
private def callFunctionAndUpdateState(
keyRow: UnsafeRow,
stateData: StateData,
valueRowIter: Iterator[InternalRow],
prevStateRow: UnsafeRow,
hasTimedOut: Boolean): Iterator[InternalRow] = {

val keyObj = getKeyObj(keyRow) // convert key to objects
val keyObj = getKeyObj(stateData.keyRow) // convert key to objects
val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects
val stateObj = getStateObj(prevStateRow)
val keyedState = GroupStateImpl.createForStreaming(
Option(stateObj),
val groupState = GroupStateImpl.createForStreaming(
Option(stateData.stateObj),
batchTimestampMs.getOrElse(NO_TIMESTAMP),
eventTimeWatermark.getOrElse(NO_TIMESTAMP),
timeoutConf,
hasTimedOut,
watermarkPresent)

// Call function, get the returned objects and convert them to rows
val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj =>
val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj =>
numOutputRows += 1
getOutputRow(obj)
}

// When the iterator is consumed, then write changes to state
def onIteratorCompletion: Unit = {

val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp
// If the state has not yet been set but timeout has been set, then
// we have to generate a row to save the timeout. However, attempting serialize
// null using case class encoder throws -
// java.lang.NullPointerException: Null value appeared in non-nullable field:
// If the schema is inferred from a Scala tuple / case class, or a Java bean, please
// try to use scala.Option[_] or other nullable types.
if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) {
throw new IllegalStateException(
"Cannot set timeout when state is not defined, that is, state has not been" +
"initialized or has been removed")
}

if (keyedState.hasRemoved) {
store.remove(keyRow)
if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) {
stateManager.removeState(store, stateData.keyRow)
numUpdatedStateRows += 1

} else {
val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow)
val stateRowToWrite = if (keyedState.hasUpdated) {
getStateRow(keyedState.get)
} else {
prevStateRow
}

val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp
val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged
val currentTimeoutTimestamp = groupState.getTimeoutTimestamp
val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp
val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged

if (shouldWriteState) {
if (stateRowToWrite == null) {
// This should never happen because checks in GroupStateImpl should avoid cases
// where empty state would need to be written
throw new IllegalStateException("Attempting to write empty state")
}
setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp)
store.put(keyRow, stateRowToWrite)
val updatedStateObj = if (groupState.exists) groupState.get else null
stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp)
numUpdatedStateRows += 1
}
}
Expand All @@ -295,28 +236,5 @@ case class FlatMapGroupsWithStateExec(
// Return an iterator of rows such that fully consumed, the updated state value will be saved
CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion)
}

/** Returns the state as Java object if defined */
def getStateObj(stateRow: UnsafeRow): Any = {
if (stateRow != null) getStateObjFromRow(stateRow) else null
}

/** Returns the row for an updated state */
def getStateRow(obj: Any): UnsafeRow = {
assert(obj != null)
getStateRowFromObj(obj)
}

/** Returns the timeout timestamp of a state row is set */
def getTimeoutTimestamp(stateRow: UnsafeRow): Long = {
if (isTimeoutEnabled && stateRow != null) {
stateRow.getLong(timeoutTimestampIndex)
} else NO_TIMESTAMP
}

/** Set the timestamp in a state row */
def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = {
if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps)
}
}
}
Loading