Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Refine code change: introduce trait and classes to group duplicate me…
…thods
  • Loading branch information
HeartSaVioR committed Jul 20, 2018
commit 977428cb35a6fc0a9fa7a0ca1a51e39a94447a01
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow}
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner}
import org.apache.spark.sql.execution.streaming.state.{StateStore, UnsafeRowPair}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.StructType

object StatefulOperatorsHelper {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why it is inside this generically named object StatefulOperatorsHelper. Rather make it a top-level trait StreamingAggregationStateManager in the execution.streaming.state package (similar to FlatMapGroupsWithStateExecHelper).

If you are modeling this against my state format PR for mapGroupsWithState, the only reason I put it in the StateManager class inside object FlatMapGroupsWithStateExecHelper was to avoid names like FlatMapGroupsWithStateExec_StateManager. I dont think that concern applies if you use the name StreamingAggregationStateManager.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah right. I found your PR useful to get an idea of how to model the classes because it was dealing with similar requirement, but didn't indicate the reason why you place it into StatefulOperatorsHelper. I'll move them to the state package.

sealed trait StreamingAggregationStateManager extends Serializable {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix.

def extractKey(row: InternalRow): UnsafeRow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the row here? add docs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe rename this to getKey to be consistent with other methods.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renaming sounds better. Will rename, and will also add docs.

def getValueExpressions: Seq[Attribute]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does getValueExpressions mean? its not obvious from the name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is to define the schema of value from / to state. For V1 it would be same to input schema and for V2 it would be input schema - key schema. Would getStateValueExpressions be OK for us?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be going to be getStateValueSchema btw, once we change return type.

def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think this method is needed if you rather add methods getIterator and remove to this interface. The only reason restoreOriginRow this being used is because the operator is directly accessing the store (through store.remove() and store.iterator()) and then trying to fix the row, instead of the delegating those operations to the StateManager. In fact, if there exists a StateManager to manage all the state in the store, then ALL operations to add/remove state should go through the manager and store should not be accessed directly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, if there exists a StateManager to manage all the state in the store, then ALL operations to add/remove state should go through the manager and store should not be accessed directly.

Totally agreed that it should be better design of StateManager. I don't remember I tried to do before, so let me try applying your suggestion and see there's anything blocks.

def get(store: StateStore, key: UnsafeRow): UnsafeRow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are you getting? what are you putting? More docs please :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might think naively about this: I thought its interface is similar to StateStore so wondered we need to add docs, but I think I was wrong. Will add docs. Thanks for the insightful input!

def put(store: StateStore, row: UnsafeRow): Unit
}

object StreamingAggregationStateManager extends Logging {
def newImpl(
keyExpressions: Seq[Attribute],
childOutput: Seq[Attribute],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe rename childOutput to inputRowAttributes to make the name more meaningful in the context of the StateManager interface (which does not have any concept of child).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds much better and you're right about concept of child. Will rename.

conf: SQLConf): StreamingAggregationStateManager = {

if (conf.advancedRemoveRedundantInStatefulAggregation) {
log.info("Advanced option removeRedundantInStatefulAggregation activated!")
new StreamingAggregationStateManagerImplV2(keyExpressions, childOutput)
} else {
new StreamingAggregationStateManagerImplV1(keyExpressions, childOutput)
}
}
}

abstract class StreamingAggregationStateManagerBaseImpl(
protected val keyExpressions: Seq[Attribute],
protected val childOutput: Seq[Attribute]) extends StreamingAggregationStateManager {

@transient protected lazy val keyProjector =
GenerateUnsafeProjection.generate(keyExpressions, childOutput)

def extractKey(row: InternalRow): UnsafeRow = keyProjector(row)
}

class StreamingAggregationStateManagerImplV1(
keyExpressions: Seq[Attribute],
childOutput: Seq[Attribute])
extends StreamingAggregationStateManagerBaseImpl(keyExpressions, childOutput) {

override def getValueExpressions: Seq[Attribute] = {
childOutput
}

override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = {
rowPair.value
}

override def get(store: StateStore, key: UnsafeRow): UnsafeRow = {
store.get(key)
}

override def put(store: StateStore, row: UnsafeRow): Unit = {
store.put(extractKey(row), row)
}
}

class StreamingAggregationStateManagerImplV2(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add docs on the state formats. How does each format organize the data in the row?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point. I might be in a rush to show its shape. Will add doc for state formats in both V1 and V2.

keyExpressions: Seq[Attribute],
childOutput: Seq[Attribute])
extends StreamingAggregationStateManagerBaseImpl(keyExpressions, childOutput) {

private val valueExpressions: Seq[Attribute] = childOutput.diff(keyExpressions)
private val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions
private val needToProjectToRestoreValue: Boolean = keyValueJoinedExpressions != childOutput

@transient private lazy val valueProjector =
GenerateUnsafeProjection.generate(valueExpressions, childOutput)

@transient private lazy val joiner =
GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions),
StructType.fromAttributes(valueExpressions))
@transient private lazy val restoreValueProjector = GenerateUnsafeProjection.generate(
keyValueJoinedExpressions, childOutput)

