-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-24763][SS] Remove redundant key data from value in streaming aggregation #21733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
4252f41
941b88d
abec57f
977428c
63dfb5d
e844636
26701a3
60c231e
b4a3807
e0ee04a
8629f59
65801a6
19888ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
* TODO list * replace all the usages for direct call of store.xxx whenever state manager is available * add iterator / remove in StreamingAggregationStateManager to remove restoreOriginRow * add docs
- Loading branch information
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,8 +20,12 @@ package org.apache.spark.sql.execution.streaming | |
| import scala.reflect.ClassTag | ||
|
|
||
| import org.apache.spark.TaskContext | ||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.SQLContext | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner} | ||
| import org.apache.spark.sql.internal.SessionState | ||
| import org.apache.spark.sql.types.StructType | ||
|
|
||
|
|
@@ -81,4 +85,110 @@ package object state { | |
| storeCoordinator) | ||
| } | ||
| } | ||
|
|
||
| sealed trait StreamingAggregationStateManager extends Serializable { | ||
| def getKey(row: InternalRow): UnsafeRow | ||
| def getStateValueSchema: StructType | ||
| def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow | ||
| def get(store: StateStore, key: UnsafeRow): UnsafeRow | ||
| def put(store: StateStore, row: UnsafeRow): Unit | ||
| } | ||
|
|
||
| object StreamingAggregationStateManager extends Logging { | ||
| val supportedVersions = Seq(1, 2) | ||
| val legacyVersion = 1 | ||
|
|
||
| def createStateManager( | ||
| keyExpressions: Seq[Attribute], | ||
| inputRowAttributes: Seq[Attribute], | ||
| stateFormatVersion: Int): StreamingAggregationStateManager = { | ||
| stateFormatVersion match { | ||
| case 1 => new StreamingAggregationStateManagerImplV1(keyExpressions, inputRowAttributes) | ||
| case 2 => new StreamingAggregationStateManagerImplV2(keyExpressions, inputRowAttributes) | ||
| case _ => throw new IllegalArgumentException(s"Version $stateFormatVersion is invalid") | ||
| } | ||
| } | ||
| } | ||
|
|
||
| abstract class StreamingAggregationStateManagerBaseImpl( | ||
| protected val keyExpressions: Seq[Attribute], | ||
| protected val inputRowAttributes: Seq[Attribute]) extends StreamingAggregationStateManager { | ||
|
|
||
| @transient protected lazy val keyProjector = | ||
| GenerateUnsafeProjection.generate(keyExpressions, inputRowAttributes) | ||
|
|
||
| def getKey(row: InternalRow): UnsafeRow = keyProjector(row) | ||
| } | ||
|
|
||
| class StreamingAggregationStateManagerImplV1( | ||
| keyExpressions: Seq[Attribute], | ||
| inputRowAttributes: Seq[Attribute]) | ||
| extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { | ||
|
|
||
| override def getStateValueSchema: StructType = inputRowAttributes.toStructType | ||
|
|
||
| override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { | ||
| rowPair.value | ||
| } | ||
|
|
||
| override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { | ||
| store.get(key) | ||
| } | ||
|
|
||
| override def put(store: StateStore, row: UnsafeRow): Unit = { | ||
| store.put(getKey(row), row) | ||
| } | ||
| } | ||
|
|
||
| class StreamingAggregationStateManagerImplV2( | ||
| keyExpressions: Seq[Attribute], | ||
| inputRowAttributes: Seq[Attribute]) | ||
| extends StreamingAggregationStateManagerBaseImpl(keyExpressions, inputRowAttributes) { | ||
|
|
||
| private val valueExpressions: Seq[Attribute] = inputRowAttributes.diff(keyExpressions) | ||
| private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions | ||
| private val needToProjectToRestoreValue: Boolean = | ||
|
||
| keyValueJoinedExpressions != inputRowAttributes | ||
|
|
||
| @transient private lazy val valueProjector = | ||
| GenerateUnsafeProjection.generate(valueExpressions, inputRowAttributes) | ||
|
|
||
| @transient private lazy val joiner = | ||
| GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), | ||
| StructType.fromAttributes(valueExpressions)) | ||
| @transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate( | ||
| keyValueJoinedExpressions, inputRowAttributes) | ||
|
||
|
|
||
| override def getStateValueSchema: StructType = valueExpressions.toStructType | ||
|
|
||
| override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { | ||
| val joinedRow = joiner.join(rowPair.key, rowPair.value) | ||
| if (needToProjectToRestoreValue) { | ||
| restoreValueProjector(joinedRow) | ||
| } else { | ||
| joinedRow | ||
| } | ||
| } | ||
|
|
||
| override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { | ||
| val savedState = store.get(key) | ||
| if (savedState == null) { | ||
| return savedState | ||
| } | ||
|
|
||
| val joinedRow = joiner.join(key, savedState) | ||
|
||
| if (needToProjectToRestoreValue) { | ||
| restoreValueProjector(joinedRow) | ||
| } else { | ||
| joinedRow | ||
| } | ||
| } | ||
|
|
||
| override def put(store: StateStore, row: UnsafeRow): Unit = { | ||
| val key = keyProjector(row) | ||
| val value = valueProjector(row) | ||
| store.put(key, value) | ||
| } | ||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils | |
| import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} | ||
| import org.apache.spark.sql.execution.exchange.Exchange | ||
| import org.apache.spark.sql.execution.streaming._ | ||
| import org.apache.spark.sql.execution.streaming.state.StateStore | ||
| import org.apache.spark.sql.execution.streaming.state.{StateStore, StreamingAggregationStateManager} | ||
| import org.apache.spark.sql.expressions.scalalang.typed | ||
| import org.apache.spark.sql.functions._ | ||
| import org.apache.spark.sql.internal.SQLConf | ||
|
|
@@ -65,7 +65,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest | |
|
|
||
| def testWithAllStateVersions(name: String, confPairs: (String, String)*) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. super nit: the confPair param is used only in one location, do you think its worth adding it as a param? The only test that needs it can stay unchanged.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually it's basically from wondering of how |
||
| (func: => Any): Unit = { | ||
| for (version <- StatefulOperatorsHelper.supportedVersions) { | ||
| for (version <- StreamingAggregationStateManager.supportedVersions) { | ||
| test(s"$name - state format version $version") { | ||
| executeFuncWithStateVersionSQLConf(version, confPairs, func) | ||
| } | ||
|
|
@@ -74,7 +74,7 @@ class StreamingAggregationSuite extends StateStoreMetricsTest | |
|
|
||
| def testQuietlyWithAllStateVersions(name: String, confPairs: (String, String)*) | ||
| (func: => Any): Unit = { | ||
| for (version <- StatefulOperatorsHelper.supportedVersions) { | ||
| for (version <- StreamingAggregationStateManager.supportedVersions) { | ||
| testQuietly(s"$name - state format version $version") { | ||
| executeFuncWithStateVersionSQLConf(version, confPairs, func) | ||
| } | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why is the input typed InternalRow where everything else is UnsafeRow? seems inconsistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getKeywas basically UnsafeProjection in statefulOperator so didn't necessarily require UnsafeRow. I just followed the usage to make it less restrict, but we know, in realityrowwill be always UnsafeRow. So OK to fix if it provides consistency.