Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Change the strategy: "add new option" -> "apply by default, but keep …
…backward compatible"
  • Loading branch information
HeartSaVioR committed Jul 20, 2018
commit 63dfb5d2c82dfdf0a9e681fd5608f72a11dc04ed
Original file line number Diff line number Diff line change
Expand Up @@ -871,15 +871,15 @@ object SQLConf {
.intConf
.createWithDefault(2)

val ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION =
buildConf("spark.sql.streaming.advanced.removeRedundantInStatefulAggregation")
val STREAMING_AGGREGATION_STATE_FORMAT_VERSION =
buildConf("spark.sql.streaming.streamingAggregation.stateFormatVersion")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to saying "streaming" in "streamingAggregation" since its already qualified by "spark.sql.streaming."

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah OK. Sounds better. Will fix.

.internal()
.doc("ADVANCED: When true, stateful aggregation tries to remove redundant data " +
"between key and value in state. Enabling this option helps minimizing state size, " +
"but no longer be compatible with state with disabling this option." +
"You can't change this option after starting the query.")
.booleanConf
.createWithDefault(false)
.doc("State format version used by streaming aggregation operations triggered " +
"explicitly or implicitly via agg() in a streaming query. State between versions are " +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you mean mean "implicitly"? Which operations are implicit?

Copy link
Contributor Author

@HeartSaVioR HeartSaVioR Aug 1, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was to explain that the option only applies to the operators which go through StateStoreRestoreExec / StateStoreSaveExec (so max("field1") as well as agg("field1" -> "max")), but now I feel it just gives confusion and I don't think end users need to understand details behind of config. Will remove the part triggered explicitly or implicitly via agg().

"tend to be incompatible, so state format version shouldn't be modified after running.")
.intConf
.checkValue(v => Set(1, 2).contains(v), "Valid versions are 1 and 2")
.createWithDefault(2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you intend to change the default to the new version, then you HAVE TO add a test that ensures that existing streaming aggregation checkpoints (generated in Spark 2.3.1 for example) will not fail to recover.

Similar to this test - https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala#L883

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice suggestion. Will add the test.


val UNSUPPORTED_OPERATION_CHECK_ENABLED =
buildConf("spark.sql.streaming.unsupportedOperationCheck")
Expand Down Expand Up @@ -1628,9 +1628,6 @@ class SQLConf extends Serializable with Logging {
def advancedPartitionPredicatePushdownEnabled: Boolean =
getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN)

def advancedRemoveRedundantInStatefulAggregation: Boolean =
getConf(ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION)

def fallBackToHdfsForStatsEnabled: Boolean = getConf(ENABLE_FALL_BACK_TO_HDFS_FOR_STATS)

def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
"Streaming aggregation doesn't support group aggregate pandas UDF")
}

val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION)

aggregate.AggUtils.planStreamingAggregation(
namedGroupingExpressions,
aggregateExpressions.map(expr => expr.asInstanceOf[AggregateExpression]),
rewrittenResultExpressions,
stateVersion,
planLater(child))

case _ => Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ object AggUtils {
groupingExpressions: Seq[NamedExpression],
functionsWithoutDistinct: Seq[AggregateExpression],
resultExpressions: Seq[NamedExpression],
stateFormatVersion: Int,
child: SparkPlan): Seq[SparkPlan] = {

val groupingAttributes = groupingExpressions.map(_.toAttribute)
Expand Down Expand Up @@ -287,7 +288,8 @@ object AggUtils {
child = partialAggregate)
}

val restored = StateStoreRestoreExec(groupingAttributes, None, partialMerged1)
val restored = StateStoreRestoreExec(groupingAttributes, None, stateFormatVersion,
partialMerged1)

val partialMerged2: SparkPlan = {
val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge))
Expand All @@ -311,6 +313,7 @@ object AggUtils {
stateInfo = None,
outputMode = None,
eventTimeWatermark = None,
stateFormatVersion = stateFormatVersion,
partialMerged2)