override def getValueExpressions: Seq[Attribute] = {
valueExpressions
}

override def restoreOriginRow(rowPair: UnsafeRowPair): UnsafeRow = {
val joinedRow = joiner.join(rowPair.key, rowPair.value)
if (needToProjectToRestoreValue) {
restoreValueProjector(joinedRow)
} else {
joinedRow
}
}

override def get(store: StateStore, key: UnsafeRow): UnsafeRow = {
val savedState = store.get(key)
if (savedState == null) {
return savedState
}

val joinedRow = joiner.join(key, savedState)
if (needToProjectToRestoreValue) {
restoreValueProjector(joinedRow)
} else {
joinedRow
}
}

override def put(store: StateStore, row: UnsafeRow): Unit = {
val key = keyProjector(row)
val value = valueProjector(row)
store.put(key, value)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,17 @@ package org.apache.spark.sql.execution.streaming
import java.util.UUID
import java.util.concurrent.TimeUnit._

import scala.collection.JavaConverters._

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateUnsafeRowJoiner, Predicate}
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate}
import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.streaming.StatefulOperatorsHelper.StreamingAggregationStateManager
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{OutputMode, StateOperatorProgress}
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -204,35 +203,18 @@ case class StateStoreRestoreExec(
child: SparkPlan)
extends UnaryExecNode with StateStoreReader {

val removeRedundant: Boolean = sqlContext.conf.advancedRemoveRedundantInStatefulAggregation
if (removeRedundant) {
log.info("Advanced option removeRedundantInStatefulAggregation activated!")
}

val valueExpressions: Seq[Attribute] = if (removeRedundant) {
child.output.diff(keyExpressions)
} else {
child.output
}
val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions
val needToProjectToRestoreValue: Boolean = keyValueJoinedExpressions != child.output

override protected def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
val stateManager = StreamingAggregationStateManager.newImpl(keyExpressions, child.output,
sqlContext.conf)

child.execute().mapPartitionsWithStateStore(
getStateInfo,
keyExpressions.toStructType,
valueExpressions.toStructType,
stateManager.getValueExpressions.toStructType,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like you need to only get the schema, not the actual expressions. So the StateManager can only return the schema and not the expressions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right. Sounds like StructType is preferred than Seq[Attribute] in this case. Will apply.

Maybe dumb question from newbie on Spark SQL (still trying to get familiar with) : I guess we prefer StructType in this case cause it's less restrictive and also get rid of headache of dealing with fields reference. Do I understand correctly?

indexOrdinal = None,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) =>
val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
val joiner = GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions),
StructType.fromAttributes(valueExpressions))
val restoreValueProject = GenerateUnsafeProjection.generate(
keyValueJoinedExpressions, child.output)

