-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-24763][SS] Remove redundant key data from value in streaming aggregation #21733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
4252f41
941b88d
abec57f
977428c
63dfb5d
e844636
26701a3
60c231e
b4a3807
e0ee04a
8629f59
65801a6
19888ab
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
…thods
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one or more | ||
| * contributor license agreements. See the NOTICE file distributed with | ||
| * this work for additional information regarding copyright ownership. | ||
| * The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| * (the "License"); you may not use this file except in compliance with | ||
| * the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, software | ||
| * distributed under the License is distributed on an "AS IS" BASIS, | ||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| * See the License for the specific language governing permissions and | ||
| * limitations under the License. | ||
| */ | ||
|
|
||
| package org.apache.spark.sql.execution.streaming | ||
|
|
||
| import org.apache.spark.internal.Logging | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner} | ||
| import org.apache.spark.sql.execution.streaming.state.{StateStore, UnsafeRowPair} | ||
| import org.apache.spark.sql.internal.SQLConf | ||
| import org.apache.spark.sql.types.StructType | ||
|
|
||
| object StatefulOperatorsHelper { | ||
| sealed trait StreamingAggregationStateManager extends Serializable { | ||
|
||
| def extractKey(row: InternalRow): UnsafeRow | ||
|
||
| def getValueExpressions: Seq[Attribute] | ||
|
||
| def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow | ||
|
||
| def get(store: StateStore, key: UnsafeRow): UnsafeRow | ||
|
||
| def put(store: StateStore, row: UnsafeRow): Unit | ||
| } | ||
|
|
||
| object StreamingAggregationStateManager extends Logging { | ||
| def newImpl( | ||
| keyExpressions: Seq[Attribute], | ||
| childOutput: Seq[Attribute], | ||
|
||
| conf: SQLConf): StreamingAggregationStateManager = { | ||
|
|
||
| if (conf.advancedRemoveRedundantInStatefulAggregation) { | ||
| log.info("Advanced option removeRedundantInStatefulAggregation activated!") | ||
| new StreamingAggregationStateManagerImplV2(keyExpressions, childOutput) | ||
| } else { | ||
| new StreamingAggregationStateManagerImplV1(keyExpressions, childOutput) | ||
| } | ||
| } | ||
| } | ||
|
|
||
| abstract class StreamingAggregationStateManagerBaseImpl( | ||
| protected val keyExpressions: Seq[Attribute], | ||
| protected val childOutput: Seq[Attribute]) extends StreamingAggregationStateManager { | ||
|
|
||
| @transient protected lazy val keyProjector = | ||
| GenerateUnsafeProjection.generate(keyExpressions, childOutput) | ||
|
|
||
| def extractKey(row: InternalRow): UnsafeRow = keyProjector(row) | ||
| } | ||
|
|
||
| class StreamingAggregationStateManagerImplV1( | ||
| keyExpressions: Seq[Attribute], | ||
| childOutput: Seq[Attribute]) | ||
| extends StreamingAggregationStateManagerBaseImpl(keyExpressions, childOutput) { | ||
|
|
||
| override def getValueExpressions: Seq[Attribute] = { | ||
| childOutput | ||
| } | ||
|
|
||
| override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { | ||
| rowPair.value | ||
| } | ||
|
|
||
| override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { | ||
| store.get(key) | ||
| } | ||
|
|
||
| override def put(store: StateStore, row: UnsafeRow): Unit = { | ||
| store.put(extractKey(row), row) | ||
| } | ||
| } | ||
|
|
||
| class StreamingAggregationStateManagerImplV2( | ||
|
||
| keyExpressions: Seq[Attribute], | ||
| childOutput: Seq[Attribute]) | ||
| extends StreamingAggregationStateManagerBaseImpl(keyExpressions, childOutput) { | ||
|
|
||
| private val valueExpressions: Seq[Attribute] = childOutput.diff(keyExpressions) | ||
| private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions | ||
| private val needToProjectToRestoreValue: Boolean = keyValueJoinedExpressions != childOutput | ||
|
|
||
| @transient private lazy val valueProjector = | ||
| GenerateUnsafeProjection.generate(valueExpressions, childOutput) | ||
|
|
||
| @transient private lazy val joiner = | ||
| GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), | ||
| StructType.fromAttributes(valueExpressions)) | ||
| @transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate( | ||
| keyValueJoinedExpressions, childOutput) | ||
|
|
||
| override def getValueExpressions: Seq[Attribute] = { | ||
| valueExpressions | ||
| } | ||
|
|
||
| override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = { | ||
| val joinedRow = joiner.join(rowPair.key, rowPair.value) | ||
| if (needToProjectToRestoreValue) { | ||
| restoreValueProjector(joinedRow) | ||
| } else { | ||
| joinedRow | ||
| } | ||
| } | ||
|
|
||
| override def get(store: StateStore, key: UnsafeRow): UnsafeRow = { | ||
| val savedState = store.get(key) | ||
| if (savedState == null) { | ||
| return savedState | ||
| } | ||
|
|
||
| val joinedRow = joiner.join(key, savedState) | ||
| if (needToProjectToRestoreValue) { | ||
| restoreValueProjector(joinedRow) | ||
| } else { | ||
| joinedRow | ||
| } | ||
| } | ||
|
|
||
| override def put(store: StateStore, row: UnsafeRow): Unit = { | ||
| val key = keyProjector(row) | ||
| val value = valueProjector(row) | ||
| store.put(key, value) | ||
| } | ||
| } | ||
|
|
||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,18 +20,17 @@ package org.apache.spark.sql.execution.streaming | |
| import java.util.UUID | ||
| import java.util.concurrent.TimeUnit._ | ||
|
|
||
| import scala.collection.JavaConverters._ | ||
|
|
||
| import org.apache.spark.rdd.RDD | ||
| import org.apache.spark.sql.catalyst.InternalRow | ||
| import org.apache.spark.sql.catalyst.errors._ | ||
| import org.apache.spark.sql.catalyst.expressions._ | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner, Predicate} | ||
| import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} | ||
| import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark | ||
| import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} | ||
| import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ | ||
| import org.apache.spark.sql.execution._ | ||
| import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} | ||
| import org.apache.spark.sql.execution.streaming.StatefulOperatorsHelper.StreamingAggregationStateManager | ||
| import org.apache.spark.sql.execution.streaming.state._ | ||
| import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress} | ||
| import org.apache.spark.sql.types._ | ||
|
|
@@ -204,35 +203,18 @@ case class StateStoreRestoreExec( | |
| child: SparkPlan) | ||
| extends UnaryExecNode with StateStoreReader { | ||
|
|
||
| val removeRedundant: Boolean = sqlContext.conf.advancedRemoveRedundantInStatefulAggregation | ||
| if (removeRedundant) { | ||
| log.info("Advanced option removeRedundantInStatefulAggregation activated!") | ||
| } | ||
|
|
||
| val valueExpressions: Seq[Attribute] = if (removeRedundant) { | ||
| child.output.diff(keyExpressions) | ||
| } else { | ||
| child.output | ||
| } | ||
| val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions | ||
| val needToProjectToRestoreValue: Boolean = keyValueJoinedExpressions != child.output | ||
|
|
||
| override protected def doExecute(): RDD[InternalRow] = { | ||
| val numOutputRows = longMetric("numOutputRows") | ||
| val stateManager = StreamingAggregationStateManager.newImpl(keyExpressions, child.output, | ||
| sqlContext.conf) | ||
|
|
||
| child.execute().mapPartitionsWithStateStore( | ||
| getStateInfo, | ||
| keyExpressions.toStructType, | ||
| valueExpressions.toStructType, | ||
| stateManager.getValueExpressions.toStructType, | ||
|
||
| indexOrdinal = None, | ||
| sqlContext.sessionState, | ||
| Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => | ||
| val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) | ||
| val joiner = GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), | ||
| StructType.fromAttributes(valueExpressions)) | ||
| val restoreValueProject = GenerateUnsafeProjection.generate( | ||
| keyValueJoinedExpressions, child.output) | ||
|
|
||
| val hasInput = iter.hasNext | ||
| if (!hasInput && keyExpressions.isEmpty) { | ||
| // If our `keyExpressions` are empty, we're getting a global aggregation. In that case | ||
|
|
@@ -243,23 +225,8 @@ case class StateStoreRestoreExec( | |
| store.iterator().map(_.value) | ||
| } else { | ||
| iter.flatMap { row => | ||
| val key = getKey(row) | ||
| val savedState = store.get(key) | ||
| val restoredRow = if (removeRedundant) { | ||
| if (savedState == null) { | ||
| savedState | ||
| } else { | ||
| val joinedRow = joiner.join(key, savedState) | ||
| if (needToProjectToRestoreValue) { | ||
| restoreValueProject(joinedRow) | ||
| } else { | ||
| joinedRow | ||
| } | ||
| } | ||
| } else { | ||
| savedState | ||
| } | ||
|
|
||
| val key = stateManager.extractKey(row) | ||
| val restoredRow = stateManager.get(store, key) | ||
| numOutputRows += 1 | ||
| Option(restoredRow).toSeq :+ row | ||
| } | ||
|
|
@@ -291,38 +258,21 @@ case class StateStoreSaveExec( | |
| child: SparkPlan) | ||
| extends UnaryExecNode with StateStoreWriter with WatermarkSupport { | ||
|
|
||
| val removeRedundant: Boolean = sqlContext.conf.advancedRemoveRedundantInStatefulAggregation | ||
| if (removeRedundant) { | ||
| log.info("Advanced option removeRedundantInStatefulAggregation activated!") | ||
| } | ||
|
|
||
| val valueExpressions: Seq[Attribute] = if (removeRedundant) { | ||
| child.output.diff(keyExpressions) | ||
| } else { | ||
| child.output | ||
| } | ||
| val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions | ||
| val needToProjectToRestoreValue: Boolean = keyValueJoinedExpressions != child.output | ||
|
|
||
| override protected def doExecute(): RDD[InternalRow] = { | ||
| metrics // force lazy init at driver | ||
| assert(outputMode.nonEmpty, | ||
| "Incorrect planning in IncrementalExecution, outputMode has not been set") | ||
|
|
||
| val stateManager = StreamingAggregationStateManager.newImpl(keyExpressions, child.output, | ||
| sqlContext.conf) | ||
|
|
||
| child.execute().mapPartitionsWithStateStore( | ||
| getStateInfo, | ||
| keyExpressions.toStructType, | ||
| valueExpressions.toStructType, | ||
| stateManager.getValueExpressions.toStructType, | ||
| indexOrdinal = None, | ||
| sqlContext.sessionState, | ||
| Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => | ||
| val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) | ||
| val getValue = GenerateUnsafeProjection.generate(valueExpressions, child.output) | ||
| val joiner = GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions), | ||
| StructType.fromAttributes(valueExpressions)) | ||
| val restoreValueProject = GenerateUnsafeProjection.generate( | ||
| keyValueJoinedExpressions, child.output) | ||
|
|
||
| val numOutputRows = longMetric("numOutputRows") | ||
| val numUpdatedStateRows = longMetric("numUpdatedStateRows") | ||
| val allUpdatesTimeMs = longMetric("allUpdatesTimeMs") | ||
|
|
@@ -335,13 +285,7 @@ case class StateStoreSaveExec( | |
| allUpdatesTimeMs += timeTakenMs { | ||
| while (iter.hasNext) { | ||
| val row = iter.next().asInstanceOf[UnsafeRow] | ||
| val key = getKey(row) | ||
| val value = if (removeRedundant) { | ||
| getValue(row) | ||
| } else { | ||
| row | ||
| } | ||
| store.put(key, value) | ||
| stateManager.put(store, row) | ||
| numUpdatedStateRows += 1 | ||
| } | ||
| } | ||
|
|
@@ -352,18 +296,7 @@ case class StateStoreSaveExec( | |
| setStoreMetrics(store) | ||
| store.iterator().map { rowPair => | ||
| numOutputRows += 1 | ||
|
|
||
| if (removeRedundant) { | ||
| val joinedRow = joiner.join(rowPair.key, rowPair.value) | ||
| if (needToProjectToRestoreValue) { | ||
| restoreValueProject(joinedRow) | ||
| } else { | ||
| joinedRow | ||
| } | ||
| } else { | ||
| rowPair.value | ||
| } | ||
|
|
||
| stateManager.restoreOriginRow(rowPair) | ||
| } | ||
|
|
||
| // Update and output only rows being evicted from the StateStore | ||
|
|
@@ -373,13 +306,7 @@ case class StateStoreSaveExec( | |
| val filteredIter = iter.filter(row => !watermarkPredicateForData.get.eval(row)) | ||
| while (filteredIter.hasNext) { | ||
| val row = filteredIter.next().asInstanceOf[UnsafeRow] | ||
| val key = getKey(row) | ||
| val value = if (removeRedundant) { | ||
| getValue(row) | ||
| } else { | ||
| row | ||
| } | ||
| store.put(key, value) | ||
| stateManager.put(store, row) | ||
| numUpdatedStateRows += 1 | ||
| } | ||
| } | ||
|
|
@@ -394,17 +321,7 @@ case class StateStoreSaveExec( | |
| val rowPair = rangeIter.next() | ||
| if (watermarkPredicateForKeys.get.eval(rowPair.key)) { | ||
| store.remove(rowPair.key) | ||
|
|
||
| if (removeRedundant) { | ||
| val joinedRow = joiner.join(rowPair.key, rowPair.value) | ||
| removedValueRow = if (needToProjectToRestoreValue) { | ||
| restoreValueProject(joinedRow) | ||
| } else { | ||
| joinedRow | ||
| } | ||
| } else { | ||
| removedValueRow = rowPair.value | ||
| } | ||
| removedValueRow = stateManager.restoreOriginRow(rowPair) | ||
| } | ||
| } | ||
| if (removedValueRow == null) { | ||
|
|
@@ -436,13 +353,7 @@ case class StateStoreSaveExec( | |
| override protected def getNext(): InternalRow = { | ||
| if (baseIterator.hasNext) { | ||
| val row = baseIterator.next().asInstanceOf[UnsafeRow] | ||
| val key = getKey(row) | ||
| val value = if (removeRedundant) { | ||
| getValue(row) | ||
| } else { | ||
| row | ||
| } | ||
| store.put(key, value) | ||
| stateManager.put(store, row) | ||
| numOutputRows += 1 | ||
| numUpdatedStateRows += 1 | ||
| row | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure why it is inside this generically named object
StatefulOperatorsHelper. Rather make it a top-level traitStreamingAggregationStateManagerin theexecution.streaming.statepackage (similar toFlatMapGroupsWithStateExecHelper).If you are modeling this against my state format PR for mapGroupsWithState, the only reason I put it in the
StateManagerclass inside objectFlatMapGroupsWithStateExecHelperwas to avoid names likeFlatMapGroupsWithStateExec_StateManager. I dont think that concern applies if you use the nameStreamingAggregationStateManager.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah right. I found your PR useful to get an idea of how to model the classes because it was dealing with similar requirement, but didn't indicate the reason why you place it into StatefulOperatorsHelper. I'll move them to the state package.