val finalAndCompleteAggregate: SparkPlan = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,21 @@ class IncrementalExecution(
val state = new Rule[SparkPlan] {

override def apply(plan: SparkPlan): SparkPlan = plan transform {
case StateStoreSaveExec(keys, None, None, None,
case StateStoreSaveExec(keys, None, None, None, stateFormatVersion,
UnaryExecNode(agg,
StateStoreRestoreExec(_, None, child))) =>
StateStoreRestoreExec(_, None, _, child))) =>
val aggStateInfo = nextStatefulOperationStateInfo
StateStoreSaveExec(
keys,
Some(aggStateInfo),
Some(outputMode),
Some(offsetSeqMetadata.batchWatermarkMs),
stateFormatVersion,
agg.withNewChildren(
StateStoreRestoreExec(
keys,
Some(aggStateInfo),
stateFormatVersion,
child) :: Nil))

case StreamingDeduplicateExec(keys, child, None, None) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ object OffsetSeqMetadata extends Logging {
private implicit val format = Serialization.formats(NoTypeHints)
private val relevantSQLConfs = Seq(
SHUFFLE_PARTITIONS, STATE_STORE_PROVIDER_CLASS, STREAMING_MULTIPLE_WATERMARK_POLICY,
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION)
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, STREAMING_AGGREGATION_STATE_FORMAT_VERSION)

/**
* Default values of relevant configurations that are used for backward compatibility.
Expand All @@ -104,7 +104,9 @@ object OffsetSeqMetadata extends Logging {
private val relevantSQLConfDefaultValues = Map[String, String](
STREAMING_MULTIPLE_WATERMARK_POLICY.key -> MultipleWatermarkPolicy.DEFAULT_POLICY_NAME,
FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION.key ->
FlatMapGroupsWithStateExecHelper.legacyVersion.toString
FlatMapGroupsWithStateExecHelper.legacyVersion.toString,
STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key ->
StatefulOperatorsHelper.legacyVersion.toString
)

def apply(json: String): OffsetSeqMetadata = Serialization.read[OffsetSeqMetadata](json)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner}
import org.apache.spark.sql.execution.streaming.state.{StateStore, UnsafeRowPair}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType

object StatefulOperatorsHelper {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why it is inside this generically named object StatefulOperatorsHelper. Rather make it a top-level trait StreamingAggregationStateManager in the execution.streaming.state package (similar to FlatMapGroupsWithStateExecHelper).

If you are modeling this against my state format PR for mapGroupsWithState, the only reason I put it in the StateManager class inside object FlatMapGroupsWithStateExecHelper was to avoid names like FlatMapGroupsWithStateExec_StateManager. I dont think that concern applies if you use the name StreamingAggregationStateManager.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah right. I found your PR useful to get an idea of how to model the classes because it was dealing with similar requirement, but didn't indicate the reason why you place it into StatefulOperatorsHelper. I'll move them to the state package.


val supportedVersions = Seq(1, 2)
val legacyVersion = 1

sealed trait StreamingAggregationStateManager extends Serializable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix.

def extractKey(row: InternalRow): UnsafeRow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the row here? add docs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename this to getKey to be consistent with other methods.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renaming sounds better. Will rename, and will also add docs.

def getValueExpressions: Seq[Attribute]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does getValueExpressions mean? its not obvious from the name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is to define the schema of value from / to state. For V1 it would be same to input schema and for V2 it would be input schema - key schema. Would getStateValueExpressions be OK for us?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be going to be getStateValueSchema btw, once we change return type.

Expand All @@ -35,16 +38,14 @@ object StatefulOperatorsHelper {
}

object StreamingAggregationStateManager extends Logging {
def newImpl(
def createStateManager(
keyExpressions: Seq[Attribute],
childOutput: Seq[Attribute],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe rename childOutput to inputRowAttributes to make the name more meaningful in the context of the StateManager interface (which does not have any concept of child).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds much better and you're right about concept of child. Will rename.

conf: SQLConf): StreamingAggregationStateManager = {

if (conf.advancedRemoveRedundantInStatefulAggregation) {
log.info("Advanced option removeRedundantInStatefulAggregation activated!")
new StreamingAggregationStateManagerImplV2(keyExpressions, childOutput)
} else {
new StreamingAggregationStateManagerImplV1(keyExpressions, childOutput)
stateFormatVersion: Int): StreamingAggregationStateManager = {
stateFormatVersion match {
case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, childOutput)
case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, childOutput)
case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,13 +200,15 @@ object WatermarkSupport {
case class StateStoreRestoreExec(
keyExpressions: Seq[Attribute],
stateInfo: Option[StatefulOperatorStateInfo],
stateFormatVersion: Int,
child: SparkPlan)
extends UnaryExecNode with StateStoreReader {

private[sql] val stateManager = StreamingAggregationStateManager.createStateManager(
keyExpressions, child.output, stateFormatVersion)

override protected def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
val stateManager = StreamingAggregationStateManager.newImpl(keyExpressions, child.output,
sqlContext.conf)

child.execute().mapPartitionsWithStateStore(
getStateInfo,
Expand Down Expand Up @@ -255,17 +257,18 @@ case class StateStoreSaveExec(
stateInfo: Option[StatefulOperatorStateInfo] = None,
outputMode: Option[OutputMode] = None,
eventTimeWatermark: Option[Long] = None,
stateFormatVersion: Int,
child: SparkPlan)
extends UnaryExecNode with StateStoreWriter with WatermarkSupport {

private[sql] val stateManager = StreamingAggregationStateManager.createStateManager(
keyExpressions, child.output, stateFormatVersion)

override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver
assert(outputMode.nonEmpty,
"Incorrect planning in IncrementalExecution, outputMode has not been set")

val stateManager = StreamingAggregationStateManager.newImpl(keyExpressions, child.output,
sqlContext.conf)

child.execute().mapPartitionsWithStateStore(
getStateInfo,
keyExpressions.toStructType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ package org.apache.spark.sql.streaming

import java.util.{Locale, TimeZone}

import org.scalatest.Assertions
import org.scalatest.BeforeAndAfterAll
import org.scalatest.{Assertions, BeforeAndAfterAll}

import org.apache.spark.{SparkEnv, SparkException}
import org.apache.spark.rdd.BlockRDD
Expand Down Expand Up @@ -54,30 +53,35 @@ class StreamingAggregationSuite extends StateStoreMetricsTest

import testImplicits._

val confAndTestNamePostfixMatrix = List(
(Seq(SQLConf.ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION.key -> "false"), ""),
(Seq(SQLConf.ADVANCED_REMOVE_REDUNDANT_IN_STATEFUL_AGGREGATION.key -> "true"),
" : enable remove redundant in stateful aggregation")
)
def executeFuncWithStateVersionSQLConf(
stateVersion: Int,
confPairs: Seq[(String, String)],
func: => Any): Unit = {
withSQLConf(confPairs ++
Seq(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION.key -> stateVersion.toString): _*) {
func
}
}

def testWithAggrOptions(testName: String, pairs: (String, String)*)(testFun: => Any): Unit = {
confAndTestNamePostfixMatrix.foreach {
case (conf, testNamePostfix) => withSQLConf(pairs ++ conf: _*) {
test(testName + testNamePostfix)(testFun)
def testWithAllStateVersions(name: String, confPairs: (String, String)*)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: the confPair param is used only in one location, do you think its worth adding it as a param? The only test that needs it can stay unchanged.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually it's basically from wondering of how withSQLConf works. Does withSQLConf handle nested withSQLConf properly? If then we don't need to add confPairs param at all, and if not I guess we might still want to add this.

(func: => Any): Unit = {
for (version <- StatefulOperatorsHelper.supportedVersions) {
test(s"$name - state format version $version") {
executeFuncWithStateVersionSQLConf(version, confPairs, func)
}
}
}

def testQuietlyWithAggrOptions(testName: String, pairs: (String, String)*)
(testFun: => Any): Unit = {
confAndTestNamePostfixMatrix.foreach {
case (conf, testNamePostfix) => withSQLConf(pairs ++ conf: _*) {
testQuietly(testName + testNamePostfix)(testFun)
def testQuietlyWithAllStateVersions(name: String, confPairs: (String, String)*)
(func: => Any): Unit = {
for (version <- StatefulOperatorsHelper.supportedVersions) {
testQuietly(s"$name - state format version $version") {
executeFuncWithStateVersionSQLConf(version, confPairs, func)
}
}
}

testWithAggrOptions("simple count, update mode") {
testWithAllStateVersions("simple count, update mode") {
val inputData = MemoryStream[Int]

val aggregated =
Expand All @@ -101,7 +105,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
)
}

testWithAggrOptions("count distinct") {
testWithAllStateVersions("count distinct") {
val inputData = MemoryStream[(Int, Seq[Int])]

val aggregated =
Expand All @@ -117,7 +121,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
)
}

testWithAggrOptions("simple count, complete mode") {
testWithAllStateVersions("simple count, complete mode") {
val inputData = MemoryStream[Int]

val aggregated =
Expand All @@ -140,7 +144,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
)
}

testWithAggrOptions("simple count, append mode") {
testWithAllStateVersions("simple count, append mode") {
val inputData = MemoryStream[Int]

val aggregated =
Expand All @@ -157,7 +161,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
}
}

testWithAggrOptions("sort after aggregate in complete mode") {
testWithAllStateVersions("sort after aggregate in complete mode") {
val inputData = MemoryStream[Int]

val aggregated =
Expand All @@ -182,7 +186,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
)
}

testWithAggrOptions("state metrics") {
testWithAllStateVersions("state metrics") {
val inputData = MemoryStream[Int]

val aggregated =
Expand Down Expand Up @@ -235,7 +239,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
)
}

testWithAggrOptions("multiple keys") {
testWithAllStateVersions("multiple keys") {
val inputData = MemoryStream[Int]

val aggregated =
Expand All @@ -252,7 +256,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
)
}

testQuietlyWithAggrOptions("midbatch failure") {
testQuietlyWithAllStateVersions("midbatch failure") {
val inputData = MemoryStream[Int]
FailureSingleton.firstTime = true
val aggregated =
Expand All @@ -278,7 +282,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
)
}

testWithAggrOptions("typed aggregators") {
testWithAllStateVersions("typed aggregators") {
val inputData = MemoryStream[(String, Int)]
val aggregated = inputData.toDS().groupByKey(_._1).agg(typed.sumLong(_._2))

Expand All @@ -288,7 +292,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
)
}

testWithAggrOptions("prune results by current_time, complete mode") {
testWithAllStateVersions("prune results by current_time, complete mode") {
import testImplicits._
val clock = new StreamManualClock
val inputData = MemoryStream[Long]
Expand Down Expand Up @@ -340,7 +344,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
)
}

testWithAggrOptions("prune results by current_date, complete mode") {
testWithAllStateVersions("prune results by current_date, complete mode") {
import testImplicits._
val clock = new StreamManualClock
val tz = TimeZone.getDefault.getID
Expand Down Expand Up @@ -389,7 +393,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
)
}

testWithAggrOptions("SPARK-19690: do not convert batch aggregation in streaming query " +
testWithAllStateVersions("SPARK-19690: do not convert batch aggregation in streaming query " +
"to streaming") {
val streamInput = MemoryStream[Int]
val batchDF = Seq(1, 2, 3, 4, 5)
Expand Down Expand Up @@ -454,7 +458,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
true
}

testWithAggrOptions("SPARK-21977: coalesce(1) with 0 partition RDD should be " +
testWithAllStateVersions("SPARK-21977: coalesce(1) with 0 partition RDD should be " +
"repartitioned to 1") {
val inputSource = new BlockRDDBackedSource(spark)
MockSourceProvider.withMockSources(inputSource) {
Expand Down Expand Up @@ -493,8 +497,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
}
}

testWithAggrOptions("SPARK-21977: coalesce(1) with aggregation should still be repartitioned " +
"when it has non-empty grouping keys") {
testWithAllStateVersions("SPARK-21977: coalesce(1) with aggregation should still be " +
"repartitioned when it has non-empty grouping keys") {
val inputSource = new BlockRDDBackedSource(spark)
MockSourceProvider.withMockSources(inputSource) {
withTempDir { tempDir =>
Expand Down Expand Up @@ -546,7 +550,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
}
}

testWithAggrOptions("SPARK-22230: last should change with new batches") {
testWithAllStateVersions("SPARK-22230: last should change with new batches") {
val input = MemoryStream[Int]

val aggregated = input.toDF().agg(last('value))
Expand All @@ -562,7 +566,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest
)
}

testWithAggrOptions("SPARK-23004: Ensure that TypedImperativeAggregate functions " +
testWithAllStateVersions("SPARK-23004: Ensure that TypedImperativeAggregate functions " +
"do not throw errors", SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
// See the JIRA SPARK-23004 for more details. In short, this test reproduces the error
// by ensuring the following.
Expand Down