Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
WIP Address a part of review comments from @tdas
* TODO list
  * replace all the usages for direct call of store.xxx whenever state manager is available
  * add iterator / remove in StreamingAggregationStateManager to remove restoreOriginRow
  * add docs
  • Loading branch information
HeartSaVioR committed Aug 1, 2018
commit 60c231e98a550b0e439827caff75a29c23423a9c
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.json4s.jackson.Serialization

import org.apache.spark.internal.Logging
import org.apache.spark.sql.RuntimeConfig
import org.apache.spark.sql.execution.streaming.state.FlatMapGroupsWithStateExecHelper
import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager}
import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _}

/**
Expand Down Expand Up @@ -106,7 +106,7 @@ object OffsetSeqMetadata extends Logging {
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key ->
FlatMapGroupsWithStateExecHelper.legacyVersion.toString,
STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key ->
StatefulOperatorsHelper.legacyVersion.toString
StreamingAggregationStateManager.legacyVersion.toString
)

def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,12 @@ package org.apache.spark.sql.execution.streaming
import scala.reflect.ClassTag

import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner}
import org.apache.spark.sql.internal.SessionState
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -81,4 +85,110 @@ package object state {
storeCoordinator)
}
}

sealed trait StreamingAggregationStateManager extends Serializable {
def getKey(row: InternalRow): UnsafeRow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why is the input typed InternalRow where everything else is UnsafeRow? seems inconsistent.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getKey was basically UnsafeProjection in statefulOperator so didn't necessarily require UnsafeRow. I just followed the usage to make it less restrict, but we know, in reality row will be always UnsafeRow. So OK to fix if it provides consistency.

def getStateValueSchema: StructType
def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow
def get(store: StateStore, key: UnsafeRow): UnsafeRow
def put(store: StateStore, row: UnsafeRow): Unit
}

object StreamingAggregationStateManager extends Logging {
val supportedVersions = Seq(1, 2)
val legacyVersion = 1

def createStateManager(
keyExpressions: Seq[Attribute],
inputRowAttributes: Seq[Attribute],
stateFormatVersion: Int): StreamingAggregationStateManager = {
stateFormatVersion match {
case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, inputRowAttributes)
case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, inputRowAttributes)
case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid")
}
}
}

abstract class StreamingAggregationStateManagerBaseImpl(
protected val keyExpressions: Seq[Attribute],
protected val inputRowAttributes: Seq[Attribute]) extends StreamingAggregationStateManager {

@transient protected lazy val keyProjector =
GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes)

def getKey(row: InternalRow): UnsafeRow = keyProjector(row)
}

class StreamingAggregationStateManagerImplV1(
keyExpressions: Seq[Attribute],
inputRowAttributes: Seq[Attribute])
extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) {

override def getStateValueSchema: StructType = inputRowAttributes.toStructType

override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = {
rowPair.value
}

override def get(store: StateStore, key: UnsafeRow): UnsafeRow = {
store.get(key)
}

override def put(store: StateStore, row: UnsafeRow): Unit = {
store.put(getKey(row), row)
}
}

class StreamingAggregationStateManagerImplV2(
keyExpressions: Seq[Attribute],
inputRowAttributes: Seq[Attribute])
extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) {

private val valueExpressions: Seq[Attribute] = inputRowAttributes.diff(keyExpressions)
private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions
private val needToProjectToRestoreValue: Boolean =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docs on what this means (that, if the fields in the joined row are not in the expected order, then use an additional project)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add.

keyValueJoinedExpressions != inputRowAttributes

@transient private lazy val valueProjector =
GenerateUnsafeProjection.generate(valueExpressions, inputRowAttributes)

@transient private lazy val joiner =
GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions),
StructType.fromAttributes(valueExpressions))
@transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate(
keyValueJoinedExpressions, inputRowAttributes)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure this is right??

def generate(expressions: InType, inputSchema: Seq[Attribute])

So the 2nd param is the input schema of the input rows of the projection. This projection applied to the joined rows, which have the schema keyValueJoinedExpressions. So I think these two should flip.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am wondering why does this not fail any test. is it because needToProjectToRestoreValue is always false?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. You're right. Will fix. Btw, needToProjectToRestoreValue is always false, unless sequence of columns for key and value get mixed up.


override def getStateValueSchema: StructType = valueExpressions.toStructType

override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = {
val joinedRow = joiner.join(rowPair.key, rowPair.value)
if (needToProjectToRestoreValue) {
restoreValueProjector(joinedRow)
} else {
joinedRow
}
}

override def get(store: StateStore, key: UnsafeRow): UnsafeRow = {
val savedState = store.get(key)
if (savedState == null) {
return savedState
}

val joinedRow = joiner.join(key, savedState)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cant you dedup the code with restoreOriginRow method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed spot. Will leverage restoreOriginRow.

if (needToProjectToRestoreValue) {
restoreValueProjector(joinedRow)
} else {
joinedRow
}
}

override def put(store: StateStore, row: UnsafeRow): Unit = {
val key = keyProjector(row)
val value = valueProjector(row)
store.put(key, value)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistrib
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.streaming.StatefulOperatorsHelper.StreamingAggregationStateManager
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress}
import org.apache.spark.sql.types._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,10 @@ package org.apache.spark.sql.execution.streaming.state

import org.apache.spark.sql.catalyst.expressions.{Attribute, SpecificInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.execution.streaming.StatefulOperatorsHelper.StreamingAggregationStateManager
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class StatefulOperatorsHelperSuite extends StreamTest {
class StreamingAggregationStateManagerSuite extends StreamTest {
// ============================ fields and method for test data ============================

val testKeys: Seq[String] = Seq("key1", "key2")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode}
import org.apache.spark.sql.execution.exchange.Exchange
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.state.StateStore
import org.apache.spark.sql.execution.streaming.state.{StateStore, StreamingAggregationStateManager}
import org.apache.spark.sql.expressions.scalalang.typed
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -65,7 +65,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest

def testWithAllStateVersions(name: String, confPairs: (String, String)*)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: the confPair param is used only in one location, do you think its worth adding it as a param? The only test that needs it can stay unchanged.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it's basically from wondering of how withSQLConf works. Does withSQLConf handle nested withSQLConf properly? If then we don't need to add confPairs param at all, and if not I guess we might still want to add this.

(func: => Any): Unit = {
for (version <- StatefulOperatorsHelper.supportedVersions) {
for (version <- StreamingAggregationStateManager.supportedVersions) {
test(s"$name - state format version $version") {
executeFuncWithStateVersionSQLConf(version, confPairs, func)
}
Expand All @@ -74,7 +74,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest

def testQuietlyWithAllStateVersions(name: String, confPairs: (String, String)*)
(func: => Any): Unit = {
for (version <- StatefulOperatorsHelper.supportedVersions) {
for (version <- StreamingAggregationStateManager.supportedVersions) {
testQuietly(s"$name - state format version $version") {
executeFuncWithStateVersionSQLConf(version, confPairs, func)
}
Expand Down