val hasInput = iter.hasNext
if (!hasInput && keyExpressions.isEmpty) {
// If our `keyExpressions` are empty, we're getting a global aggregation. In that case
Expand All @@ -243,23 +225,8 @@ case class StateStoreRestoreExec(
store.iterator().map(_.value)
} else {
iter.flatMap { row =>
val key = getKey(row)
val savedState = store.get(key)
val restoredRow = if (removeRedundant) {
if (savedState == null) {
savedState
} else {
val joinedRow = joiner.join(key, savedState)
if (needToProjectToRestoreValue) {
restoreValueProject(joinedRow)
} else {
joinedRow
}
}
} else {
savedState
}

val key = stateManager.extractKey(row)
val restoredRow = stateManager.get(store, key)
numOutputRows += 1
Option(restoredRow).toSeq :+ row
}
Expand Down Expand Up @@ -291,38 +258,21 @@ case class StateStoreSaveExec(
child: SparkPlan)
extends UnaryExecNode with StateStoreWriter with WatermarkSupport {

val removeRedundant: Boolean = sqlContext.conf.advancedRemoveRedundantInStatefulAggregation
if (removeRedundant) {
log.info("Advanced option removeRedundantInStatefulAggregation activated!")
}

val valueExpressions: Seq[Attribute] = if (removeRedundant) {
child.output.diff(keyExpressions)
} else {
child.output
}
val keyValueJoinedExpressions: Seq[Attribute] = keyExpressions ++ valueExpressions
val needToProjectToRestoreValue: Boolean = keyValueJoinedExpressions != child.output

override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver
assert(outputMode.nonEmpty,
"Incorrect planning in IncrementalExecution, outputMode has not been set")

val stateManager = StreamingAggregationStateManager.newImpl(keyExpressions, child.output,
sqlContext.conf)

child.execute().mapPartitionsWithStateStore(
getStateInfo,
keyExpressions.toStructType,
valueExpressions.toStructType,
stateManager.getValueExpressions.toStructType,
indexOrdinal = None,
sqlContext.sessionState,
Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) =>
val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output)
val getValue = GenerateUnsafeProjection.generate(valueExpressions, child.output)
val joiner = GenerateUnsafeRowJoiner.create(StructType.fromAttributes(keyExpressions),
StructType.fromAttributes(valueExpressions))
val restoreValueProject = GenerateUnsafeProjection.generate(
keyValueJoinedExpressions, child.output)

val numOutputRows = longMetric("numOutputRows")
val numUpdatedStateRows = longMetric("numUpdatedStateRows")
val allUpdatesTimeMs = longMetric("allUpdatesTimeMs")
Expand All @@ -335,13 +285,7 @@ case class StateStoreSaveExec(
allUpdatesTimeMs += timeTakenMs {
while (iter.hasNext) {
val row = iter.next().asInstanceOf[UnsafeRow]
val key = getKey(row)
val value = if (removeRedundant) {
getValue(row)
} else {
row
}
store.put(key, value)
stateManager.put(store, row)
numUpdatedStateRows += 1
}
}
Expand All @@ -352,18 +296,7 @@ case class StateStoreSaveExec(
setStoreMetrics(store)
store.iterator().map { rowPair =>
numOutputRows += 1

if (removeRedundant) {
val joinedRow = joiner.join(rowPair.key, rowPair.value)
if (needToProjectToRestoreValue) {
restoreValueProject(joinedRow)
} else {
joinedRow
}
} else {
rowPair.value
}

stateManager.restoreOriginRow(rowPair)
}

// Update and output only rows being evicted from the StateStore
Expand All @@ -373,13 +306,7 @@ case class StateStoreSaveExec(
val filteredIter = iter.filter(row => !watermarkPredicateForData.get.eval(row))
while (filteredIter.hasNext) {
val row = filteredIter.next().asInstanceOf[UnsafeRow]
val key = getKey(row)
val value = if (removeRedundant) {
getValue(row)
} else {
row
}
store.put(key, value)
stateManager.put(store, row)
numUpdatedStateRows += 1
}
}
Expand All @@ -394,17 +321,7 @@ case class StateStoreSaveExec(
val rowPair = rangeIter.next()
if (watermarkPredicateForKeys.get.eval(rowPair.key)) {
store.remove(rowPair.key)

if (removeRedundant) {
val joinedRow = joiner.join(rowPair.key, rowPair.value)
removedValueRow = if (needToProjectToRestoreValue) {
restoreValueProject(joinedRow)
} else {
joinedRow
}
} else {
removedValueRow = rowPair.value
}
removedValueRow = stateManager.restoreOriginRow(rowPair)
}
}
if (removedValueRow == null) {
Expand Down Expand Up @@ -436,13 +353,7 @@ case class StateStoreSaveExec(
override protected def getNext(): InternalRow = {
if (baseIterator.hasNext) {
val row = baseIterator.next().asInstanceOf[UnsafeRow]
val key = getKey(row)
val value = if (removeRedundant) {
getValue(row)
} else {
row
}
store.put(key, value)
stateManager.put(store, row)
numOutputRows += 1
numUpdatedStateRows += 1
row
Expand Down