From d10da2700a970c82537b678db34b2d80cebebcc8 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 17 Jan 2017 16:45:54 -0800 Subject: [PATCH 01/21] Prototype - almost working --- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 + .../sql/catalyst/plans/logical/object.scala | 34 +++++++ .../catalyst/streaming/InternalState.scala | 20 ++++ .../spark/sql/KeyValueGroupedDataset.scala | 14 +++ .../spark/sql/execution/SparkStrategies.scala | 18 +++- .../apache/spark/sql/execution/objects.scala | 6 ++ .../streaming/IncrementalExecution.scala | 17 +++- .../streaming/StatefulAggregate.scala | 69 +++++++++++-- .../state/HDFSBackedStateStoreProvider.scala | 17 ++++ .../execution/streaming/state/StateImpl.scala | 97 +++++++++++++++++++ .../streaming/state/StateStore.scala | 2 + .../apache/spark/sql/streaming/State.scala | 44 +++++++++ .../streaming/MapGroupsWithStateSuite.scala | 55 +++++++++++ 13 files changed, 385 insertions(+), 10 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalState.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateImpl.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/streaming/State.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index aa77a6efef347..bf2dbebf7fcce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -405,6 +405,8 @@ trait CheckAnalysis extends PredicateHelper { } case o if !o.resolved => + println(o) + println(o.expressions.filterNot(_.resolved).mkString("\n")) failAnalysis( s"unresolved operator ${operator.simpleString}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 0ab4c9016623e..6168df8292ff9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.streaming.InternalState import org.apache.spark.sql.types._ object CatalystSerde { @@ -313,6 +314,39 @@ case class MapGroups( outputObjAttr: Attribute, child: LogicalPlan) extends UnaryNode with ObjectProducer +/** Factory for constructing new `MapGroups` nodes. */ +object MapGroupsWithState { + def apply[K: Encoder, T: Encoder, S: Encoder, U: Encoder]( + func: (T, InternalState[S]) => U, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + child: LogicalPlan): LogicalPlan = { + val mapped = new MapGroupsWithState( + func.asInstanceOf[(Any, InternalState[Any]) => Any], + UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), + UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes), + groupingAttributes, + dataAttributes, + CatalystSerde.generateObjAttr[U], + encoderFor[S].deserializer, + encoderFor[S].namedExpressions, + child) + CatalystSerde.serialize[U](mapped) + } +} + +case class MapGroupsWithState( + func: (Any, InternalState[Any]) => Any, + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, + stateDeserializer: Expression, + stateSerializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectProducer + + /** Factory for constructing new `FlatMapGroupsInR` nodes. */ object FlatMapGroupsInR { def apply( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalState.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalState.scala new file mode 100644 index 0000000000000..0d9391b535f99 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalState.scala @@ -0,0 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.streaming + +trait InternalState[S] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 395d709f26591..796f3f41c54ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -24,8 +24,10 @@ import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.streaming.InternalState import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator +import org.apache.spark.sql.streaming.State /** * :: Experimental :: @@ -108,6 +110,18 @@ class KeyValueGroupedDataset[K, V] private[sql]( mapValues { (v: V) => func.call(v) } } + def mapValuesWithState[STATE: Encoder, OUT: Encoder]( + func: (V, State[STATE]) => OUT): Dataset[OUT] = { + Dataset[OUT]( + sparkSession, + MapGroupsWithState[K, V, STATE, OUT]( + func.asInstanceOf[(V, InternalState[STATE]) => OUT], + groupingAttributes, + dataAttributes, + logicalPlan)) + } + + /** * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping * over the Dataset to extract the keys and then running a distinct operation on those. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index fafb91967086f..f00d01f3ec4da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, EventTimeWatermark, LogicalPlan, MapGroupsWithState} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} @@ -244,6 +244,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object MapGroupsWithStateStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case MapGroupsWithState( + func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, + stateDeser, stateSer, child) => + val execPlan = MapGroupsWithStateExec( + func, keyDeser, valueDeser, + groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer, + planLater(child)) + execPlan :: Nil + case _ => + println("here") + Nil + } + } + /** * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index fde3b2a528994..313452099d2eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -144,6 +144,12 @@ object ObjectOperator { (i: InternalRow) => proj(i).get(0, deserializer.dataType) } + def deserializeRowToObject( + deserializer: Expression): InternalRow => Any = { + val proj = GenerateSafeProjection.generate(deserializer :: Nil) + (i: InternalRow) => proj(i).get(0, deserializer.dataType) + } + def serializeObjectToRow(serializer: Seq[Expression]): Any => UnsafeRow = { val proj = GenerateUnsafeProjection.generate(serializer) val objType = serializer.head.collect { case b: BoundReference => b.dataType }.head diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 6ab6fa61dc200..147401b370793 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.util.concurrent.atomic.AtomicInteger + import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp, Literal} import org.apache.spark.sql.SparkSession @@ -41,6 +43,7 @@ class IncrementalExecution( // TODO: make this always part of planning. val stateStrategy = sparkSession.sessionState.planner.StatefulAggregationStrategy +: + sparkSession.sessionState.planner.MapGroupsWithStateStrategy +: sparkSession.sessionState.planner.StreamingRelationStrategy +: sparkSession.sessionState.experimentalMethods.extraStrategies @@ -68,7 +71,7 @@ class IncrementalExecution( * Records the current id for a given stateful operator in the query plan as the `state` * preparation walks the query plan. */ - private var operatorId = 0 + private val operatorId = new AtomicInteger(0) /** Locates save/restore pairs surrounding aggregation. */ val state = new Rule[SparkPlan] { @@ -77,8 +80,8 @@ class IncrementalExecution( case StateStoreSaveExec(keys, None, None, None, UnaryExecNode(agg, StateStoreRestoreExec(keys2, None, child))) => - val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId) - operatorId += 1 + val stateId = + OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) StateStoreSaveExec( keys, @@ -90,6 +93,14 @@ class IncrementalExecution( keys, Some(stateId), child) :: Nil)) + case MapGroupsWithStateExec( + func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, + None, stateDeser, stateSer, child) => + val stateId = + OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) + MapGroupsWithStateExec( + func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, + Some(stateId), stateDeser, stateSer, child) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index 0551e4b4a2ef5..6cabd6f217087 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -23,14 +23,14 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark -import org.apache.spark.sql.catalyst.plans.physical.Partitioning +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution +import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.streaming.{OutputMode, State} +import org.apache.spark.sql.types.{DataType, StructType} /** Used to identify the state store for a given operator. */ @@ -176,7 +176,7 @@ case class StateStoreSaveExec( } // Assumption: Append mode can be done only when watermark has been specified - store.remove(watermarkPredicate.get.eval) + store.remove(watermarkPredicate.get.eval _) store.commit() numTotalStateRows += store.numKeys() @@ -199,7 +199,7 @@ case class StateStoreSaveExec( override def hasNext: Boolean = { if (!baseIterator.hasNext) { // Remove old aggregates if watermark specified - if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval) + if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval _) store.commit() numTotalStateRows += store.numKeys() false @@ -227,3 +227,60 @@ case class StateStoreSaveExec( override def outputPartitioning: Partitioning = child.outputPartitioning } + +case class MapGroupsWithStateExec( + func: (Any, State[Any]) => Any, + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, + stateId: Option[OperatorStateId], + stateDeserializer: Expression, + stateSerializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StatefulOperator { + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(groupingAttributes) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsWithStateStore[InternalRow]( + getStateId.checkpointLocation, + operatorId = getStateId.operatorId, + storeVersion = getStateId.batchId, + groupingAttributes.toStructType, + child.output.toStructType, + sqlContext.sessionState, + Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => + val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) + val getKey = GenerateUnsafeProjection.generate(groupingAttributes, child.output) + val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) + val outputMappedObj = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + + val getStateObj = + ObjectOperator.deserializeRowToObject(stateDeserializer) + val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer) + + iter.map { row => + val key = getKey(row) + val keyObj = getKeyObj(row) + val valueObj = getValueObj(row) + val stateObjOption = store.get(key).map(getStateObj) + val wrappedState = new StateImpl[Any]() + wrappedState.wrap(stateObjOption) + val mapped = func(key, wrappedState) + if (wrappedState.isRemoved) { + store.remove(key) + } else if (wrappedState.isUpdated) { + store.put(key, outputStateObj(wrappedState.get())) + } + outputMappedObj(mapped) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 4f3f8181d1f4e..3f4fd8f4ab288 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -147,6 +147,23 @@ private[state] class HDFSBackedStateStoreProvider( } } + override def remove(key: UnsafeRow): Unit = { + verify(state == UPDATING, "Cannot remove after already committed or aborted") + if (mapToUpdate.containsKey(key)) { + val value = mapToUpdate.remove(key) + Option(allUpdates.get(key)) match { + case Some(ValueUpdated(_, _)) | None => + // Value existed in previous version and maybe was updated, mark removed + allUpdates.put(key, ValueRemoved(key, value)) + case Some(ValueAdded(_, _)) => + // Value did not exist in previous version and was added, should not appear in updates + allUpdates.remove(key) + case Some(ValueRemoved(_, _)) => + // Remove already in update map, no need to change + } + } + } + /** Commit all the updates that have been made to the store, and return the new version. */ override def commit(): Long = { verify(state == UPDATING, "Cannot commit after already committed or aborted") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateImpl.scala new file mode 100644 index 0000000000000..a911d142f8859 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateImpl.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.streaming.State + +/** Internal implementation of the [[State]] interface */ +class StateImpl[S] extends State[S] { + private var state: S = null.asInstanceOf[S] + private var defined: Boolean = false + private var timingOut: Boolean = false + private var updated: Boolean = false + private var removed: Boolean = false + + // ========= Public API ========= + override def exists(): Boolean = { + defined + } + + override def get(): S = { + if (defined) { + state + } else { + throw new NoSuchElementException("State is not set") + } + } + + override def update(newState: S): Unit = { + require(!removed, "Cannot update the state after it has been removed") + require(!timingOut, "Cannot update the state that is timing out") + state = newState + defined = true + updated = true + } + + override def isTimingOut(): Boolean = { + timingOut + } + + override def remove(): Unit = { + require(!timingOut, "Cannot remove the state that is timing out") + require(!removed, "Cannot remove the state that has already been removed") + defined = false + updated = false + removed = true + } + + // ========= Internal API ========= + + /** Whether the state has been marked for removing */ + def isRemoved(): Boolean = { + removed + } + + /** Whether the state has been been updated */ + def isUpdated(): Boolean = { + updated + } + + def wrap(optionalState: Option[S]): Unit = { + optionalState match { + case Some(newState) => + this.state = newState + defined = true + + case None => + this.state = null.asInstanceOf[S] + defined = false + } + timingOut = false + removed = false + updated = false + } + + def wrapTimingOutState(newState: S): Unit = { + this.state = newState + defined = true + timingOut = true + removed = false + updated = false + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 9bc6c0e2b9334..e0293dacad355 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -57,6 +57,8 @@ trait StateStore { */ def remove(condition: UnsafeRow => Boolean): Unit + def remove(key: UnsafeRow): Unit + /** * Commit all the updates that have been made to the store, and return the new version. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/State.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/State.scala new file mode 100644 index 0000000000000..4c44fa194c14b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/State.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.catalyst.streaming.InternalState + +trait State[S] extends InternalState[S] { + + def exists(): Boolean + + def get(): S + + def update(newState: S): Unit + + def remove(): Unit + + def isTimingOut(): Boolean + + @inline final def getOption(): Option[S] = if (exists) Some(get()) else None + + @inline final override def toString(): String = { + getOption.map { + _.toString + }.getOrElse("") + } +} + + + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala new file mode 100644 index 0000000000000..b9642e502cc3f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.StateStore + +class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { + + import testImplicits._ + + override def afterAll(): Unit = { + super.afterAll() + StateStore.stop() + } + + test("basics") { + val inputData = MemoryStream[String] + + val stateFunc = (data: String, state: State[Int]) => { + val count = state.getOption().getOrElse(0) + 1 + state.update(count) + (data, count.toString) + } + val result = + inputData.toDS() + .groupByKey(x => x) + .mapValuesWithState[Int, (String, String)](stateFunc) + + testStream(result, Append)( + AddData(inputData, "a"), + CheckLastBatch(("a", "1")), + AddData(inputData, "a"), + CheckLastBatch(("a", "2")) + ) + } +} From 78cd1853033a091b62cb350879bb6c3a0b6c8641 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 17 Jan 2017 17:05:51 -0800 Subject: [PATCH 02/21] Renamed to mapGroupsWithState --- .../scala/org/apache/spark/sql/KeyValueGroupedDataset.scala | 2 +- .../apache/spark/sql/streaming/MapGroupsWithStateSuite.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 796f3f41c54ff..4d367b39d0ede 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -110,7 +110,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( mapValues { (v: V) => func.call(v) } } - def mapValuesWithState[STATE: Encoder, OUT: Encoder]( + def mapGroupsWithState[STATE: Encoder, OUT: Encoder]( func: (V, State[STATE]) => OUT): Dataset[OUT] = { Dataset[OUT]( sparkSession, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala index b9642e502cc3f..cd3d60d6238f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -43,7 +43,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { val result = inputData.toDS() .groupByKey(x => x) - .mapValuesWithState[Int, (String, String)](stateFunc) + .mapGroupsWithState[Int, (String, String)](stateFunc) testStream(result, Append)( AddData(inputData, "a"), From 0c22e08a8f9ad66a49bc939652fee14577f9bd4b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 17 Jan 2017 17:56:07 -0800 Subject: [PATCH 03/21] Fixed bugs --- .../sql/catalyst/plans/logical/object.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 1 - .../streaming/StatefulAggregate.scala | 55 +++++++++++-------- .../streaming/MapGroupsWithStateSuite.scala | 7 ++- 4 files changed, 36 insertions(+), 29 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 6168df8292ff9..2012c80bf8649 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -328,7 +328,7 @@ object MapGroupsWithState { groupingAttributes, dataAttributes, CatalystSerde.generateObjAttr[U], - encoderFor[S].deserializer, + encoderFor[S].resolveAndBind().deserializer, encoderFor[S].namedExpressions, child) CatalystSerde.serialize[U](mapped) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index f00d01f3ec4da..294dcefb652b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -255,7 +255,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { planLater(child)) execPlan :: Nil case _ => - println("here") Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index 6cabd6f217087..49de6e6fa310d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{OutputMode, State} import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.{CompletionIterator, NextIterator} /** Used to identify the state store for a given operator. */ @@ -230,7 +231,7 @@ case class StateStoreSaveExec( case class MapGroupsWithStateExec( func: (Any, State[Any]) => Any, - keyDeserializer: Expression, + keyDeserializer: Expression, // probably not needed valueDeserializer: Expression, groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], @@ -257,30 +258,36 @@ case class MapGroupsWithStateExec( child.output.toStructType, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => - val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) - val getKey = GenerateUnsafeProjection.generate(groupingAttributes, child.output) - val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) - val outputMappedObj = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - - val getStateObj = - ObjectOperator.deserializeRowToObject(stateDeserializer) - val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer) - - iter.map { row => - val key = getKey(row) - val keyObj = getKeyObj(row) - val valueObj = getValueObj(row) - val stateObjOption = store.get(key).map(getStateObj) - val wrappedState = new StateImpl[Any]() - wrappedState.wrap(stateObjOption) - val mapped = func(key, wrappedState) - if (wrappedState.isRemoved) { - store.remove(key) - } else if (wrappedState.isUpdated) { - store.put(key, outputStateObj(wrappedState.get())) + try { + val getKey = GenerateUnsafeProjection.generate(groupingAttributes, child.output) + val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) + val outputMappedObj = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + + val getStateObj = + ObjectOperator.deserializeRowToObject(stateDeserializer) + val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer) + + val mappedIter = iter.map { row => + val key = getKey(row) + val valueObj = getValueObj(row) + val stateObjOption = store.get(key).map(getStateObj) + val wrappedState = new StateImpl[Any]() + wrappedState.wrap(stateObjOption) + + val mapped = func(valueObj, wrappedState) + if (wrappedState.isRemoved) { + store.remove(key) + } else if (wrappedState.isUpdated) { + store.put(key, outputStateObj(wrappedState.get())) + } + outputMappedObj(mapped) } - outputMappedObj(mapped) + CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIter, { store.commit() }) + } catch { + case e: Throwable => + store.abort() + throw e } - } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala index cd3d60d6238f9..7359948e438d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -35,6 +35,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { test("basics") { val inputData = MemoryStream[String] + // Function to maintain a running count val stateFunc = (data: String, state: State[Int]) => { val count = state.getOption().getOrElse(0) + 1 state.update(count) @@ -43,13 +44,13 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { val result = inputData.toDS() .groupByKey(x => x) - .mapGroupsWithState[Int, (String, String)](stateFunc) + .mapGroupsWithState[Int, (String, String)](stateFunc) // Int => State, (Str, Str) => Out testStream(result, Append)( AddData(inputData, "a"), CheckLastBatch(("a", "1")), - AddData(inputData, "a"), - CheckLastBatch(("a", "2")) + AddData(inputData, "a", "b"), + CheckLastBatch(("a", "2"), ("b", "1")) ) } } From 52e14e479ffa1e38c9efc5b063a95831caab6997 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 18 Jan 2017 02:52:20 -0800 Subject: [PATCH 04/21] Removed prints --- .../org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 2 -- .../apache/spark/sql/streaming/MapGroupsWithStateSuite.scala | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index bf2dbebf7fcce..aa77a6efef347 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -405,8 +405,6 @@ trait CheckAnalysis extends PredicateHelper { } case o if !o.resolved => - println(o) - println(o.expressions.filterNot(_.resolved).mkString("\n")) failAnalysis( s"unresolved operator ${operator.simpleString}") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala index 7359948e438d1..eada2a3aeed3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -41,6 +41,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { state.update(count) (data, count.toString) } + val result = inputData.toDS() .groupByKey(x => x) From 529aefe6d7cb9cd54e20c0cdaa11cec90a4f16be Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 18 Jan 2017 02:58:27 -0800 Subject: [PATCH 05/21] Test state remove --- .../streaming/MapGroupsWithStateSuite.scala | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala index eada2a3aeed3c..7a432ac90d2ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -35,11 +35,18 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { test("basics") { val inputData = MemoryStream[String] - // Function to maintain a running count + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (data: String, state: State[Int]) => { - val count = state.getOption().getOrElse(0) + 1 - state.update(count) - (data, count.toString) + val oldCount = state.getOption().getOrElse(0) + if (oldCount == 2) { + state.remove() + (data, "-1") + } else { + val newCount = oldCount + 1 + state.update(newCount) + (data, newCount.toString) + } } val result = @@ -51,7 +58,11 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { AddData(inputData, "a"), CheckLastBatch(("a", "1")), AddData(inputData, "a", "b"), - CheckLastBatch(("a", "2"), ("b", "1")) + CheckLastBatch(("a", "2"), ("b", "1")), + AddData(inputData, "a", "b"), + CheckLastBatch(("a", "-1"), ("b", "2")), // state for a remove + AddData(inputData, "a", "b"), + CheckLastBatch(("a", "1"), ("b", "-1")) // state for a recreated ) } } From 57f5e8d2e8a74a8269667a9d4d89971eb9107c07 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 18 Jan 2017 12:26:53 -0800 Subject: [PATCH 06/21] Test restart, and test with metrics --- .../streaming/ProgressReporter.scala | 2 +- .../streaming/StatefulAggregate.scala | 21 ++++++++++++++++-- .../streaming/MapGroupsWithStateSuite.scala | 22 +++++++++++++++---- 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index c5e9eae607b39..738b6cb60c6d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -239,7 +239,7 @@ trait ProgressReporter extends Logging { // Extract statistics about stateful operators in the query plan. val stateNodes = lastExecution.executedPlan.collect { - case p if p.isInstanceOf[StateStoreSaveExec] => p + case p if (p.isInstanceOf[StateStoreSaveExec] || p.isInstanceOf[MapGroupsWithStateExec]) => p } val stateOperators = stateNodes.map { node => new StateOperatorProgress( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index 49de6e6fa310d..d4efcb6deae93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -247,9 +247,20 @@ case class MapGroupsWithStateExec( ClusteredDistribution(groupingAttributes) :: Nil override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingAttributes.map(SortOrder(_, Ascending))) + Seq(groupingAttributes.map(SortOrder(_, Ascending))) // is this ordering needed? + + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"), + "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"), + "numRemovedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of removed state rows") + ) override protected def doExecute(): RDD[InternalRow] = { + val numTotalStateRows = longMetric("numTotalStateRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val numRemovedStateRows = longMetric("numRemovedStateRows") + child.execute().mapPartitionsWithStateStore[InternalRow]( getStateId.checkpointLocation, operatorId = getStateId.operatorId, @@ -259,6 +270,7 @@ case class MapGroupsWithStateExec( sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => try { + val getKey = GenerateUnsafeProjection.generate(groupingAttributes, child.output) val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) val outputMappedObj = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) @@ -277,12 +289,17 @@ case class MapGroupsWithStateExec( val mapped = func(valueObj, wrappedState) if (wrappedState.isRemoved) { store.remove(key) + numRemovedStateRows += 1 } else if (wrappedState.isUpdated) { store.put(key, outputStateObj(wrappedState.get())) + numUpdatedStateRows += 1 } outputMappedObj(mapped) } - CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIter, { store.commit() }) + CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIter, { + store.commit() + numTotalStateRows += store.numKeys() + }) } catch { case e: Throwable => store.abort() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala index 7a432ac90d2ea..1cb64ac1c3c25 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -57,12 +57,26 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { testStream(result, Append)( AddData(inputData, "a"), CheckLastBatch(("a", "1")), + assertNumStateRows(1), AddData(inputData, "a", "b"), CheckLastBatch(("a", "2"), ("b", "1")), - AddData(inputData, "a", "b"), - CheckLastBatch(("a", "-1"), ("b", "2")), // state for a remove - AddData(inputData, "a", "b"), - CheckLastBatch(("a", "1"), ("b", "-1")) // state for a recreated + assertNumStateRows(2), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 + CheckLastBatch(("a", "-1"), ("b", "2")), + assertNumStateRows(1), + StopStream, + StartStream(), + AddData(inputData, "a", "b", "c"), // should recreate state for "a" and return count as 1 + CheckLastBatch(("a", "1"), ("b", "-1"), ("c", "1")), + assertNumStateRows(2) ) } + + private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => + val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get + assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows) + true + } } From 3e0d8dcfa81d58ae6ca6754cc54c19383179802a Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Fri, 20 Jan 2017 18:25:29 -0800 Subject: [PATCH 07/21] Fixed everything --- .../UnsupportedOperationChecker.scala | 2 +- .../sql/catalyst/plans/logical/object.scala | 10 ++-- .../spark/sql/KeyValueGroupedDataset.scala | 30 ++++++---- .../streaming/ProgressReporter.scala | 3 +- .../streaming/StatefulAggregate.scala | 21 ++++--- .../streaming/MapGroupsWithStateSuite.scala | 60 ++++++++++++++++--- 6 files changed, 89 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index c2666b2ab9129..f4d016cb96711 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -87,7 +87,7 @@ object UnsupportedOperationChecker { * data. */ def containsCompleteData(subplan: LogicalPlan): Boolean = { - val aggs = plan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a } + val aggs = subplan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a } // Either the subplan has no streaming source, or it has aggregation with Complete mode !subplan.isStreaming || (aggs.nonEmpty && outputMode == InternalOutputModes.Complete) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 2012c80bf8649..382244a4ea272 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -316,15 +316,15 @@ case class MapGroups( /** Factory for constructing new `MapGroups` nodes. */ object MapGroupsWithState { - def apply[K: Encoder, T: Encoder, S: Encoder, U: Encoder]( - func: (T, InternalState[S]) => U, + def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder]( + func: (Any, Iterator[Any], InternalState[Any]) => Iterator[Any], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], child: LogicalPlan): LogicalPlan = { val mapped = new MapGroupsWithState( - func.asInstanceOf[(Any, InternalState[Any]) => Any], + func, UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), - UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes), + UnresolvedDeserializer(encoderFor[V].deserializer, dataAttributes), groupingAttributes, dataAttributes, CatalystSerde.generateObjAttr[U], @@ -336,7 +336,7 @@ object MapGroupsWithState { } case class MapGroupsWithState( - func: (Any, InternalState[Any]) => Any, + func: (Any, Iterator[Any], InternalState[Any]) => Iterator[Any], keyDeserializer: Expression, valueDeserializer: Expression, groupingAttributes: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 4d367b39d0ede..de98048167699 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -110,18 +110,6 @@ class KeyValueGroupedDataset[K, V] private[sql]( mapValues { (v: V) => func.call(v) } } - def mapGroupsWithState[STATE: Encoder, OUT: Encoder]( - func: (V, State[STATE]) => OUT): Dataset[OUT] = { - Dataset[OUT]( - sparkSession, - MapGroupsWithState[K, V, STATE, OUT]( - func.asInstanceOf[(V, InternalState[STATE]) => OUT], - groupingAttributes, - dataAttributes, - logicalPlan)) - } - - /** * Returns a [[Dataset]] that contains each unique key. This is equivalent to doing mapping * over the Dataset to extract the keys and then running a distinct operation on those. @@ -232,6 +220,24 @@ class KeyValueGroupedDataset[K, V] private[sql]( mapGroups((key, data) => f.call(key, data.asJava))(encoder) } + def mapGroupsWithState[STATE: Encoder, OUT: Encoder]( + f: (K, Iterator[V], State[STATE]) => OUT): Dataset[OUT] = { + val func = (key: K, it: Iterator[V], s: State[STATE]) => Iterator(f(key, it, s)) + flatMapGroupsWithState[STATE, OUT](func) + } + + + def flatMapGroupsWithState[STATE: Encoder, OUT: Encoder]( + func: (K, Iterator[V], State[STATE]) => Iterator[OUT]): Dataset[OUT] = { + Dataset[OUT]( + sparkSession, + MapGroupsWithState[K, V, STATE, OUT]( + func.asInstanceOf[(Any, Iterator[Any], InternalState[Any]) => Iterator[Any]], + groupingAttributes, + dataAttributes, + logicalPlan)) + } + /** * (Scala-specific) * Reduces the elements of each group of data using the specified binary function. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 738b6cb60c6d9..f6a20d3840b0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -239,7 +239,8 @@ trait ProgressReporter extends Logging { // Extract statistics about stateful operators in the query plan. val stateNodes = lastExecution.executedPlan.collect { - case p if (p.isInstanceOf[StateStoreSaveExec] || p.isInstanceOf[MapGroupsWithStateExec]) => p + case p if + (p.isInstanceOf[StateStoreSaveExec] || p.isInstanceOf[MapGroupsWithStateExec]) => p } val stateOperators = stateNodes.map { node => new StateOperatorProgress( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index d4efcb6deae93..27ab09db40f18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjecti import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.catalyst.streaming.InternalState import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics @@ -230,7 +231,7 @@ case class StateStoreSaveExec( } case class MapGroupsWithStateExec( - func: (Any, State[Any]) => Any, + func: (Any, Iterator[Any], InternalState[Any]) => Iterator[Any], keyDeserializer: Expression, // probably not needed valueDeserializer: Expression, groupingAttributes: Seq[Attribute], @@ -270,23 +271,24 @@ case class MapGroupsWithStateExec( sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => try { + val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) + val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) val getKey = GenerateUnsafeProjection.generate(groupingAttributes, child.output) val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) val outputMappedObj = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - val getStateObj = ObjectOperator.deserializeRowToObject(stateDeserializer) val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer) - val mappedIter = iter.map { row => - val key = getKey(row) - val valueObj = getValueObj(row) + val finalIterator = groupedIter.flatMap { case (keyRow, valueRowIter) => + val key = keyRow.asInstanceOf[UnsafeRow] + val keyObj = getKeyObj(keyRow) + val valueObjIter = valueRowIter.map(getValueObj.apply) val stateObjOption = store.get(key).map(getStateObj) val wrappedState = new StateImpl[Any]() wrappedState.wrap(stateObjOption) - - val mapped = func(valueObj, wrappedState) + val mappedIterator = func(keyObj, valueObjIter, wrappedState) if (wrappedState.isRemoved) { store.remove(key) numRemovedStateRows += 1 @@ -294,9 +296,10 @@ case class MapGroupsWithStateExec( store.put(key, outputStateObj(wrappedState.get())) numUpdatedStateRows += 1 } - outputMappedObj(mapped) + + mappedIterator.map(outputMappedObj.apply) } - CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIter, { + CompletionIterator[InternalRow, Iterator[InternalRow]](finalIterator, { store.commit() numTotalStateRows += store.numKeys() }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala index 1cb64ac1c3c25..46282f3434e9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -32,27 +32,69 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { StateStore.stop() } - test("basics") { + test("mapGroupWithState") { val inputData = MemoryStream[String] // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - val stateFunc = (data: String, state: State[Int]) => { - val oldCount = state.getOption().getOrElse(0) - if (oldCount == 2) { + val stateFunc = (key: String, values: Iterator[String], state: State[Int]) => { + + var count = state.getOption().getOrElse(0) + values.size + if (count == 3) { + state.remove() + (key, "-1") + } else { + state.update(count) + (key, count.toString) + } + } + + val result = + inputData.toDS() + .groupByKey(x => x) + .mapGroupsWithState[Int, (String, String)](stateFunc) // Int => State, (Str, Str) => Out + + testStream(result, Append)( + AddData(inputData, "a"), + CheckLastBatch(("a", "1")), + assertNumStateRows(1), + AddData(inputData, "a", "b"), + CheckLastBatch(("a", "2"), ("b", "1")), + assertNumStateRows(2), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 + CheckLastBatch(("a", "-1"), ("b", "2")), + assertNumStateRows(1), + StopStream, + StartStream(), + AddData(inputData, "a", "b", "c"), // should recreate state for "a" and return count as 1 + CheckLastBatch(("a", "1"), ("b", "-1"), ("c", "1")), + assertNumStateRows(2) + ) + } + + test("flatMapGroupWithState") { + val inputData = MemoryStream[String] + + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val stateFunc = (key: String, values: Iterator[String], state: State[Int]) => { + + var count = state.getOption().getOrElse(0) + values.size + if (count == 3) { state.remove() - (data, "-1") + Iterator((key, "-1")) } else { - val newCount = oldCount + 1 - state.update(newCount) - (data, newCount.toString) + state.update(count) + Iterator((key, count.toString)) } } val result = inputData.toDS() .groupByKey(x => x) - .mapGroupsWithState[Int, (String, String)](stateFunc) // Int => State, (Str, Str) => Out + .flatMapGroupsWithState[Int, (String, String)](stateFunc) // Int => State, (Str, Str) => Out testStream(result, Append)( AddData(inputData, "a"), From b54fa230eda141316713e3b1d1c56d8a28fd3a6c Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 29 Jan 2017 18:06:00 -0800 Subject: [PATCH 08/21] Refactored, added java APIs and tests --- .../sql/catalyst/plans/logical/object.scala | 4 +- .../catalyst/streaming/InternalState.scala | 20 --- .../spark/sql/KeyValueGroupedDataset.scala | 48 ++++++- .../scala/org/apache/spark/sql/State.scala | 101 +++++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 32 ++--- .../spark/sql/execution/StateImpl.scala | 70 +++++++++++ .../apache/spark/sql/execution/objects.scala | 16 +++ .../streaming/IncrementalExecution.scala | 4 +- .../streaming/StatefulAggregate.scala | 12 +- .../execution/streaming/state/StateImpl.scala | 97 -------------- .../apache/spark/sql/streaming/State.scala | 44 ------- .../apache/spark/sql/JavaDatasetSuite.java | 32 +++++ .../streaming/MapGroupsWithStateSuite.scala | 118 +++++++++++++++--- 13 files changed, 393 insertions(+), 205 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalState.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/State.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/StateImpl.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateImpl.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/streaming/State.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 382244a4ea272..530515f7cdff2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.catalyst.streaming.InternalState import org.apache.spark.sql.types._ object CatalystSerde { @@ -314,6 +313,9 @@ case class MapGroups( outputObjAttr: Attribute, child: LogicalPlan) extends UnaryNode with ObjectProducer +/** Internal class representing State */ +trait InternalState[S] + /** Factory for constructing new `MapGroups` nodes. */ object MapGroupsWithState { def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder]( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalState.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalState.scala deleted file mode 100644 index 0d9391b535f99..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalState.scala +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.streaming - -trait InternalState[S] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index de98048167699..83f08b78ae875 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -24,10 +24,8 @@ import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.streaming.InternalState import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator -import org.apache.spark.sql.streaming.State /** * :: Experimental :: @@ -220,13 +218,42 @@ class KeyValueGroupedDataset[K, V] private[sql]( mapGroups((key, data) => f.call(key, data.asJava))(encoder) } + /** + * ::Experimental:: + * (Scala-specific) + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving def mapGroupsWithState[STATE: Encoder, OUT: Encoder]( - f: (K, Iterator[V], State[STATE]) => OUT): Dataset[OUT] = { + f: (K, Iterator[V], State[STATE]) => OUT): Dataset[OUT] = { val func = (key: K, it: Iterator[V], s: State[STATE]) => Iterator(f(key, it, s)) flatMapGroupsWithState[STATE, OUT](func) } + /** + * ::Experimental:: + * (Java-specific) + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving + def mapGroupsWithState[STATE, OUT]( + f: MapGroupsWithStateFunction[K, V, STATE, OUT], + stateEncoder: Encoder[STATE], + outputEncoder: Encoder[OUT]): Dataset[OUT] = { + val func = (key: K, it: Iterator[V], s: State[STATE]) => Iterator(f.call(key, it.asJava, s)) + flatMapGroupsWithState[STATE, OUT](func)(stateEncoder, outputEncoder) + } + + /** + * ::Experimental:: + * (Scala-specific) + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving def flatMapGroupsWithState[STATE: Encoder, OUT: Encoder]( func: (K, Iterator[V], State[STATE]) => Iterator[OUT]): Dataset[OUT] = { Dataset[OUT]( @@ -238,6 +265,21 @@ class KeyValueGroupedDataset[K, V] private[sql]( logicalPlan)) } + /** + * ::Experimental:: + * (Java-specific) + * @since 2.1.1 + */ + @Experimental + @InterfaceStability.Evolving + def flatMapGroupsWithState[STATE, OUT]( + f: FlatMapGroupsWithStateFunction[K, V, STATE, OUT], + stateEncoder: Encoder[STATE], + outputEncoder: Encoder[OUT]): Dataset[OUT] = { + val func = (key: K, it: Iterator[V], s: State[STATE]) => f.call(key, it.asJava, s).asScala + flatMapGroupsWithState[STATE, OUT](func)(stateEncoder, outputEncoder) + } + /** * (Scala-specific) * Reduces the elements of each group of data using the specified binary function. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/State.scala b/sql/core/src/main/scala/org/apache/spark/sql/State.scala new file mode 100644 index 0000000000000..4dd690cca2dc0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/State.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.sql.catalyst.plans.logical.InternalState + +/** + * :: Experimental :: + * + * Wrapper class for interacting with state data in `mapGroupsWithState` and + * `flatMapGroupsWithState` operations on + * [[org.apache.spark.sql.KeyValueGroupedDataset KeyValueGroupedDataset]]. + * + * @note Operations on state are not threadsafe. + * + * Scala example of using `State`: + * {{{ + * // A mapping function that maintains an integer state for string keys and returns a string. + * def mappingFunction(key: String, value: Iterable[Int], state: State[Int]): Option[String] = { + * // Check if state exists + * if (state.exists) { + * val existingState = state.get // Get the existing state + * val shouldRemove = ... // Decide whether to remove the state + * if (shouldRemove) { + * state.remove() // Remove the state + * } else { + * val newState = ... + * state.update(newState) // Set the new state + * } + * } else { + * val initialState = ... + * state.update(initialState) // Set the initial state + * } + * ... // return something + * } + * + * }}} + * + * Java example of using `State`: + * {{{ + * // A mapping function that maintains an integer state for string keys and returns a string. + * Function3, State, String> mappingFunction = + * new Function3, State, String>() { + * + * @Override + * public String call(String key, Optional value, State state) { + * if (state.exists()) { + * int existingState = state.get(); // Get the existing state + * boolean shouldRemove = ...; // Decide whether to remove the state + * if (shouldRemove) { + * state.remove(); // Remove the state + * } else { + * int newState = ...; + * state.update(newState); // Set the new state + * } + * } else { + * int initialState = ...; // Set the initial state + * state.update(initialState); + * } + * ... // return something + * } + * }; + * }}} + * + * @tparam S Type of the state + * @since 2.1.1 + */ +@Experimental +@InterfaceStability.Evolving +trait State[S] extends InternalState[S] { + + def exists: Boolean + + def get(): S + + def update(newState: S): Unit + + def remove(): Unit + + @inline final def getOption(): Option[S] = if (exists) Some(get()) else None + + @inline final override def toString(): String = { + getOption.map { _.toString }.getOrElse("") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 294dcefb652b1..b0bb1f6c00305 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -244,21 +244,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - object MapGroupsWithStateStrategy extends Strategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case MapGroupsWithState( - func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, - stateDeser, stateSer, child) => - val execPlan = MapGroupsWithStateExec( - func, keyDeser, valueDeser, - groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer, - planLater(child)) - execPlan :: Nil - case _ => - Nil - } - } - /** * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. */ @@ -328,6 +313,21 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + object MapGroupsWithStateStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case MapGroupsWithState( + func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, + stateDeser, stateSer, child) => + val execPlan = MapGroupsWithStateExec( + func, keyDeser, valueDeser, + groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer, + planLater(child)) + execPlan :: Nil + case _ => + Nil + } + } + // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def numPartitions: Int = self.numPartitions @@ -369,6 +369,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.AppendColumnsWithObjectExec(f, childSer, newSer, planLater(child)) :: Nil case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => execution.MapGroupsExec(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil + case logical.MapGroupsWithState(f, key, value, grouping, data, output, _, _, child) => + execution.MapGroupsExec(f, key, value, grouping, data, output, planLater(child)) :: Nil case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroupExec( f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/StateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/StateImpl.scala new file mode 100644 index 0000000000000..c6c8d12c49083 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/StateImpl.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.State + +/** Internal implementation of the [[State]] interface */ +private[sql] class StateImpl[S](optionalValue: Option[S]) extends State[S] { + private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) + private var defined: Boolean = optionalValue.isDefined + private var updated: Boolean = false // whether value has been updated (but not removed) + private var removed: Boolean = false // whether value has eben removed + + // ========= Public API ========= + override def exists: Boolean = { + defined + } + + override def get(): S = { + if (defined) { + value + } else { + throw new NoSuchElementException("State is either not defined or has already been removed") + } + } + + override def update(newValue: S): Unit = { + value = newValue + defined = true + updated = true + removed = false + } + + override def remove(): Unit = { + defined = false + updated = false + removed = true + } + + // ========= Internal API ========= + + /** Whether the state has been marked for removing */ + def isRemoved: Boolean = { + removed + } + + /** Whether the state has been been updated */ + def isUpdated: Boolean = { + updated + } +} + +object StateImpl { + def apply[S](optionalValue: Option[S]): StateImpl[S] = new StateImpl[S](optionalValue) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 313452099d2eb..c6a5616902d75 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.plans.logical.InternalState import org.apache.spark.sql.types.{DataType, ObjectType, StructType} @@ -350,6 +351,21 @@ case class MapGroupsExec( } } +object MapGroupsExec { + def apply( + func: (Any, Iterator[Any], InternalState[Any]) => TraversableOnce[Any], + keyDeserializer: Expression, + valueDeserializer: Expression, + groupingAttributes: Seq[Attribute], + dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, + child: SparkPlan): MapGroupsExec = { + val f = (key: Any, values: Iterator[Any]) => func(key, values, StateImpl[Any](None)) + new MapGroupsExec(f, keyDeserializer, valueDeserializer, + groupingAttributes, dataAttributes, outputObjAttr, child) + } +} + /** * Groups the input rows together and calls the R function with each group and an iterator * containing all elements in the group. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 147401b370793..dc840c862438f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -41,7 +41,7 @@ class IncrementalExecution( extends QueryExecution(sparkSession, logicalPlan) with Logging { // TODO: make this always part of planning. - val stateStrategy = + val streamingExtraStrategies = sparkSession.sessionState.planner.StatefulAggregationStrategy +: sparkSession.sessionState.planner.MapGroupsWithStateStrategy +: sparkSession.sessionState.planner.StreamingRelationStrategy +: @@ -52,7 +52,7 @@ class IncrementalExecution( new SparkPlanner( sparkSession.sparkContext, sparkSession.sessionState.conf, - stateStrategy) + streamingExtraStrategies) /** * See [SPARK-18339] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index 27ab09db40f18..e87a2e0a6387c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -22,17 +22,16 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} -import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, InternalState} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.catalyst.streaming.InternalState import org.apache.spark.sql.execution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.streaming.state._ -import org.apache.spark.sql.streaming.{OutputMode, State} -import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.util.{CompletionIterator, NextIterator} +import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.CompletionIterator /** Used to identify the state store for a given operator. */ @@ -286,8 +285,7 @@ case class MapGroupsWithStateExec( val keyObj = getKeyObj(keyRow) val valueObjIter = valueRowIter.map(getValueObj.apply) val stateObjOption = store.get(key).map(getStateObj) - val wrappedState = new StateImpl[Any]() - wrappedState.wrap(stateObjOption) + val wrappedState = StateImpl[Any](stateObjOption) val mappedIterator = func(keyObj, valueObjIter, wrappedState) if (wrappedState.isRemoved) { store.remove(key) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateImpl.scala deleted file mode 100644 index a911d142f8859..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateImpl.scala +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming.state - -import org.apache.spark.sql.streaming.State - -/** Internal implementation of the [[State]] interface */ -class StateImpl[S] extends State[S] { - private var state: S = null.asInstanceOf[S] - private var defined: Boolean = false - private var timingOut: Boolean = false - private var updated: Boolean = false - private var removed: Boolean = false - - // ========= Public API ========= - override def exists(): Boolean = { - defined - } - - override def get(): S = { - if (defined) { - state - } else { - throw new NoSuchElementException("State is not set") - } - } - - override def update(newState: S): Unit = { - require(!removed, "Cannot update the state after it has been removed") - require(!timingOut, "Cannot update the state that is timing out") - state = newState - defined = true - updated = true - } - - override def isTimingOut(): Boolean = { - timingOut - } - - override def remove(): Unit = { - require(!timingOut, "Cannot remove the state that is timing out") - require(!removed, "Cannot remove the state that has already been removed") - defined = false - updated = false - removed = true - } - - // ========= Internal API ========= - - /** Whether the state has been marked for removing */ - def isRemoved(): Boolean = { - removed - } - - /** Whether the state has been been updated */ - def isUpdated(): Boolean = { - updated - } - - def wrap(optionalState: Option[S]): Unit = { - optionalState match { - case Some(newState) => - this.state = newState - defined = true - - case None => - this.state = null.asInstanceOf[S] - defined = false - } - timingOut = false - removed = false - updated = false - } - - def wrapTimingOutState(newState: S): Unit = { - this.state = newState - defined = true - timingOut = true - removed = false - updated = false - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/State.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/State.scala deleted file mode 100644 index 4c44fa194c14b..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/State.scala +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import org.apache.spark.sql.catalyst.streaming.InternalState - -trait State[S] extends InternalState[S] { - - def exists(): Boolean - - def get(): S - - def update(newState: S): Unit - - def remove(): Unit - - def isTimingOut(): Boolean - - @inline final def getOption(): Option[S] = if (exists) Some(get()) else None - - @inline final override def toString(): String = { - getOption.map { - _.toString - }.getOrElse("") - } -} - - - diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 8304b728aa238..6260b6127d52e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -225,6 +225,38 @@ public Iterator call(Integer key, Iterator values) { Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList())); + Dataset mapped2 = grouped.mapGroupsWithState( + new MapGroupsWithStateFunction() { + @Override + public String call(Integer key, Iterator values, State s) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return sb.toString(); + } + }, + Encoders.LONG(), + Encoders.STRING()); + + Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped2.collectAsList())); + + Dataset flatMapped2 = grouped.flatMapGroupsWithState( + new FlatMapGroupsWithStateFunction() { + @Override + public Iterator call(Integer key, Iterator values, State s) { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return Collections.singletonList(sb.toString()).iterator(); + } + }, + Encoders.LONG(), + Encoders.STRING()); + + Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped2.collectAsList())); + Dataset> reduced = grouped.reduceGroups(new ReduceFunction() { @Override public String call(String v1, String v2) throws Exception { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala index 46282f3434e9d..dfe6bf31b0d60 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -19,10 +19,15 @@ package org.apache.spark.sql.streaming import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.State import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.execution.StateImpl import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.execution.streaming.state.StateStore +/** Class to check custom state types */ +case class RunningCount(count: Long) + class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { import testImplicits._ @@ -32,19 +37,100 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { StateStore.stop() } - test("mapGroupWithState") { + test("state - get, exists, update, remove, ") { + var state: StateImpl[String] = null + + def testState( + expectedData: Option[String], + shouldBeUpdated: Boolean = false, + shouldBeRemoved: Boolean = false + ): Unit = { + if (expectedData.isDefined) { + assert(state.exists) + assert(state.get() === expectedData.get) + assert(state.getOption() === expectedData) + } else { + assert(!state.exists) + intercept[NoSuchElementException] { + state.get() + } + assert(state.getOption() === None) + } + + assert(state.isUpdated === shouldBeUpdated) + assert(state.isRemoved === shouldBeRemoved) + } + + // Updating empty state + state = StateImpl[String](None) + testState(None) + state.update("") + testState(Some(""), shouldBeUpdated = true) + + // Updating exiting state, even if with null + state = StateImpl[String](Some("2")) + testState(Some("2")) + state.update("3") + testState(Some("3"), shouldBeUpdated = true) + state.update(null) + testState(Some(null), shouldBeUpdated = true) + + // Removing state + state.remove() + testState(None, shouldBeRemoved = true, shouldBeUpdated = false) + state.remove() // should be still callable + state.update("4") + testState(Some("4"), shouldBeRemoved = false, shouldBeUpdated = true) + } + + + + // ************* Batch query tests for [flat]mapGroupsWithState ************* + + test("batch - mapGroupsWithState") { + val stateFunc = (key: String, values: Iterator[String], state: State[RunningCount]) => { + assert(!state.exists) + assert(state.getOption.isEmpty) + (key, values.size) + } + + checkAnswer( + spark.createDataset(Seq("a", "a", "b")) + .groupByKey(x => x) + .mapGroupsWithState(stateFunc) + .toDF, + spark.createDataset(Seq(("a", 2), ("b", 1))).toDF) + } + + test("batch - flatMapGroupsWithState") { + // Function that returns running count only if its even, otherwise does not return + val stateFunc = (key: String, values: Iterator[String], state: State[RunningCount]) => { + assert(!state.exists) + assert(state.getOption.isEmpty) + if (values.size == 2) { + Iterator((key, values.size)) + } else Iterator.empty + } + checkAnswer( + Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF, + Seq(("a", 2), ("b", 1)).toDF) + } + + // ************* Streaming query tests for [flat]mapGroupsWithState ************* + + test("streaming - mapGroupsWithState") { val inputData = MemoryStream[String] // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - val stateFunc = (key: String, values: Iterator[String], state: State[Int]) => { + val stateFunc = (key: String, values: Iterator[String], state: State[RunningCount]) => { - var count = state.getOption().getOrElse(0) + values.size + var count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { state.remove() (key, "-1") } else { - state.update(count) + state.update(RunningCount(count)) (key, count.toString) } } @@ -52,7 +138,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { val result = inputData.toDS() .groupByKey(x => x) - .mapGroupsWithState[Int, (String, String)](stateFunc) // Int => State, (Str, Str) => Out + .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) testStream(result, Append)( AddData(inputData, "a"), @@ -74,19 +160,19 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { ) } - test("flatMapGroupWithState") { + test("streaming - flatMapGroupsWithState") { val inputData = MemoryStream[String] // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - val stateFunc = (key: String, values: Iterator[String], state: State[Int]) => { + // Returns the data and the count if state is defined, otherwise does not return anything + val stateFunc = (key: String, values: Iterator[String], state: State[RunningCount]) => { - var count = state.getOption().getOrElse(0) + values.size + var count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { state.remove() - Iterator((key, "-1")) + Iterator.empty } else { - state.update(count) + state.update(RunningCount(count)) Iterator((key, count.toString)) } } @@ -94,7 +180,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState[Int, (String, String)](stateFunc) // Int => State, (Str, Str) => Out + .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) testStream(result, Append)( AddData(inputData, "a"), @@ -105,13 +191,13 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { assertNumStateRows(2), StopStream, StartStream(), - AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 - CheckLastBatch(("a", "-1"), ("b", "2")), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckLastBatch( ("b", "2")), assertNumStateRows(1), StopStream, StartStream(), - AddData(inputData, "a", "b", "c"), // should recreate state for "a" and return count as 1 - CheckLastBatch(("a", "1"), ("b", "-1"), ("c", "1")), + AddData(inputData, "a", "b", "c"), // should recreate state for "a" and return count as 1 and + CheckLastBatch(("a", "1"), ("c", "1")), // ... not return anything for b assertNumStateRows(2) ) } From ab3cb6c961f0d861a24c2146dcb9dc0380c8adc9 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Sun, 29 Jan 2017 21:07:19 -0800 Subject: [PATCH 09/21] Refactored --- .../sql/catalyst/plans/logical/object.scala | 6 +- .../spark/sql/KeyValueGroupedDataset.scala | 20 +-- .../scala/org/apache/spark/sql/State.scala | 4 +- .../apache/spark/sql/execution/objects.scala | 4 +- .../streaming/ProgressReporter.scala | 3 +- .../streaming/StatefulAggregate.scala | 56 ++++---- .../streaming/MapGroupsWithStateSuite.scala | 125 +++++++++--------- 7 files changed, 110 insertions(+), 108 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 530515f7cdff2..7ae50a0007fd1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -314,12 +314,12 @@ case class MapGroups( child: LogicalPlan) extends UnaryNode with ObjectProducer /** Internal class representing State */ -trait InternalState[S] +trait LogicalState[S] /** Factory for constructing new `MapGroups` nodes. */ object MapGroupsWithState { def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder]( - func: (Any, Iterator[Any], InternalState[Any]) => Iterator[Any], + func: (Any, Iterator[Any], LogicalState[Any]) => Iterator[Any], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], child: LogicalPlan): LogicalPlan = { @@ -338,7 +338,7 @@ object MapGroupsWithState { } case class MapGroupsWithState( - func: (Any, Iterator[Any], InternalState[Any]) => Iterator[Any], + func: (Any, Iterator[Any], LogicalState[Any]) => Iterator[Any], keyDeserializer: Expression, valueDeserializer: Expression, groupingAttributes: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 83f08b78ae875..6a5f9c5b86923 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -226,9 +226,9 @@ class KeyValueGroupedDataset[K, V] private[sql]( @Experimental @InterfaceStability.Evolving def mapGroupsWithState[STATE: Encoder, OUT: Encoder]( - f: (K, Iterator[V], State[STATE]) => OUT): Dataset[OUT] = { - val func = (key: K, it: Iterator[V], s: State[STATE]) => Iterator(f(key, it, s)) - flatMapGroupsWithState[STATE, OUT](func) + func: (K, Iterator[V], State[STATE]) => OUT): Dataset[OUT] = { + val f = (key: K, it: Iterator[V], s: State[STATE]) => Iterator(func(key, it, s)) + flatMapGroupsWithState[STATE, OUT](f) } /** @@ -239,11 +239,11 @@ class KeyValueGroupedDataset[K, V] private[sql]( @Experimental @InterfaceStability.Evolving def mapGroupsWithState[STATE, OUT]( - f: MapGroupsWithStateFunction[K, V, STATE, OUT], + func: MapGroupsWithStateFunction[K, V, STATE, OUT], stateEncoder: Encoder[STATE], outputEncoder: Encoder[OUT]): Dataset[OUT] = { - val func = (key: K, it: Iterator[V], s: State[STATE]) => Iterator(f.call(key, it.asJava, s)) - flatMapGroupsWithState[STATE, OUT](func)(stateEncoder, outputEncoder) + val f = (key: K, it: Iterator[V], s: State[STATE]) => Iterator(func.call(key, it.asJava, s)) + flatMapGroupsWithState[STATE, OUT](f)(stateEncoder, outputEncoder) } @@ -259,7 +259,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( Dataset[OUT]( sparkSession, MapGroupsWithState[K, V, STATE, OUT]( - func.asInstanceOf[(Any, Iterator[Any], InternalState[Any]) => Iterator[Any]], + func.asInstanceOf[(Any, Iterator[Any], LogicalState[Any]) => Iterator[Any]], groupingAttributes, dataAttributes, logicalPlan)) @@ -273,11 +273,11 @@ class KeyValueGroupedDataset[K, V] private[sql]( @Experimental @InterfaceStability.Evolving def flatMapGroupsWithState[STATE, OUT]( - f: FlatMapGroupsWithStateFunction[K, V, STATE, OUT], + func: FlatMapGroupsWithStateFunction[K, V, STATE, OUT], stateEncoder: Encoder[STATE], outputEncoder: Encoder[OUT]): Dataset[OUT] = { - val func = (key: K, it: Iterator[V], s: State[STATE]) => f.call(key, it.asJava, s).asScala - flatMapGroupsWithState[STATE, OUT](func)(stateEncoder, outputEncoder) + val f = (key: K, it: Iterator[V], s: State[STATE]) => func.call(key, it.asJava, s).asScala + flatMapGroupsWithState[STATE, OUT](f)(stateEncoder, outputEncoder) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/State.scala b/sql/core/src/main/scala/org/apache/spark/sql/State.scala index 4dd690cca2dc0..65f621fa05e57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/State.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/State.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.catalyst.plans.logical.InternalState +import org.apache.spark.sql.catalyst.plans.logical.LogicalState /** * :: Experimental :: @@ -83,7 +83,7 @@ import org.apache.spark.sql.catalyst.plans.logical.InternalState */ @Experimental @InterfaceStability.Evolving -trait State[S] extends InternalState[S] { +trait State[S] extends LogicalState[S] { def exists: Boolean diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index c6a5616902d75..2615eaa5ffe3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.plans.logical.InternalState +import org.apache.spark.sql.catalyst.plans.logical.LogicalState import org.apache.spark.sql.types.{DataType, ObjectType, StructType} @@ -353,7 +353,7 @@ case class MapGroupsExec( object MapGroupsExec { def apply( - func: (Any, Iterator[Any], InternalState[Any]) => TraversableOnce[Any], + func: (Any, Iterator[Any], LogicalState[Any]) => TraversableOnce[Any], keyDeserializer: Expression, valueDeserializer: Expression, groupingAttributes: Seq[Attribute], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index f6a20d3840b0e..5eefca7aec955 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -239,8 +239,7 @@ trait ProgressReporter extends Logging { // Extract statistics about stateful operators in the query plan. val stateNodes = lastExecution.executedPlan.collect { - case p if - (p.isInstanceOf[StateStoreSaveExec] || p.isInstanceOf[MapGroupsWithStateExec]) => p + case p if p.isInstanceOf[StateStoreWriter] => p } val stateOperators = stateNodes.map { node => new StateOperatorProgress( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index e87a2e0a6387c..b0578961bce3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, InternalState} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalState} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution @@ -54,6 +54,18 @@ trait StatefulOperator extends SparkPlan { } } +trait StateStoreReader extends StatefulOperator { + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) +} + +trait StateStoreWriter extends StatefulOperator { + override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"), + "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows")) +} + /** * For each input tuple, the key is calculated and the value from the [[StateStore]] is added * to the stream (in addition to the input tuple) if present. @@ -64,9 +76,6 @@ case class StateStoreRestoreExec( child: SparkPlan) extends execution.UnaryExecNode with StatefulOperator { - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") @@ -102,12 +111,7 @@ case class StateStoreSaveExec( outputMode: Option[OutputMode] = None, eventTimeWatermark: Option[Long] = None, child: SparkPlan) - extends execution.UnaryExecNode with StatefulOperator { - - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"), - "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows")) + extends execution.UnaryExecNode with StateStoreWriter { /** Generate a predicate that matches data older than the watermark */ private lazy val watermarkPredicate: Option[Predicate] = { @@ -229,8 +233,12 @@ case class StateStoreSaveExec( override def outputPartitioning: Partitioning = child.outputPartitioning } + +/** + * Physical operator for executing streaming mapGroupsWithState. + */ case class MapGroupsWithStateExec( - func: (Any, Iterator[Any], InternalState[Any]) => Iterator[Any], + func: (Any, Iterator[Any], LogicalState[Any]) => Iterator[Any], keyDeserializer: Expression, // probably not needed valueDeserializer: Expression, groupingAttributes: Seq[Attribute], @@ -239,7 +247,7 @@ case class MapGroupsWithStateExec( stateId: Option[OperatorStateId], stateDeserializer: Expression, stateSerializer: Seq[NamedExpression], - child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StatefulOperator { + child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with StateStoreWriter { override def outputPartitioning: Partitioning = child.outputPartitioning @@ -249,17 +257,7 @@ case class MapGroupsWithStateExec( override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(groupingAttributes.map(SortOrder(_, Ascending))) // is this ordering needed? - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), - "numTotalStateRows" -> SQLMetrics.createMetric(sparkContext, "number of total state rows"), - "numUpdatedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of updated state rows"), - "numRemovedStateRows" -> SQLMetrics.createMetric(sparkContext, "number of removed state rows") - ) - override protected def doExecute(): RDD[InternalRow] = { - val numTotalStateRows = longMetric("numTotalStateRows") - val numUpdatedStateRows = longMetric("numUpdatedStateRows") - val numRemovedStateRows = longMetric("numRemovedStateRows") child.execute().mapPartitionsWithStateStore[InternalRow]( getStateId.checkpointLocation, @@ -270,12 +268,16 @@ case class MapGroupsWithStateExec( sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => try { + val numTotalStateRows = longMetric("numTotalStateRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val numOutputRows = longMetric("numOutputRows") + val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) val getKey = GenerateUnsafeProjection.generate(groupingAttributes, child.output) val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) - val outputMappedObj = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) val getStateObj = ObjectOperator.deserializeRowToObject(stateDeserializer) val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer) @@ -287,15 +289,19 @@ case class MapGroupsWithStateExec( val stateObjOption = store.get(key).map(getStateObj) val wrappedState = StateImpl[Any](stateObjOption) val mappedIterator = func(keyObj, valueObjIter, wrappedState) + if (wrappedState.isRemoved) { store.remove(key) - numRemovedStateRows += 1 + numUpdatedStateRows += 1 } else if (wrappedState.isUpdated) { store.put(key, outputStateObj(wrappedState.get())) numUpdatedStateRows += 1 } - mappedIterator.map(outputMappedObj.apply) + mappedIterator.map { obj => + numOutputRows += 1 + getOutputRow(obj) + } } CompletionIterator[InternalRow, Iterator[InternalRow]](finalIterator, { store.commit() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala index dfe6bf31b0d60..f6fab7c905713 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -37,7 +37,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { StateStore.stop() } - test("state - get, exists, update, remove, ") { + test("state - get, exists, update, remove") { var state: StateImpl[String] = null def testState( @@ -83,128 +83,125 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { testState(Some("4"), shouldBeRemoved = false, shouldBeUpdated = true) } - - - // ************* Batch query tests for [flat]mapGroupsWithState ************* - - test("batch - mapGroupsWithState") { - val stateFunc = (key: String, values: Iterator[String], state: State[RunningCount]) => { - assert(!state.exists) - assert(state.getOption.isEmpty) - (key, values.size) - } - - checkAnswer( - spark.createDataset(Seq("a", "a", "b")) - .groupByKey(x => x) - .mapGroupsWithState(stateFunc) - .toDF, - spark.createDataset(Seq(("a", 2), ("b", 1))).toDF) - } - - test("batch - flatMapGroupsWithState") { - // Function that returns running count only if its even, otherwise does not return - val stateFunc = (key: String, values: Iterator[String], state: State[RunningCount]) => { - assert(!state.exists) - assert(state.getOption.isEmpty) - if (values.size == 2) { - Iterator((key, values.size)) - } else Iterator.empty - } - checkAnswer( - Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF, - Seq(("a", 2), ("b", 1)).toDF) - } - - // ************* Streaming query tests for [flat]mapGroupsWithState ************* - - test("streaming - mapGroupsWithState") { + test("flatMapGroupsWithState - streaming") { val inputData = MemoryStream[String] // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + // Returns the data and the count if state is defined, otherwise does not return anything val stateFunc = (key: String, values: Iterator[String], state: State[RunningCount]) => { var count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { state.remove() - (key, "-1") + Iterator.empty } else { state.update(RunningCount(count)) - (key, count.toString) + Iterator((key, count.toString)) } } val result = inputData.toDS() .groupByKey(x => x) - .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) + .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) testStream(result, Append)( AddData(inputData, "a"), CheckLastBatch(("a", "1")), - assertNumStateRows(1), + assertNumStateRows(total = 1, updated = 1), AddData(inputData, "a", "b"), CheckLastBatch(("a", "2"), ("b", "1")), - assertNumStateRows(2), + assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), - AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 - CheckLastBatch(("a", "-1"), ("b", "2")), - assertNumStateRows(1), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckLastBatch(("b", "2")), + assertNumStateRows(total = 1, updated = 2), StopStream, StartStream(), - AddData(inputData, "a", "b", "c"), // should recreate state for "a" and return count as 1 - CheckLastBatch(("a", "1"), ("b", "-1"), ("c", "1")), - assertNumStateRows(2) + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckLastBatch(("a", "1"), ("c", "1")) + // assertNumStateRows(total = 3, updated = 2) ) } - test("streaming - flatMapGroupsWithState") { + test("flatMapGroupsWithState - batch") { + // Function that returns running count only if its even, otherwise does not return + val stateFunc = (key: String, values: Iterator[String], state: State[RunningCount]) => { + if (state.exists) throw new IllegalArgumentException("state.exists should be false") + if (state.getOption.nonEmpty) { + throw new IllegalArgumentException("state.getOption should be empty") + } + Iterator((key, values.size)) + } + checkAnswer( + Seq("a", "a", "b").toDS.groupByKey(x => x).flatMapGroupsWithState(stateFunc).toDF, + Seq(("a", 2), ("b", 1)).toDF) + } + + test("mapGroupsWithState - streaming") { val inputData = MemoryStream[String] // Function to maintain running count up to 2, and then remove the count - // Returns the data and the count if state is defined, otherwise does not return anything + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: State[RunningCount]) => { var count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { state.remove() - Iterator.empty + (key, "-1") } else { state.update(RunningCount(count)) - Iterator((key, count.toString)) + (key, count.toString) } } val result = inputData.toDS() .groupByKey(x => x) - .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) + .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) testStream(result, Append)( AddData(inputData, "a"), CheckLastBatch(("a", "1")), - assertNumStateRows(1), + assertNumStateRows(total = 1, updated = 1), AddData(inputData, "a", "b"), CheckLastBatch(("a", "2"), ("b", "1")), - assertNumStateRows(2), + assertNumStateRows(total = 2, updated = 2), StopStream, StartStream(), - AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a - CheckLastBatch( ("b", "2")), - assertNumStateRows(1), + AddData(inputData, "a", "b"), // should remove state for "a" and return count as -1 + CheckLastBatch(("a", "-1"), ("b", "2")), + assertNumStateRows(total = 1, updated = 2), StopStream, StartStream(), - AddData(inputData, "a", "b", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")), // ... not return anything for b - assertNumStateRows(2) + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 + CheckLastBatch(("a", "1"), ("c", "1")), + assertNumStateRows(total = 3, updated = 2) ) } - private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => + test("mapGroupsWithState - batch") { + val stateFunc = (key: String, values: Iterator[String], state: State[RunningCount]) => { + if (state.exists) throw new IllegalArgumentException("state.exists should be false") + if (state.getOption.nonEmpty) { + throw new IllegalArgumentException("state.getOption should be empty") + } + (key, values.size) + } + + checkAnswer( + spark.createDataset(Seq("a", "a", "b")) + .groupByKey(x => x) + .mapGroupsWithState(stateFunc) + .toDF, + spark.createDataset(Seq(("a", 2), ("b", 1))).toDF) + } + + private def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = AssertOnQuery { q => val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get - assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows) + assert(progressWithData.stateOperators(0).numRowsTotal === total) + assert(progressWithData.stateOperators(0).numRowsUpdated === updated) true } } From ddf4550b765af89a4ed7d80edabfe3370cbd1e23 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Mon, 30 Jan 2017 13:29:44 -0800 Subject: [PATCH 10/21] Added more test --- .../streaming/MapGroupsWithStateSuite.scala | 51 ++++++++++++++++--- 1 file changed, 44 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala index f6fab7c905713..e712b41904ad8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.streaming import org.scalatest.BeforeAndAfterAll +import org.apache.spark.SparkException import org.apache.spark.sql.State import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.StateImpl @@ -84,8 +85,6 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { } test("flatMapGroupsWithState - streaming") { - val inputData = MemoryStream[String] - // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything val stateFunc = (key: String, values: Iterator[String], state: State[RunningCount]) => { @@ -100,6 +99,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { } } + val inputData = MemoryStream[String] val result = inputData.toDS() .groupByKey(x => x) @@ -120,8 +120,8 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { StopStream, StartStream(), AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and - CheckLastBatch(("a", "1"), ("c", "1")) - // assertNumStateRows(total = 3, updated = 2) + CheckLastBatch(("a", "1"), ("c", "1")), + assertNumStateRows(total = 3, updated = 2) ) } @@ -140,13 +140,11 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { } test("mapGroupsWithState - streaming") { - val inputData = MemoryStream[String] - // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: State[RunningCount]) => { - var count = state.getOption.map(_.count).getOrElse(0L) + values.size + val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { state.remove() (key, "-1") @@ -156,6 +154,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { } } + val inputData = MemoryStream[String] val result = inputData.toDS() .groupByKey(x => x) @@ -198,6 +197,40 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { spark.createDataset(Seq(("a", 2), ("b", 1))).toDF) } + testQuietly("StateStore.abort on task failure handling") { + val stateFunc = (key: String, values: Iterator[String], state: State[RunningCount]) => { + if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") + val count = state.getOption.map(_.count).getOrElse(0L) + values.size + state.update(RunningCount(count)) + (key, count) + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) + + def setFailInTask(value: Boolean): AssertOnQuery = AssertOnQuery { q => + MapGroupsWithStateSuite.failInTask = value + true + } + + testStream(result, Append)( + setFailInTask(false), + AddData(inputData, "a"), + CheckLastBatch(("a", 1L)), + AddData(inputData, "a"), + CheckLastBatch(("a", 2L)), + setFailInTask(true), + AddData(inputData, "a"), + ExpectFailure[SparkException](), // task should fail but should not increment count + setFailInTask(false), + StartStream(), + CheckLastBatch(("a", 3L)) // task should not fail, and should show correct count + ) + } + private def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = AssertOnQuery { q => val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get assert(progressWithData.stateOperators(0).numRowsTotal === total) @@ -205,3 +238,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { true } } + +object MapGroupsWithStateSuite { + var failInTask = true +} From 6fab7a5fde75309198d1e73e66948d82d0f590e6 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 31 Jan 2017 11:16:16 -0800 Subject: [PATCH 11/21] Added docs --- .../spark/sql/KeyValueGroupedDataset.scala | 92 +++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 6a5f9c5b86923..0b05f412a6d94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -221,6 +221,29 @@ class KeyValueGroupedDataset[K, V] private[sql]( /** * ::Experimental:: * (Scala-specific) + * Applies the given function to each group of data, while using an additional keyed state. + * For each unique group, the function will be passed the group key and an iterator that contains + * all of the elements in the group. The function can return an object of arbitrary type, and + * optionally update or remove the corresponding state. The returned object will form a new + * [[Dataset]]. + * + * This function can be applied on both batch and streaming Datasets. With a streaming dataset, + * this function will be once for each in every trigger. For each key, the updated state from the + * function call in a trigger will be the state available in the function call in the next + * trigger. However, for batch, `mapGroupsWithState` behaves exactly as `mapGroups` and the + * function is called only once per key without any prior state. + * + * There is no guaranteed ordering of values in the iterator in the function. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @see [[State]] for more details of how to update/remove state in the function. * @since 2.1.1 */ @Experimental @@ -234,6 +257,29 @@ class KeyValueGroupedDataset[K, V] private[sql]( /** * ::Experimental:: * (Java-specific) + * Applies the given function to each group of data, while using an additional keyed state. + * For each unique group, the function will be passed the group key and an iterator that contains + * all of the elements in the group. The function can return an object of arbitrary type, and + * optionally update or remove the corresponding state. The returned object will form a new + * [[Dataset]]. + * + * This function can be applied on both batch and streaming Datasets. With a streaming dataset, + * this function will be once for each in every trigger. For each key, the updated state from the + * function call in a trigger will be the state available in the function call in the next + * trigger. However, for batch, `mapGroupsWithState` behaves exactly as `mapGroups` and the + * function is called only once per key without any prior state. + * + * There is no guaranteed ordering of values in the iterator in the function. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @see [[State]] for more details of how to update/remove state in the function. * @since 2.1.1 */ @Experimental @@ -250,6 +296,29 @@ class KeyValueGroupedDataset[K, V] private[sql]( /** * ::Experimental:: * (Scala-specific) + * Applies the given function to each group of data, while using an additional keyed state. + * For each unique group, the function will be passed the group key and an iterator that contains + * all of the elements in the group. The function can return an iteratior of object of arbitrary + * type, and optionally update or remove the corresponding state. The returned object will form a + * new [[Dataset]]. + * + * This function can be applied on both batch and streaming Datasets. With a streaming dataset, + * this function will be once for each in every trigger. For each key, the updated state from the + * function call in a trigger will be the state available in the function call in the next + * trigger. However, for batch, `mapGroupsWithState` behaves exactly as `mapGroups` and the + * function is called only once per key without any prior state. + * + * There is no guaranteed ordering of values in the iterator in the function. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @see [[State]] for more details of how to update/remove state in the function. * @since 2.1.1 */ @Experimental @@ -268,6 +337,29 @@ class KeyValueGroupedDataset[K, V] private[sql]( /** * ::Experimental:: * (Java-specific) + * Applies the given function to each group of data, while using an additional keyed state. + * For each unique group, the function will be passed the group key and an iterator that contains + * all of the elements in the group. The function can return an iteratior of object of arbitrary + * type, and optionally update or remove the corresponding state. The returned object will form a + * new [[Dataset]]. + * + * This function can be applied on both batch and streaming Datasets. With a streaming dataset, + * this function will be once for each in every trigger. For each key, the updated state from the + * function call in a trigger will be the state available in the function call in the next + * trigger. However, for batch, `mapGroupsWithState` behaves exactly as `flatMapGroups` and the + * function is called only once per key without any prior state. + * + * There is no guaranteed ordering of values in the iterator in the function. + * + * This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. + * + * Internally, the implementation will spill to disk if any given group is too large to fit into + * memory. However, users must take care to avoid materializing the whole iterator for a group + * (for example, by calling `toList`) unless they are sure that this is possible given the memory + * constraints of their cluster. + * + * @see [[State]] for more details of how to update/remove state in the function. * @since 2.1.1 */ @Experimental From 8be63de507f9e4fc258aca1d467509e06b674565 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 31 Jan 2017 11:51:25 -0800 Subject: [PATCH 12/21] Added docs --- .../sql/catalyst/plans/logical/object.scala | 16 +++++++++++++++- .../spark/sql/execution/SparkStrategies.scala | 4 ++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 7ae50a0007fd1..136d52173ab13 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -316,7 +316,7 @@ case class MapGroups( /** Internal class representing State */ trait LogicalState[S] -/** Factory for constructing new `MapGroups` nodes. */ +/** Factory for constructing new `MapGroupsWithState` nodes. */ object MapGroupsWithState { def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder]( func: (Any, Iterator[Any], LogicalState[Any]) => Iterator[Any], @@ -337,6 +337,20 @@ object MapGroupsWithState { } } +/** + * Applies func to each unique group in `child`, based on the evaluation of `groupingAttributes`, + * while using state data. + * Func is invoked with an object representation of the grouping key an iterator containing the + * object representation of all the rows with that key. + * + * @param keyDeserializer used to extract the key object for each group. + * @param valueDeserializer used to extract the items in the iterator from an input row. + * @param groupingAttributes used to group the data + * @param dataAttributes used to read the data + * @param outputObjAttr used to define the output object + * @param stateDeserializer used to deserialize state before calling `func` + * @param stateSerializer used to serialize updated state after calling `func` + */ case class MapGroupsWithState( func: (Any, Iterator[Any], LogicalState[Any]) => Iterator[Any], keyDeserializer: Expression, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b0bb1f6c00305..7a78ed9663911 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -313,6 +313,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Strategy to convert MapGroupsWithState logical operator to physical operator + * in streaming plans. Conversion for batch plans is handled by [[BasicOperators]]. + */ object MapGroupsWithStateStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case MapGroupsWithState( From 34449e4c9de2ed8bd3b0d81ec7bc31580b066420 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 31 Jan 2017 15:33:49 -0800 Subject: [PATCH 13/21] Addressed many comments --- .../UnsupportedOperationChecker.scala | 11 +- .../sql/catalyst/plans/logical/object.scala | 6 +- .../analysis/UnsupportedOperationsSuite.scala | 25 +++- .../spark/sql/KeyValueGroupedDataset.scala | 28 +++-- .../org/apache/spark/sql/KeyedState.scala | 116 ++++++++++++++++++ .../scala/org/apache/spark/sql/State.scala | 101 --------------- .../{StateImpl.scala => KeyedStateImpl.scala} | 14 +-- .../apache/spark/sql/execution/objects.scala | 9 +- .../streaming/IncrementalExecution.scala | 6 +- .../streaming/StatefulAggregate.scala | 103 +++++++--------- .../execution/streaming/state/package.scala | 11 +- .../apache/spark/sql/JavaDatasetSuite.java | 4 +- .../streaming/MapGroupsWithStateSuite.scala | 28 ++--- 13 files changed, 249 insertions(+), 213 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/State.scala rename sql/core/src/main/scala/org/apache/spark/sql/execution/{StateImpl.scala => KeyedStateImpl.scala} (81%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index f4d016cb96711..a202064828db2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -46,8 +46,13 @@ object UnsupportedOperationChecker { "Queries without streaming sources cannot be executed with writeStream.start()")(plan) } + /** Collect all the streaming aggregates in a sub plan */ + def collectStreamingAggregates(subplan: LogicalPlan): Seq[Aggregate] = { + subplan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a } + } + // Disallow multiple streaming aggregations - val aggregates = plan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a } + val aggregates = collectStreamingAggregates(plan) if (aggregates.size > 1) { throwError( @@ -114,6 +119,10 @@ object UnsupportedOperationChecker { case _: InsertIntoTable => throwError("InsertIntoTable is not supported with streaming DataFrames/Datasets") + case m: MapGroupsWithState if collectStreamingAggregates(m).nonEmpty => + throwError("(map/flatMap)GroupsWithState is not supported after aggregation on a " + + "streaming DataFrame/Dataset") + case Join(left, right, joinType, _) => joinType match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 136d52173ab13..c632b2d80330e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -314,12 +314,12 @@ case class MapGroups( child: LogicalPlan) extends UnaryNode with ObjectProducer /** Internal class representing State */ -trait LogicalState[S] +trait LogicalKeyedState[S] /** Factory for constructing new `MapGroupsWithState` nodes. */ object MapGroupsWithState { def apply[K: Encoder, V: Encoder, S: Encoder, U: Encoder]( - func: (Any, Iterator[Any], LogicalState[Any]) => Iterator[Any], + func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], child: LogicalPlan): LogicalPlan = { @@ -352,7 +352,7 @@ object MapGroupsWithState { * @param stateSerializer used to serialize updated state after calling `func` */ case class MapGroupsWithState( - func: (Any, Iterator[Any], LogicalState[Any]) => Iterator[Any], + func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], keyDeserializer: Expression, valueDeserializer: Expression, groupingAttributes: Seq[Attribute], diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index dcdb1ae089328..c17e00b497599 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -22,13 +22,13 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{MapGroupsWithState, _} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{IntegerType, LongType} /** A dummy command for testing unsupported operations. */ case class DummyCommand() extends Command @@ -111,6 +111,25 @@ class UnsupportedOperationsSuite extends SparkFunSuite { outputMode = Complete, expectedMsgs = Seq("distinct aggregation")) + // MapGroupsWithState: Not supported after a streaming aggregation + val att = new AttributeReference(name = "a", dataType = LongType)() + assertSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState on batch relation", + MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation), + outputMode = Append) + + assertSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState on streaming relation before aggregation", + MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), streamRelation), + outputMode = Append) + + assertNotSupportedInStreamingPlan( + "mapGroupsWithState - mapGroupsWithState on streaming relation after aggregation", + MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), + Aggregate(Nil, aggExprs("c"), streamRelation)), + outputMode = Complete, + expectedMsgs = Seq("(map/flatMap)GroupsWithState")) + // Inner joins: Stream-stream not supported testBinaryOperationInStreamingPlan( "inner join", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 0b05f412a6d94..75773467528a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -243,15 +243,15 @@ class KeyValueGroupedDataset[K, V] private[sql]( * (for example, by calling `toList`) unless they are sure that this is possible given the memory * constraints of their cluster. * - * @see [[State]] for more details of how to update/remove state in the function. + * @see [[KeyedState]] for more details of how to update/remove state in the function. * @since 2.1.1 */ @Experimental @InterfaceStability.Evolving def mapGroupsWithState[STATE: Encoder, OUT: Encoder]( - func: (K, Iterator[V], State[STATE]) => OUT): Dataset[OUT] = { - val f = (key: K, it: Iterator[V], s: State[STATE]) => Iterator(func(key, it, s)) - flatMapGroupsWithState[STATE, OUT](f) + func: (K, Iterator[V], KeyedState[STATE]) => OUT): Dataset[OUT] = { + flatMapGroupsWithState[STATE, OUT]( + (key: K, it: Iterator[V], s: KeyedState[STATE]) => Iterator(func(key, it, s))) } /** @@ -279,7 +279,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * (for example, by calling `toList`) unless they are sure that this is possible given the memory * constraints of their cluster. * - * @see [[State]] for more details of how to update/remove state in the function. + * @see [[KeyedState]] for more details of how to update/remove state in the function. * @since 2.1.1 */ @Experimental @@ -288,8 +288,9 @@ class KeyValueGroupedDataset[K, V] private[sql]( func: MapGroupsWithStateFunction[K, V, STATE, OUT], stateEncoder: Encoder[STATE], outputEncoder: Encoder[OUT]): Dataset[OUT] = { - val f = (key: K, it: Iterator[V], s: State[STATE]) => Iterator(func.call(key, it.asJava, s)) - flatMapGroupsWithState[STATE, OUT](f)(stateEncoder, outputEncoder) + flatMapGroupsWithState[STATE, OUT]( + (key: K, it: Iterator[V], s: KeyedState[STATE]) => Iterator(func.call(key, it.asJava, s)) + )(stateEncoder, outputEncoder) } @@ -318,17 +319,17 @@ class KeyValueGroupedDataset[K, V] private[sql]( * (for example, by calling `toList`) unless they are sure that this is possible given the memory * constraints of their cluster. * - * @see [[State]] for more details of how to update/remove state in the function. + * @see [[KeyedState]] for more details of how to update/remove state in the function. * @since 2.1.1 */ @Experimental @InterfaceStability.Evolving def flatMapGroupsWithState[STATE: Encoder, OUT: Encoder]( - func: (K, Iterator[V], State[STATE]) => Iterator[OUT]): Dataset[OUT] = { + func: (K, Iterator[V], KeyedState[STATE]) => Iterator[OUT]): Dataset[OUT] = { Dataset[OUT]( sparkSession, MapGroupsWithState[K, V, STATE, OUT]( - func.asInstanceOf[(Any, Iterator[Any], LogicalState[Any]) => Iterator[Any]], + func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], groupingAttributes, dataAttributes, logicalPlan)) @@ -359,7 +360,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( * (for example, by calling `toList`) unless they are sure that this is possible given the memory * constraints of their cluster. * - * @see [[State]] for more details of how to update/remove state in the function. + * @see [[KeyedState]] for more details of how to update/remove state in the function. * @since 2.1.1 */ @Experimental @@ -368,8 +369,9 @@ class KeyValueGroupedDataset[K, V] private[sql]( func: FlatMapGroupsWithStateFunction[K, V, STATE, OUT], stateEncoder: Encoder[STATE], outputEncoder: Encoder[OUT]): Dataset[OUT] = { - val f = (key: K, it: Iterator[V], s: State[STATE]) => func.call(key, it.asJava, s).asScala - flatMapGroupsWithState[STATE, OUT](f)(stateEncoder, outputEncoder) + flatMapGroupsWithState[STATE, OUT]( + (key: K, it: Iterator[V], s: KeyedState[STATE]) => func.call(key, it.asJava, s).asScala + )(stateEncoder, outputEncoder) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala new file mode 100644 index 0000000000000..7bc0af996725d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.annotation.{Experimental, InterfaceStability} +import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState + +/** + * :: Experimental :: + * + * Wrapper class for interacting with keyed state data in `mapGroupsWithState` and + * `flatMapGroupsWithState` operations on + * [[org.apache.spark.sql.KeyValueGroupedDataset KeyValueGroupedDataset]]. + * + * Important points to note. + * - State can be `null`. So updating the state to null is not same as removing the state. + * - Operations on state are not threadsafe. This is to avoid memory barriers. + * - If the `remove()` is called, then `exists()` will return `false`, and + * `getOption()` will return `None`. + * - After that `update(newState)` is called, then `exists()` will return `true`, + * and `getOption()` will return `Some(...)`. + * + * Scala example of using `KeyedState`: + * {{{ + * // A mapping function that maintains an integer state for string keys and returns a string. + * def mappingFunction(key: String, value: Iterable[Int], state: KeyedState[Int]): Option[String]= { + * // Check if state exists + * if (state.exists) { + * val existingState = state.get // Get the existing state + * val shouldRemove = ... // Decide whether to remove the state + * if (shouldRemove) { + * state.remove() // Remove the state + * } else { + * val newState = ... + * state.update(newState) // Set the new state + * } + * } else { + * val initialState = ... + * state.update(initialState) // Set the initial state + * } + * ... // return something + * } + * + * }}} + * + * Java example of using `KeyedState`: + * {{{ + * // A mapping function that maintains an integer state for string keys and returns a string. + * MapGroupsWithStateFunction mappingFunction = + * new MapGroupsWithStateFunction() { + * + * @Override + * public String call(String key, Optional value, KeyedState state) { + * if (state.exists()) { + * int existingState = state.get(); // Get the existing state + * boolean shouldRemove = ...; // Decide whether to remove the state + * if (shouldRemove) { + * state.remove(); // Remove the state + * } else { + * int newState = ...; + * state.update(newState); // Set the new state + * } + * } else { + * int initialState = ...; // Set the initial state + * state.update(initialState); + * } + * ... // return something + * } + * }; + * }}} + * + * @tparam S User-defined type of the state to be stored for each key. Must be encodable into + * Spark SQL types (see [[Encoder]] for more details). + * @since 2.1.1 + */ +@Experimental +@InterfaceStability.Evolving +trait KeyedState[S] extends LogicalKeyedState[S] { + + /** Whether state exists or not. */ + def exists: Boolean + + /** Get the state object if it is defined, otherwise throws NoSuchElementException. */ + def get: S + + /** + * Update the value of the state. Note that null is a valid value, and does not signify removing + * of the state. + */ + def update(newState: S): Unit + + /** Remove this keyed state. */ + def remove(): Unit + + /** (scala friendly) Get the state object as an [[Option]]. */ + @inline final def getOption: Option[S] = if (exists) Some(get) else None + + @inline final override def toString: String = { + getOption.map { _.toString }.getOrElse("") + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/State.scala b/sql/core/src/main/scala/org/apache/spark/sql/State.scala deleted file mode 100644 index 65f621fa05e57..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/State.scala +++ /dev/null @@ -1,101 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql - -import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.catalyst.plans.logical.LogicalState - -/** - * :: Experimental :: - * - * Wrapper class for interacting with state data in `mapGroupsWithState` and - * `flatMapGroupsWithState` operations on - * [[org.apache.spark.sql.KeyValueGroupedDataset KeyValueGroupedDataset]]. - * - * @note Operations on state are not threadsafe. - * - * Scala example of using `State`: - * {{{ - * // A mapping function that maintains an integer state for string keys and returns a string. - * def mappingFunction(key: String, value: Iterable[Int], state: State[Int]): Option[String] = { - * // Check if state exists - * if (state.exists) { - * val existingState = state.get // Get the existing state - * val shouldRemove = ... // Decide whether to remove the state - * if (shouldRemove) { - * state.remove() // Remove the state - * } else { - * val newState = ... - * state.update(newState) // Set the new state - * } - * } else { - * val initialState = ... - * state.update(initialState) // Set the initial state - * } - * ... // return something - * } - * - * }}} - * - * Java example of using `State`: - * {{{ - * // A mapping function that maintains an integer state for string keys and returns a string. - * Function3, State, String> mappingFunction = - * new Function3, State, String>() { - * - * @Override - * public String call(String key, Optional value, State state) { - * if (state.exists()) { - * int existingState = state.get(); // Get the existing state - * boolean shouldRemove = ...; // Decide whether to remove the state - * if (shouldRemove) { - * state.remove(); // Remove the state - * } else { - * int newState = ...; - * state.update(newState); // Set the new state - * } - * } else { - * int initialState = ...; // Set the initial state - * state.update(initialState); - * } - * ... // return something - * } - * }; - * }}} - * - * @tparam S Type of the state - * @since 2.1.1 - */ -@Experimental -@InterfaceStability.Evolving -trait State[S] extends LogicalState[S] { - - def exists: Boolean - - def get(): S - - def update(newState: S): Unit - - def remove(): Unit - - @inline final def getOption(): Option[S] = if (exists) Some(get()) else None - - @inline final override def toString(): String = { - getOption.map { _.toString }.getOrElse("") - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/StateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyedStateImpl.scala similarity index 81% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/StateImpl.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/KeyedStateImpl.scala index c6c8d12c49083..59107e69250e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/StateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyedStateImpl.scala @@ -17,21 +17,21 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.State +import org.apache.spark.sql.KeyedState -/** Internal implementation of the [[State]] interface */ -private[sql] class StateImpl[S](optionalValue: Option[S]) extends State[S] { +/** Internal implementation of the [[KeyedState]] interface */ +private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedState[S] { private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) private var defined: Boolean = optionalValue.isDefined private var updated: Boolean = false // whether value has been updated (but not removed) - private var removed: Boolean = false // whether value has eben removed + private var removed: Boolean = false // whether value has been removed // ========= Public API ========= override def exists: Boolean = { defined } - override def get(): S = { + override def get: S = { if (defined) { value } else { @@ -65,6 +65,6 @@ private[sql] class StateImpl[S](optionalValue: Option[S]) extends State[S] { } } -object StateImpl { - def apply[S](optionalValue: Option[S]): StateImpl[S] = new StateImpl[S](optionalValue) +object KeyedStateImpl { + def apply[S](optionalValue: Option[S]): KeyedStateImpl[S] = new KeyedStateImpl[S](optionalValue) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 2615eaa5ffe3b..124843c16f7ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.plans.logical.LogicalState +import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState import org.apache.spark.sql.types.{DataType, ObjectType, StructType} @@ -145,8 +145,7 @@ object ObjectOperator { (i: InternalRow) => proj(i).get(0, deserializer.dataType) } - def deserializeRowToObject( - deserializer: Expression): InternalRow => Any = { + def deserializeRowToObject(deserializer: Expression): InternalRow => Any = { val proj = GenerateSafeProjection.generate(deserializer :: Nil) (i: InternalRow) => proj(i).get(0, deserializer.dataType) } @@ -353,14 +352,14 @@ case class MapGroupsExec( object MapGroupsExec { def apply( - func: (Any, Iterator[Any], LogicalState[Any]) => TraversableOnce[Any], + func: (Any, Iterator[Any], LogicalKeyedState[Any]) => TraversableOnce[Any], keyDeserializer: Expression, valueDeserializer: Expression, groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], outputObjAttr: Attribute, child: SparkPlan): MapGroupsExec = { - val f = (key: Any, values: Iterator[Any]) => func(key, values, StateImpl[Any](None)) + val f = (key: Any, values: Iterator[Any]) => func(key, values, KeyedStateImpl[Any](None)) new MapGroupsExec(f, keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr, child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index ef0cc74f3ffe3..a3e108b29eda6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -94,13 +94,11 @@ class IncrementalExecution( Some(stateId), child) :: Nil)) case MapGroupsWithStateExec( - func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, - None, stateDeser, stateSer, child) => + f, kDeser, vDeser, group, data, output, None, stateDeser, stateSer, child) => val stateId = OperatorStateId(checkpointLocation, operatorId.getAndIncrement(), currentBatchId) MapGroupsWithStateExec( - func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, - Some(stateId), stateDeser, stateSer, child) + f, kDeser, vDeser, group, data, output, Some(stateId), stateDeser, stateSer, child) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala index e6b4d1fa4d075..88e4d112d871a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, Predicate} -import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalState} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalKeyedState} import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution @@ -32,7 +32,6 @@ import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.util.CompletionIterator -import org.apache.spark.TaskContext /** Used to identify the state store for a given operator. */ @@ -156,13 +155,6 @@ case class StateStoreSaveExec( val numTotalStateRows = longMetric("numTotalStateRows") val numUpdatedStateRows = longMetric("numUpdatedStateRows") - // Abort the state store in case of error - TaskContext.get().addTaskCompletionListener(_ => { - if (!store.hasCommitted) { - store.abort() - } - }) - outputMode match { // Update and output all rows in the StateStore. case Some(Complete) => @@ -242,12 +234,10 @@ case class StateStoreSaveExec( } -/** - * Physical operator for executing streaming mapGroupsWithState. - */ +/** Physical operator for executing streaming mapGroupsWithState. */ case class MapGroupsWithStateExec( - func: (Any, Iterator[Any], LogicalState[Any]) => Iterator[Any], - keyDeserializer: Expression, // probably not needed + func: (Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any], + keyDeserializer: Expression, valueDeserializer: Expression, groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], @@ -259,14 +249,15 @@ case class MapGroupsWithStateExec( override def outputPartitioning: Partitioning = child.outputPartitioning + /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(groupingAttributes) :: Nil + /** Ordering needed for using GroupingIterator */ override def requiredChildOrdering: Seq[Seq[SortOrder]] = - Seq(groupingAttributes.map(SortOrder(_, Ascending))) // is this ordering needed? + Seq(groupingAttributes.map(SortOrder(_, Ascending))) override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitionsWithStateStore[InternalRow]( getStateId.checkpointLocation, operatorId = getStateId.operatorId, @@ -275,51 +266,45 @@ case class MapGroupsWithStateExec( child.output.toStructType, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { (store, iter) => - try { - val numTotalStateRows = longMetric("numTotalStateRows") - val numUpdatedStateRows = longMetric("numUpdatedStateRows") - val numOutputRows = longMetric("numOutputRows") - - val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) - - val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) - val getKey = GenerateUnsafeProjection.generate(groupingAttributes, child.output) - val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) - val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - val getStateObj = - ObjectOperator.deserializeRowToObject(stateDeserializer) - val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer) - - val finalIterator = groupedIter.flatMap { case (keyRow, valueRowIter) => - val key = keyRow.asInstanceOf[UnsafeRow] - val keyObj = getKeyObj(keyRow) - val valueObjIter = valueRowIter.map(getValueObj.apply) - val stateObjOption = store.get(key).map(getStateObj) - val wrappedState = StateImpl[Any](stateObjOption) - val mappedIterator = func(keyObj, valueObjIter, wrappedState) - - if (wrappedState.isRemoved) { - store.remove(key) - numUpdatedStateRows += 1 - } else if (wrappedState.isUpdated) { - store.put(key, outputStateObj(wrappedState.get())) - numUpdatedStateRows += 1 - } + val numTotalStateRows = longMetric("numTotalStateRows") + val numUpdatedStateRows = longMetric("numUpdatedStateRows") + val numOutputRows = longMetric("numOutputRows") - mappedIterator.map { obj => - numOutputRows += 1 - getOutputRow(obj) - } + val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) + + val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) + val getKey = GenerateUnsafeProjection.generate(groupingAttributes, child.output) + val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) + val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) + val getStateObj = + ObjectOperator.deserializeRowToObject(stateDeserializer) + val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer) + + val finalIterator = groupedIter.flatMap { case (keyRow, valueRowIter) => + val key = keyRow.asInstanceOf[UnsafeRow] + val keyObj = getKeyObj(keyRow) + val valueObjIter = valueRowIter.map(getValueObj.apply) + val stateObjOption = store.get(key).map(getStateObj) + val wrappedState = KeyedStateImpl[Any](stateObjOption) + val mappedIterator = func(keyObj, valueObjIter, wrappedState) + + if (wrappedState.isRemoved) { + store.remove(key) + numUpdatedStateRows += 1 + } else if (wrappedState.isUpdated) { + store.put(key, outputStateObj(wrappedState.get)) + numUpdatedStateRows += 1 + } + + mappedIterator.map { obj => + numOutputRows += 1 + getOutputRow(obj) } - CompletionIterator[InternalRow, Iterator[InternalRow]](finalIterator, { - store.commit() - numTotalStateRows += store.numKeys() - }) - } catch { - case e: Throwable => - store.abort() - throw e } - } + CompletionIterator[InternalRow, Iterator[InternalRow]](finalIterator, { + store.commit() + numTotalStateRows += store.numKeys() + }) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index 1b56c08f729c6..589042afb1e52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming import scala.reflect.ClassTag +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.internal.SessionState @@ -59,10 +60,18 @@ package object state { sessionState: SessionState, storeCoordinator: Option[StateStoreCoordinatorRef])( storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { + val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) + val wrappedF = (store: StateStore, iter: Iterator[T]) => { + // Abort the state store in case of error + TaskContext.get().addTaskCompletionListener(_ => { + if (!store.hasCommitted) store.abort() + }) + cleanedF(store, iter) + } new StateStoreRDD( dataRDD, - cleanedF, + wrappedF, checkpointLocation, operatorId, storeVersion, diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 6260b6127d52e..5ef4e887ded09 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -228,7 +228,7 @@ public Iterator call(Integer key, Iterator values) { Dataset mapped2 = grouped.mapGroupsWithState( new MapGroupsWithStateFunction() { @Override - public String call(Integer key, Iterator values, State s) throws Exception { + public String call(Integer key, Iterator values, KeyedState s) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); while (values.hasNext()) { sb.append(values.next()); @@ -244,7 +244,7 @@ public String call(Integer key, Iterator values, State s) throws E Dataset flatMapped2 = grouped.flatMapGroupsWithState( new FlatMapGroupsWithStateFunction() { @Override - public Iterator call(Integer key, Iterator values, State s) { + public Iterator call(Integer key, Iterator values, KeyedState s) { StringBuilder sb = new StringBuilder(key.toString()); while (values.hasNext()) { sb.append(values.next()); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala index e712b41904ad8..57e0676a1cfb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.streaming import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException -import org.apache.spark.sql.State +import org.apache.spark.sql.KeyedState import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.StateImpl +import org.apache.spark.sql.execution.KeyedStateImpl import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.execution.streaming.state.StateStore @@ -39,7 +39,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { } test("state - get, exists, update, remove") { - var state: StateImpl[String] = null + var state: KeyedStateImpl[String] = null def testState( expectedData: Option[String], @@ -48,14 +48,14 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { ): Unit = { if (expectedData.isDefined) { assert(state.exists) - assert(state.get() === expectedData.get) - assert(state.getOption() === expectedData) + assert(state.get === expectedData.get) + assert(state.getOption === expectedData) } else { assert(!state.exists) intercept[NoSuchElementException] { - state.get() + state.get } - assert(state.getOption() === None) + assert(state.getOption === None) } assert(state.isUpdated === shouldBeUpdated) @@ -63,13 +63,13 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { } // Updating empty state - state = StateImpl[String](None) + state = KeyedStateImpl[String](None) testState(None) state.update("") testState(Some(""), shouldBeUpdated = true) // Updating exiting state, even if with null - state = StateImpl[String](Some("2")) + state = KeyedStateImpl[String](Some("2")) testState(Some("2")) state.update("3") testState(Some("3"), shouldBeUpdated = true) @@ -87,7 +87,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { test("flatMapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything - val stateFunc = (key: String, values: Iterator[String], state: State[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { var count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { @@ -127,7 +127,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { test("flatMapGroupsWithState - batch") { // Function that returns running count only if its even, otherwise does not return - val stateFunc = (key: String, values: Iterator[String], state: State[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { if (state.exists) throw new IllegalArgumentException("state.exists should be false") if (state.getOption.nonEmpty) { throw new IllegalArgumentException("state.getOption should be empty") @@ -142,7 +142,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { test("mapGroupsWithState - streaming") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) - val stateFunc = (key: String, values: Iterator[String], state: State[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { @@ -181,7 +181,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { } test("mapGroupsWithState - batch") { - val stateFunc = (key: String, values: Iterator[String], state: State[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { if (state.exists) throw new IllegalArgumentException("state.exists should be false") if (state.getOption.nonEmpty) { throw new IllegalArgumentException("state.getOption should be empty") @@ -198,7 +198,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { } testQuietly("StateStore.abort on task failure handling") { - val stateFunc = (key: String, values: Iterator[String], state: State[RunningCount]) => { + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") val count = state.getOption.map(_.count).getOrElse(0L) + values.size state.update(RunningCount(count)) From 3628af82fb85e2a6dc145af5b60dfb8d23f71819 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 31 Jan 2017 15:39:42 -0800 Subject: [PATCH 14/21] Addressed more comments --- .../main/scala/org/apache/spark/sql/execution/objects.scala | 1 + .../spark/sql/execution/{ => streaming}/KeyedStateImpl.scala | 2 +- .../{StatefulAggregate.scala => statefulOperators.scala} | 0 .../apache/spark/sql/streaming/MapGroupsWithStateSuite.scala | 3 +-- 4 files changed, 3 insertions(+), 3 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/{ => streaming}/KeyedStateImpl.scala (97%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/{StatefulAggregate.scala => statefulOperators.scala} (100%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 124843c16f7ab..de7e9ad00ec02 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState +import org.apache.spark.sql.execution.streaming.KeyedStateImpl import org.apache.spark.sql.types.{DataType, ObjectType, StructType} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala similarity index 97% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/KeyedStateImpl.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala index 59107e69250e0..0e94220f8990e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/KeyedStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.KeyedState diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala index 57e0676a1cfb2..7fc1e6eed5eb5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -22,8 +22,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql.KeyedState import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.KeyedStateImpl -import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.{KeyedStateImpl, MemoryStream} import org.apache.spark.sql.execution.streaming.state.StateStore /** Class to check custom state types */ From af4f1f2cf184f155f80a55196a544f12c9ef6102 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 31 Jan 2017 16:02:53 -0800 Subject: [PATCH 15/21] Addresed more comments --- .../spark/sql/KeyValueGroupedDataset.scala | 131 +++++++++++------- .../spark/sql/execution/SparkStrategies.scala | 6 +- .../streaming/statefulOperators.scala | 4 +- 3 files changed, 86 insertions(+), 55 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 75773467528a7..8500bd1cad994 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -221,22 +221,30 @@ class KeyValueGroupedDataset[K, V] private[sql]( /** * ::Experimental:: * (Scala-specific) - * Applies the given function to each group of data, while using an additional keyed state. - * For each unique group, the function will be passed the group key and an iterator that contains - * all of the elements in the group. The function can return an object of arbitrary type, and + * Applies the given function to each group of data, while maintaining some user-defined per-group + * state. + * + * For each unique group, the given function will be invoked once for each group + * with the following arguments: + * - The key of the group. + * - An iterator containing all the values for this key. + * - A user-defined state object set by previous invocations of the given function. + * Note that, for batch queries, there is only ever one invocation and thus the state object + * will always be empty. And the function can return an object of arbitrary type, and * optionally update or remove the corresponding state. The returned object will form a new * [[Dataset]]. * - * This function can be applied on both batch and streaming Datasets. With a streaming dataset, - * this function will be once for each in every trigger. For each key, the updated state from the - * function call in a trigger will be the state available in the function call in the next - * trigger. However, for batch, `mapGroupsWithState` behaves exactly as `mapGroups` and the - * function is called only once per key without any prior state. + * This operation can be applied on both batch and streaming Datasets. With a streaming dataset, + * the given function will be invoked once for each group in every trigger/batch that has + * data in the group. The updates to the state will be stored and passed to the function in the + * next invocation. However, for batch, `mapGroupsWithState` behaves exactly as `mapGroups` and + * the function is called only once per key without any prior state. * - * There is no guaranteed ordering of values in the iterator in the function. - * - * This function does not support partial aggregation, and as a result requires shuffling all + * Other points to note + * - There is no guaranteed ordering of values in the iterator in the function. + * - This function does not support partial aggregation, and as a result requires shuffling all * the data in the [[Dataset]]. + * - Operations on [[KeyedState]] are not threadsafe. See corresponding docs for more details. * * Internally, the implementation will spill to disk if any given group is too large to fit into * memory. However, users must take care to avoid materializing the whole iterator for a group @@ -257,22 +265,30 @@ class KeyValueGroupedDataset[K, V] private[sql]( /** * ::Experimental:: * (Java-specific) - * Applies the given function to each group of data, while using an additional keyed state. - * For each unique group, the function will be passed the group key and an iterator that contains - * all of the elements in the group. The function can return an object of arbitrary type, and + * Applies the given function to each group of data, while maintaining some user-defined per-group + * state. + * + * For each unique group, the given function will be invoked once for each group + * with the following arguments: + * - The key of the group. + * - An iterator containing all the values for this key. + * - A user-defined state object set by previous invocations of the given function. + * Note that, for batch queries, there is only ever one invocation and thus the state object + * will always be empty. And the function can return an object of arbitrary type, and * optionally update or remove the corresponding state. The returned object will form a new * [[Dataset]]. * - * This function can be applied on both batch and streaming Datasets. With a streaming dataset, - * this function will be once for each in every trigger. For each key, the updated state from the - * function call in a trigger will be the state available in the function call in the next - * trigger. However, for batch, `mapGroupsWithState` behaves exactly as `mapGroups` and the - * function is called only once per key without any prior state. + * This operation can be applied on both batch and streaming Datasets. With a streaming dataset, + * the given function will be invoked once for each group in every trigger/batch that has + * data in the group. The updates to the state will be stored and passed to the function in the + * next invocation. However, for batch, `mapGroupsWithState` behaves exactly as `mapGroups` and + * the function is called only once per key without any prior state. * - * There is no guaranteed ordering of values in the iterator in the function. - * - * This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. + * Other points to note + * - There is no guaranteed ordering of values in the iterator in the function. + * - This function does not support partial aggregation, and as a result requires shuffling all + * the data in the [[Dataset]]. + * - Operations on [[KeyedState]] are not threadsafe. See corresponding docs for more details. * * Internally, the implementation will spill to disk if any given group is too large to fit into * memory. However, users must take care to avoid materializing the whole iterator for a group @@ -293,26 +309,33 @@ class KeyValueGroupedDataset[K, V] private[sql]( )(stateEncoder, outputEncoder) } - /** * ::Experimental:: * (Scala-specific) - * Applies the given function to each group of data, while using an additional keyed state. - * For each unique group, the function will be passed the group key and an iterator that contains - * all of the elements in the group. The function can return an iteratior of object of arbitrary - * type, and optionally update or remove the corresponding state. The returned object will form a - * new [[Dataset]]. - * - * This function can be applied on both batch and streaming Datasets. With a streaming dataset, - * this function will be once for each in every trigger. For each key, the updated state from the - * function call in a trigger will be the state available in the function call in the next - * trigger. However, for batch, `mapGroupsWithState` behaves exactly as `mapGroups` and the - * function is called only once per key without any prior state. + * Applies the given function to each group of data, while maintaining some user-defined per-group + * state. + * + * For each unique group, the given function will be invoked once for each group + * with the following arguments: + * - The key of the group. + * - An iterator containing all the values for this key. + * - A user-defined state object set by previous invocations of the given function. + * Note that, for batch queries, there is only ever one invocation and thus the state object + * will always be empty. And the function can return an iterator of objects of arbitrary type, and + * optionally update or remove the corresponding state. The returned object will form a new + * [[Dataset]]. * - * There is no guaranteed ordering of values in the iterator in the function. + * This operation can be applied on both batch and streaming Datasets. With a streaming dataset, + * the given function will be invoked once for each group in every trigger/batch that has + * data in the group. The updates to the state will be stored and passed to the function in the + * next invocation. However, for batch, `mapGroupsWithState` behaves exactly as `mapGroups` and + * the function is called only once per key without any prior state. * - * This function does not support partial aggregation, and as a result requires shuffling all + * Other points to note + * - There is no guaranteed ordering of values in the iterator in the function. + * - This function does not support partial aggregation, and as a result requires shuffling all * the data in the [[Dataset]]. + * - Operations on [[KeyedState]] are not threadsafe. See corresponding docs for more details. * * Internally, the implementation will spill to disk if any given group is too large to fit into * memory. However, users must take care to avoid materializing the whole iterator for a group @@ -338,22 +361,30 @@ class KeyValueGroupedDataset[K, V] private[sql]( /** * ::Experimental:: * (Java-specific) - * Applies the given function to each group of data, while using an additional keyed state. - * For each unique group, the function will be passed the group key and an iterator that contains - * all of the elements in the group. The function can return an iteratior of object of arbitrary - * type, and optionally update or remove the corresponding state. The returned object will form a - * new [[Dataset]]. - * - * This function can be applied on both batch and streaming Datasets. With a streaming dataset, - * this function will be once for each in every trigger. For each key, the updated state from the - * function call in a trigger will be the state available in the function call in the next - * trigger. However, for batch, `mapGroupsWithState` behaves exactly as `flatMapGroups` and the - * function is called only once per key without any prior state. + * Applies the given function to each group of data, while maintaining some user-defined per-group + * state. + * + * For each unique group, the given function will be invoked once for each group + * with the following arguments: + * - The key of the group. + * - An iterator containing all the values for this key. + * - A user-defined state object set by previous invocations of the given function. + * Note that, for batch queries, there is only ever one invocation and thus the state object + * will always be empty. And the function can return an iterator of objects of arbitrary type, and + * optionally update or remove the corresponding state. The returned object will form a new + * [[Dataset]]. * - * There is no guaranteed ordering of values in the iterator in the function. + * This operation can be applied on both batch and streaming Datasets. With a streaming dataset, + * the given function will be invoked once for each group in every trigger/batch that has + * data in the group. The updates to the state will be stored and passed to the function in the + * next invocation. However, for batch, `mapGroupsWithState` behaves exactly as `mapGroups` and + * the function is called only once per key without any prior state. * - * This function does not support partial aggregation, and as a result requires shuffling all + * Other points to note + * - There is no guaranteed ordering of values in the iterator in the function. + * - This function does not support partial aggregation, and as a result requires shuffling all * the data in the [[Dataset]]. + * - Operations on [[KeyedState]] are not threadsafe. See corresponding docs for more details. * * Internally, the implementation will spill to disk if any given group is too large to fit into * memory. However, users must take care to avoid materializing the whole iterator for a group diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7a78ed9663911..d9f4334918e69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -320,11 +320,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object MapGroupsWithStateStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case MapGroupsWithState( - func, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, - stateDeser, stateSer, child) => + f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, stateDeser, stateSer, child) => val execPlan = MapGroupsWithStateExec( - func, keyDeser, valueDeser, - groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer, + f, keyDeser, valueDeser, groupAttr, dataAttr, outputAttr, None, stateDeser, stateSer, planLater(child)) execPlan :: Nil case _ => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 88e4d112d871a..9b42623593ce7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -41,7 +41,7 @@ case class OperatorStateId( batchId: Long) /** - * An operator that saves or restores state from the [[StateStore]]. The [[OperatorStateId]] should + * An operator that reads or writes state from the [[StateStore]]. The [[OperatorStateId]] should * be filled in by `prepareForExecution` in [[IncrementalExecution]]. */ trait StatefulOperator extends SparkPlan { @@ -54,11 +54,13 @@ trait StatefulOperator extends SparkPlan { } } +/** An operator that reads from a StateStore. */ trait StateStoreReader extends StatefulOperator { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) } +/** An operator that writes to a StateStore. */ trait StateStoreWriter extends StatefulOperator { override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), From 59c229b0be934b643950e40ac75051f22b756c93 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 31 Jan 2017 16:04:29 -0800 Subject: [PATCH 16/21] Added missing classes --- .../FlatMapGroupsWithStateFunction.java | 38 +++++++++++++++++++ .../function/MapGroupsWithStateFunction.java | 38 +++++++++++++++++++ 2 files changed, 76 insertions(+) create mode 100644 sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java create mode 100644 sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java new file mode 100644 index 0000000000000..2570c8d02ab7c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupsWithStateFunction.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.KeyedState; + +/** + * ::Experimental:: + * Base interface for a map function used in + * {@link org.apache.spark.sql.KeyValueGroupedDataset#flatMapGroupsWithState(FlatMapGroupsWithStateFunction, Encoder, Encoder)}. + * @since 2.1.1 + */ +@Experimental +@InterfaceStability.Evolving +public interface FlatMapGroupsWithStateFunction extends Serializable { + Iterator call(K key, Iterator values, KeyedState state) throws Exception; +} diff --git a/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java new file mode 100644 index 0000000000000..614d3925e0510 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/api/java/function/MapGroupsWithStateFunction.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Encoder; +import org.apache.spark.sql.KeyedState; + +/** + * ::Experimental:: + * Base interface for a map function used in + * {@link org.apache.spark.sql.KeyValueGroupedDataset#mapGroupsWithState(MapGroupsWithStateFunction, Encoder, Encoder)} + * @since 2.1.1 + */ +@Experimental +@InterfaceStability.Evolving +public interface MapGroupsWithStateFunction extends Serializable { + R call(K key, Iterator values, KeyedState state) throws Exception; +} From db2dbb230210131ceb8d3f7ab0ebb6f94e44233b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 31 Jan 2017 17:14:25 -0800 Subject: [PATCH 17/21] Addressed 1 comment --- .../spark/sql/KeyValueGroupedDataset.scala | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 8500bd1cad994..028bf090ca3d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -256,10 +256,10 @@ class KeyValueGroupedDataset[K, V] private[sql]( */ @Experimental @InterfaceStability.Evolving - def mapGroupsWithState[STATE: Encoder, OUT: Encoder]( - func: (K, Iterator[V], KeyedState[STATE]) => OUT): Dataset[OUT] = { - flatMapGroupsWithState[STATE, OUT]( - (key: K, it: Iterator[V], s: KeyedState[STATE]) => Iterator(func(key, it, s))) + def mapGroupsWithState[S: Encoder, U: Encoder]( + func: (K, Iterator[V], KeyedState[S]) => U): Dataset[U] = { + flatMapGroupsWithState[S, U]( + (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func(key, it, s))) } /** @@ -300,12 +300,12 @@ class KeyValueGroupedDataset[K, V] private[sql]( */ @Experimental @InterfaceStability.Evolving - def mapGroupsWithState[STATE, OUT]( - func: MapGroupsWithStateFunction[K, V, STATE, OUT], - stateEncoder: Encoder[STATE], - outputEncoder: Encoder[OUT]): Dataset[OUT] = { - flatMapGroupsWithState[STATE, OUT]( - (key: K, it: Iterator[V], s: KeyedState[STATE]) => Iterator(func.call(key, it.asJava, s)) + def mapGroupsWithState[S, U]( + func: MapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U]): Dataset[U] = { + flatMapGroupsWithState[S, U]( + (key: K, it: Iterator[V], s: KeyedState[S]) => Iterator(func.call(key, it.asJava, s)) )(stateEncoder, outputEncoder) } @@ -347,11 +347,11 @@ class KeyValueGroupedDataset[K, V] private[sql]( */ @Experimental @InterfaceStability.Evolving - def flatMapGroupsWithState[STATE: Encoder, OUT: Encoder]( - func: (K, Iterator[V], KeyedState[STATE]) => Iterator[OUT]): Dataset[OUT] = { - Dataset[OUT]( + def flatMapGroupsWithState[S: Encoder, U: Encoder]( + func: (K, Iterator[V], KeyedState[S]) => Iterator[U]): Dataset[U] = { + Dataset[U]( sparkSession, - MapGroupsWithState[K, V, STATE, OUT]( + MapGroupsWithState[K, V, S, U]( func.asInstanceOf[(Any, Iterator[Any], LogicalKeyedState[Any]) => Iterator[Any]], groupingAttributes, dataAttributes, @@ -396,12 +396,12 @@ class KeyValueGroupedDataset[K, V] private[sql]( */ @Experimental @InterfaceStability.Evolving - def flatMapGroupsWithState[STATE, OUT]( - func: FlatMapGroupsWithStateFunction[K, V, STATE, OUT], - stateEncoder: Encoder[STATE], - outputEncoder: Encoder[OUT]): Dataset[OUT] = { - flatMapGroupsWithState[STATE, OUT]( - (key: K, it: Iterator[V], s: KeyedState[STATE]) => func.call(key, it.asJava, s).asScala + def flatMapGroupsWithState[S, U]( + func: FlatMapGroupsWithStateFunction[K, V, S, U], + stateEncoder: Encoder[S], + outputEncoder: Encoder[U]): Dataset[U] = { + flatMapGroupsWithState[S, U]( + (key: K, it: Iterator[V], s: KeyedState[S]) => func.call(key, it.asJava, s).asScala )(stateEncoder, outputEncoder) } From 8b18fa1c5a457198e3b99d41aaf770c8cb11106d Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 31 Jan 2017 21:07:59 -0800 Subject: [PATCH 18/21] Addressed comments --- .../spark/sql/KeyValueGroupedDataset.scala | 166 +++++------------- .../org/apache/spark/sql/KeyedState.scala | 56 ++++-- .../apache/spark/sql/execution/objects.scala | 2 +- .../execution/streaming/KeyedStateImpl.scala | 43 ++--- .../streaming/statefulOperators.scala | 2 +- .../streaming/MapGroupsWithStateSuite.scala | 29 ++- 6 files changed, 113 insertions(+), 185 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 028bf090ca3d5..94e689a4d5b97 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -221,37 +221,17 @@ class KeyValueGroupedDataset[K, V] private[sql]( /** * ::Experimental:: * (Scala-specific) - * Applies the given function to each group of data, while maintaining some user-defined per-group - * state. - * - * For each unique group, the given function will be invoked once for each group - * with the following arguments: - * - The key of the group. - * - An iterator containing all the values for this key. - * - A user-defined state object set by previous invocations of the given function. - * Note that, for batch queries, there is only ever one invocation and thus the state object - * will always be empty. And the function can return an object of arbitrary type, and - * optionally update or remove the corresponding state. The returned object will form a new - * [[Dataset]]. - * - * This operation can be applied on both batch and streaming Datasets. With a streaming dataset, - * the given function will be invoked once for each group in every trigger/batch that has - * data in the group. The updates to the state will be stored and passed to the function in the - * next invocation. However, for batch, `mapGroupsWithState` behaves exactly as `mapGroups` and - * the function is called only once per key without any prior state. + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. * - * Other points to note - * - There is no guaranteed ordering of values in the iterator in the function. - * - This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. - * - Operations on [[KeyedState]] are not threadsafe. See corresponding docs for more details. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the memory - * constraints of their cluster. + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. * - * @see [[KeyedState]] for more details of how to update/remove state in the function. + * See [[Encoder]] for more details on what types are encodable to Spark SQL. * @since 2.1.1 */ @Experimental @@ -265,37 +245,20 @@ class KeyValueGroupedDataset[K, V] private[sql]( /** * ::Experimental:: * (Java-specific) - * Applies the given function to each group of data, while maintaining some user-defined per-group - * state. - * - * For each unique group, the given function will be invoked once for each group - * with the following arguments: - * - The key of the group. - * - An iterator containing all the values for this key. - * - A user-defined state object set by previous invocations of the given function. - * Note that, for batch queries, there is only ever one invocation and thus the state object - * will always be empty. And the function can return an object of arbitrary type, and - * optionally update or remove the corresponding state. The returned object will form a new - * [[Dataset]]. - * - * This operation can be applied on both batch and streaming Datasets. With a streaming dataset, - * the given function will be invoked once for each group in every trigger/batch that has - * data in the group. The updates to the state will be stored and passed to the function in the - * next invocation. However, for batch, `mapGroupsWithState` behaves exactly as `mapGroups` and - * the function is called only once per key without any prior state. - * - * Other points to note - * - There is no guaranteed ordering of values in the iterator in the function. - * - This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. - * - Operations on [[KeyedState]] are not threadsafe. See corresponding docs for more details. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the memory - * constraints of their cluster. - * - * @see [[KeyedState]] for more details of how to update/remove state in the function. + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param stateEncoder Encoder for the state type. + * @param outputEncoder Encoder for the output type. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. * @since 2.1.1 */ @Experimental @@ -312,37 +275,17 @@ class KeyValueGroupedDataset[K, V] private[sql]( /** * ::Experimental:: * (Scala-specific) - * Applies the given function to each group of data, while maintaining some user-defined per-group - * state. - * - * For each unique group, the given function will be invoked once for each group - * with the following arguments: - * - The key of the group. - * - An iterator containing all the values for this key. - * - A user-defined state object set by previous invocations of the given function. - * Note that, for batch queries, there is only ever one invocation and thus the state object - * will always be empty. And the function can return an iterator of objects of arbitrary type, and - * optionally update or remove the corresponding state. The returned object will form a new - * [[Dataset]]. - * - * This operation can be applied on both batch and streaming Datasets. With a streaming dataset, - * the given function will be invoked once for each group in every trigger/batch that has - * data in the group. The updates to the state will be stored and passed to the function in the - * next invocation. However, for batch, `mapGroupsWithState` behaves exactly as `mapGroups` and - * the function is called only once per key without any prior state. + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. * - * Other points to note - * - There is no guaranteed ordering of values in the iterator in the function. - * - This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. - * - Operations on [[KeyedState]] are not threadsafe. See corresponding docs for more details. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the memory - * constraints of their cluster. + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. * - * @see [[KeyedState]] for more details of how to update/remove state in the function. + * See [[Encoder]] for more details on what types are encodable to Spark SQL. * @since 2.1.1 */ @Experimental @@ -361,37 +304,20 @@ class KeyValueGroupedDataset[K, V] private[sql]( /** * ::Experimental:: * (Java-specific) - * Applies the given function to each group of data, while maintaining some user-defined per-group - * state. - * - * For each unique group, the given function will be invoked once for each group - * with the following arguments: - * - The key of the group. - * - An iterator containing all the values for this key. - * - A user-defined state object set by previous invocations of the given function. - * Note that, for batch queries, there is only ever one invocation and thus the state object - * will always be empty. And the function can return an iterator of objects of arbitrary type, and - * optionally update or remove the corresponding state. The returned object will form a new - * [[Dataset]]. - * - * This operation can be applied on both batch and streaming Datasets. With a streaming dataset, - * the given function will be invoked once for each group in every trigger/batch that has - * data in the group. The updates to the state will be stored and passed to the function in the - * next invocation. However, for batch, `mapGroupsWithState` behaves exactly as `mapGroups` and - * the function is called only once per key without any prior state. - * - * Other points to note - * - There is no guaranteed ordering of values in the iterator in the function. - * - This function does not support partial aggregation, and as a result requires shuffling all - * the data in the [[Dataset]]. - * - Operations on [[KeyedState]] are not threadsafe. See corresponding docs for more details. - * - * Internally, the implementation will spill to disk if any given group is too large to fit into - * memory. However, users must take care to avoid materializing the whole iterator for a group - * (for example, by calling `toList`) unless they are sure that this is possible given the memory - * constraints of their cluster. - * - * @see [[KeyedState]] for more details of how to update/remove state in the function. + * Applies the given function to each group of data, while maintaining a user-defined per-group + * state. The result Dataset will represent the objects returned by the function. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger, and + * updates to each group's state will be saved across invocations. + * See [[KeyedState]] for more details. + * + * @tparam S The type of the user-defined state. Must be encodable to Spark SQL types. + * @tparam U The type of the output objects. Must be encodable to Spark SQL types. + * @param func Function to be called on every group. + * @param stateEncoder Encoder for the state type. + * @param outputEncoder Encoder for the output type. + * + * See [[Encoder]] for more details on what types are encodable to Spark SQL. * @since 2.1.1 */ @Experimental diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala index 7bc0af996725d..35155674fcfee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala @@ -25,17 +25,42 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * * Wrapper class for interacting with keyed state data in `mapGroupsWithState` and * `flatMapGroupsWithState` operations on - * [[org.apache.spark.sql.KeyValueGroupedDataset KeyValueGroupedDataset]]. + * [[KeyValueGroupedDataset]]. * - * Important points to note. - * - State can be `null`. So updating the state to null is not same as removing the state. - * - Operations on state are not threadsafe. This is to avoid memory barriers. - * - If the `remove()` is called, then `exists()` will return `false`, and - * `getOption()` will return `None`. - * - After that `update(newState)` is called, then `exists()` will return `true`, - * and `getOption()` will return `Some(...)`. + * Detail description on `[map/flatMap]GroupsWithState` operation + * ------------------------------------------------------------ + * Both, `mapGroupsWithState` and `flatMapGroupsWithState` in [[KeyValueGroupedDataset]] + * will invoke the user-given function on each group (defined by the grouping function in + * `Dataset.groupByKey()`) while maintaining user-defined per-group state between invocations. + * For a static batch Dataset, the function will be invoked once per group. For a streaming + * Dataset, the function will be invoked for each group repeatedly in every trigger. + * That is, in every batch of the [[streaming.StreamingQuery StreamingQuery]], + * the function will be invoked once for each group that has data in the batch. * - * Scala example of using `KeyedState`: + * The function is invoked with following parameters. + * - The key of the group. + * - An iterator containing all the values for this key. + * - A user-defined state object set by previous invocations of the given function. + * In case of a batch Dataset, there is only invocation and state object will be empty as + * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` + * is equivalent to `[map/flatMap]Groups`. + * + * Important points to note about the function. + * - In a trigger, the function will be called only the groups present in the batch. So do not + * assume that the function will be called in every trigger for every group that has state. + * - There is no guaranteed ordering of values in the iterator in the function, neither with + * batch, nor with streaming Datasets. + * - All the data will be shuffled before applying the function. + * + * Important points to note about using KeyedState. + * - The value of the state cannot be null. So updating state with null is same as removing it. + * - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers. + * - If the `remove()` is called, then `exists()` will return `false`, and + * `getOption()` will return `None`. + * - After that `update(newState)` is called, then `exists()` will return `true`, + * and `getOption()` will return `Some(...)`. + * + * Scala example of using `KeyedState` in `mapGroupsWithState`: * {{{ * // A mapping function that maintains an integer state for string keys and returns a string. * def mappingFunction(key: String, value: Iterable[Int], state: KeyedState[Int]): Option[String]= { @@ -95,22 +120,15 @@ trait KeyedState[S] extends LogicalKeyedState[S] { /** Whether state exists or not. */ def exists: Boolean - /** Get the state object if it is defined, otherwise throws NoSuchElementException. */ + /** Get the state object if it exists, or null. */ def get: S /** - * Update the value of the state. Note that null is a valid value, and does not signify removing - * of the state. + * Update the value of the state. Note that null is not a valid value, and `update(null)` is + * same as `remove()` */ def update(newState: S): Unit /** Remove this keyed state. */ def remove(): Unit - - /** (scala friendly) Get the state object as an [[Option]]. */ - @inline final def getOption: Option[S] = if (exists) Some(get) else None - - @inline final override def toString: String = { - getOption.map { _.toString }.getOrElse("") - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index de7e9ad00ec02..6b4e1eb7b0256 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -360,7 +360,7 @@ object MapGroupsExec { dataAttributes: Seq[Attribute], outputObjAttr: Attribute, child: SparkPlan): MapGroupsExec = { - val f = (key: Any, values: Iterator[Any]) => func(key, values, KeyedStateImpl[Any](None)) + val f = (key: Any, values: Iterator[Any]) => func(key, values, KeyedStateImpl[Any](null)) new MapGroupsExec(f, keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr, child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala index 0e94220f8990e..8809e05d8616e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala @@ -20,51 +20,38 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.KeyedState /** Internal implementation of the [[KeyedState]] interface */ -private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedState[S] { - private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) - private var defined: Boolean = optionalValue.isDefined +private[sql] case class KeyedStateImpl[S](private var value: S) extends KeyedState[S] { private var updated: Boolean = false // whether value has been updated (but not removed) private var removed: Boolean = false // whether value has been removed // ========= Public API ========= - override def exists: Boolean = { - defined - } + override def exists: Boolean = { value != null } - override def get: S = { - if (defined) { - value - } else { - throw new NoSuchElementException("State is either not defined or has already been removed") - } - } + override def get: S = value override def update(newValue: S): Unit = { - value = newValue - defined = true - updated = true - removed = false + if (newValue == null) { + remove() + } else { + value = newValue + updated = true + removed = false + } } override def remove(): Unit = { - defined = false + value = null.asInstanceOf[S] updated = false removed = true } + override def toString: String = "KeyedState($value)" + // ========= Internal API ========= /** Whether the state has been marked for removing */ - def isRemoved: Boolean = { - removed - } + def isRemoved: Boolean = removed /** Whether the state has been been updated */ - def isUpdated: Boolean = { - updated - } -} - -object KeyedStateImpl { - def apply[S](optionalValue: Option[S]): KeyedStateImpl[S] = new KeyedStateImpl[S](optionalValue) + def isUpdated: Boolean = updated } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 9b42623593ce7..d61c34a1f9657 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -287,7 +287,7 @@ case class MapGroupsWithStateExec( val keyObj = getKeyObj(keyRow) val valueObjIter = valueRowIter.map(getValueObj.apply) val stateObjOption = store.get(key).map(getStateObj) - val wrappedState = KeyedStateImpl[Any](stateObjOption) + val wrappedState = KeyedStateImpl[Any](stateObjOption.orNull) val mappedIterator = func(keyObj, valueObjIter, wrappedState) if (wrappedState.isRemoved) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala index 7fc1e6eed5eb5..44dc309593dcd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -48,32 +48,25 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { if (expectedData.isDefined) { assert(state.exists) assert(state.get === expectedData.get) - assert(state.getOption === expectedData) } else { assert(!state.exists) - intercept[NoSuchElementException] { - state.get - } - assert(state.getOption === None) + assert(state.get === null) } - assert(state.isUpdated === shouldBeUpdated) assert(state.isRemoved === shouldBeRemoved) } // Updating empty state - state = KeyedStateImpl[String](None) + state = KeyedStateImpl[String](null) testState(None) state.update("") testState(Some(""), shouldBeUpdated = true) - // Updating exiting state, even if with null - state = KeyedStateImpl[String](Some("2")) + // Updating exiting state + state = KeyedStateImpl[String]("2") testState(Some("2")) state.update("3") testState(Some("3"), shouldBeUpdated = true) - state.update(null) - testState(Some(null), shouldBeUpdated = true) // Removing state state.remove() @@ -81,6 +74,10 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { state.remove() // should be still callable state.update("4") testState(Some("4"), shouldBeRemoved = false, shouldBeUpdated = true) + + // Updating by null is same as remove + state.update(null) + testState(None, shouldBeRemoved = true, shouldBeUpdated = false) } test("flatMapGroupsWithState - streaming") { @@ -88,7 +85,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { // Returns the data and the count if state is defined, otherwise does not return anything val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - var count = state.getOption.map(_.count).getOrElse(0L) + values.size + var count = Option(state.get).map(_.count).getOrElse(0L) + values.size if (count == 3) { state.remove() Iterator.empty @@ -128,7 +125,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { // Function that returns running count only if its even, otherwise does not return val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { if (state.exists) throw new IllegalArgumentException("state.exists should be false") - if (state.getOption.nonEmpty) { + if (state.exists) { throw new IllegalArgumentException("state.getOption should be empty") } Iterator((key, values.size)) @@ -143,7 +140,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - val count = state.getOption.map(_.count).getOrElse(0L) + values.size + val count = Option(state.get).map(_.count).getOrElse(0L) + values.size if (count == 3) { state.remove() (key, "-1") @@ -182,7 +179,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { test("mapGroupsWithState - batch") { val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { if (state.exists) throw new IllegalArgumentException("state.exists should be false") - if (state.getOption.nonEmpty) { + if (state.exists) { throw new IllegalArgumentException("state.getOption should be empty") } (key, values.size) @@ -199,7 +196,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { testQuietly("StateStore.abort on task failure handling") { val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") - val count = state.getOption.map(_.count).getOrElse(0L) + values.size + val count = Option(state.get).map(_.count).getOrElse(0L) + values.size state.update(RunningCount(count)) (key, count) } From 8b3150a7bd74d70211ee0c8bcc993191b8cd25a3 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 1 Feb 2017 11:31:27 -0800 Subject: [PATCH 19/21] Fixed bug --- .../spark/sql/execution/streaming/statefulOperators.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index d61c34a1f9657..82594ef0ed563 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -76,7 +76,7 @@ case class StateStoreRestoreExec( keyExpressions: Seq[Attribute], stateId: Option[OperatorStateId], child: SparkPlan) - extends execution.UnaryExecNode with StatefulOperator { + extends execution.UnaryExecNode with StateStoreReader { override protected def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") From 7a39eafe4b1dbe839409bab7424943cd204c065b Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 7 Feb 2017 01:17:14 -0500 Subject: [PATCH 20/21] Addressed concerns --- .../UnsupportedOperationChecker.scala | 2 +- .../sql/catalyst/plans/logical/object.scala | 1 - .../analysis/UnsupportedOperationsSuite.scala | 5 +- .../org/apache/spark/sql/KeyedState.scala | 20 ++--- .../execution/streaming/KeyedStateImpl.scala | 6 +- .../state/HDFSBackedStateStoreProvider.scala | 2 + .../streaming/state/StateStore.scala | 3 + .../streaming/statefulOperators.scala | 1 - .../streaming/MapGroupsWithStateSuite.scala | 84 ++++++++++++++++--- 9 files changed, 93 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index a202064828db2..d8aad42edcf5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -48,7 +48,7 @@ object UnsupportedOperationChecker { /** Collect all the streaming aggregates in a sub plan */ def collectStreamingAggregates(subplan: LogicalPlan): Seq[Aggregate] = { - subplan.collect { case a@Aggregate(_, _, _) if a.isStreaming => a } + subplan.collect { case a: Aggregate if a.isStreaming => a } } // Disallow multiple streaming aggregations diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index c632b2d80330e..0be4823bbc895 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -362,7 +362,6 @@ case class MapGroupsWithState( stateSerializer: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode with ObjectProducer - /** Factory for constructing new `FlatMapGroupsInR` nodes. */ object FlatMapGroupsInR { def apply( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index c17e00b497599..3b756e89d9036 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -113,10 +113,9 @@ class UnsupportedOperationsSuite extends SparkFunSuite { // MapGroupsWithState: Not supported after a streaming aggregation val att = new AttributeReference(name = "a", dataType = LongType)() - assertSupportedInStreamingPlan( + assertSupportedInBatchPlan( "mapGroupsWithState - mapGroupsWithState on batch relation", - MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation), - outputMode = Append) + MapGroupsWithState(null, att, att, Seq(att), Seq(att), att, att, Seq(att), batchRelation)) assertSupportedInStreamingPlan( "mapGroupsWithState - mapGroupsWithState on streaming relation before aggregation", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala index 35155674fcfee..81d4245e43e9b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * - The key of the group. * - An iterator containing all the values for this key. * - A user-defined state object set by previous invocations of the given function. - * In case of a batch Dataset, there is only invocation and state object will be empty as + * In case of a batch Dataset, there is only one invocation and state object will be empty as * there is no prior state. Essentially, for batch Datasets, `[map/flatMap]GroupsWithState` * is equivalent to `[map/flatMap]Groups`. * @@ -53,17 +53,17 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * - All the data will be shuffled before applying the function. * * Important points to note about using KeyedState. - * - The value of the state cannot be null. So updating state with null is same as removing it. + * - The value of the state cannot be null. So you cannot update state with null. * - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers. - * - If the `remove()` is called, then `exists()` will return `false`, and - * `getOption()` will return `None`. + * - If `remove()` is called, then `exists()` will return `false`, and + * `get()` will return `null`. * - After that `update(newState)` is called, then `exists()` will return `true`, - * and `getOption()` will return `Some(...)`. + * and `get()` will return the non-null value. * - * Scala example of using `KeyedState` in `mapGroupsWithState`: + * Scala example of using KeyedState` in `mapGroupsWithState`: * {{{ * // A mapping function that maintains an integer state for string keys and returns a string. - * def mappingFunction(key: String, value: Iterable[Int], state: KeyedState[Int]): Option[String]= { + * def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = { * // Check if state exists * if (state.exists) { * val existingState = state.get // Get the existing state @@ -90,7 +90,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * new MapGroupsWithStateFunction() { * * @Override - * public String call(String key, Optional value, KeyedState state) { + * public String call(String key, Iterator value, KeyedState state) { * if (state.exists()) { * int existingState = state.get(); // Get the existing state * boolean shouldRemove = ...; // Decide whether to remove the state @@ -124,8 +124,8 @@ trait KeyedState[S] extends LogicalKeyedState[S] { def get: S /** - * Update the value of the state. Note that null is not a valid value, and `update(null)` is - * same as `remove()` + * Update the value of the state. Note that `null` is not a valid value, and it throws + * IllegalArgumentException. */ def update(newState: S): Unit diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala index 8809e05d8616e..d57f2aa461b57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.KeyedState -/** Internal implementation of the [[KeyedState]] interface */ +/** Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. */ private[sql] case class KeyedStateImpl[S](private var value: S) extends KeyedState[S] { private var updated: Boolean = false // whether value has been updated (but not removed) private var removed: Boolean = false // whether value has been removed @@ -31,7 +31,7 @@ private[sql] case class KeyedStateImpl[S](private var value: S) extends KeyedSta override def update(newValue: S): Unit = { if (newValue == null) { - remove() + throw new IllegalArgumentException("'null' is not a valid state value") } else { value = newValue updated = true @@ -45,7 +45,7 @@ private[sql] case class KeyedStateImpl[S](private var value: S) extends KeyedSta removed = true } - override def toString: String = "KeyedState($value)" + override def toString: String = s"KeyedState($value)" // ========= Internal API ========= diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 26bde375f9739..61eb601a18c32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -147,6 +147,7 @@ private[state] class HDFSBackedStateStoreProvider( } } + /** Remove a single key. */ override def remove(key: UnsafeRow): Unit = { verify(state == UPDATING, "Cannot remove after already committed or aborted") if (mapToUpdate.containsKey(key)) { @@ -161,6 +162,7 @@ private[state] class HDFSBackedStateStoreProvider( case Some(ValueRemoved(_, _)) => // Remove already in update map, no need to change } + writeToDeltaFile(tempDeltaFileStream, ValueRemoved(key, value)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index dc2bcee95ca65..dcb24b26f78f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -58,6 +58,9 @@ trait StateStore { */ def remove(condition: UnsafeRow => Boolean): Unit + /** + * Remove a single key. + */ def remove(key: UnsafeRow): Unit /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index 82594ef0ed563..fc926c787ed65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -275,7 +275,6 @@ case class MapGroupsWithStateExec( val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) - val getKey = GenerateUnsafeProjection.generate(groupingAttributes, child.output) val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) val getStateObj = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala index 44dc309593dcd..665beaefe13ac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -43,8 +43,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { def testState( expectedData: Option[String], shouldBeUpdated: Boolean = false, - shouldBeRemoved: Boolean = false - ): Unit = { + shouldBeRemoved: Boolean = false): Unit = { if (expectedData.isDefined) { assert(state.exists) assert(state.get === expectedData.get) @@ -75,9 +74,26 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { state.update("4") testState(Some("4"), shouldBeRemoved = false, shouldBeUpdated = true) - // Updating by null is same as remove - state.update(null) - testState(None, shouldBeRemoved = true, shouldBeUpdated = false) + // Updating by null throw exception + intercept[IllegalArgumentException] { + state.update(null) + } + } + + test("state - primitive types") { + val intState = new KeyedStateImpl[Int](10) + assert(intState.get == 10) + intState.update(0) + assert(intState.get == 0) + intState.remove() + assert(intState.get == null) + + val longState = new KeyedStateImpl[Long](10) + assert(longState.get == 10) + longState.update(0) + assert(longState.get == 0) + longState.remove() + assert(longState.get == null) } test("flatMapGroupsWithState - streaming") { @@ -85,7 +101,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { // Returns the data and the count if state is defined, otherwise does not return anything val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - var count = Option(state.get).map(_.count).getOrElse(0L) + values.size + val count = Option(state.get).map(_.count).getOrElse(0L) + values.size if (count == 3) { state.remove() Iterator.empty @@ -125,8 +141,8 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { // Function that returns running count only if its even, otherwise does not return val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { if (state.exists) throw new IllegalArgumentException("state.exists should be false") - if (state.exists) { - throw new IllegalArgumentException("state.getOption should be empty") + if (state.get != null) { + throw new IllegalArgumentException("state.get should be empty") } Iterator((key, values.size)) } @@ -176,11 +192,55 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { ) } + test("mapGroupsWithState - streaming with aggregation later") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + + val count = Option(state.get).map(_.count).getOrElse(0L) + values.size + if (count == 3) { + state.remove() + (key, "-1") + } else { + state.update(RunningCount(count)) + (key, count.toString) + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .mapGroupsWithState(stateFunc) // Types = State: MyState, Out: (Str, Str) + .groupByKey(_._1) + .count() + + testStream(result, Complete)( + AddData(inputData, "a"), + CheckLastBatch(("a", 1)), + AddData(inputData, "a", "b"), + // mapGroups generates ("a", "2"), ("b", "1"); so increases counts of a and b by 1 + CheckLastBatch(("a", 2), ("b", 1)), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), + // mapGroups should remove state for "a" and generate ("a", "-1"), ("b", "2") ; + // so increment a and b by 1 + CheckLastBatch(("a", 3), ("b", 2)), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), + // mapGroups should recreate state for "a" and generate ("a", "1"), ("c", "1") ; + // so increment a and c by 1 + CheckLastBatch(("a", 4), ("b", 2), ("c", 1)) + ) + } + test("mapGroupsWithState - batch") { val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { if (state.exists) throw new IllegalArgumentException("state.exists should be false") - if (state.exists) { - throw new IllegalArgumentException("state.getOption should be empty") + if (state.get != null) { + throw new IllegalArgumentException("state.get should be empty") } (key, values.size) } @@ -229,8 +289,8 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { private def assertNumStateRows(total: Long, updated: Long): AssertOnQuery = AssertOnQuery { q => val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get - assert(progressWithData.stateOperators(0).numRowsTotal === total) - assert(progressWithData.stateOperators(0).numRowsUpdated === updated) + assert(progressWithData.stateOperators(0).numRowsTotal === total, "incorrect total rows") + assert(progressWithData.stateOperators(0).numRowsUpdated === updated, "incorrect updates rows") true } } From f3d12311229a958fee5fb3cfcf4f6b33035fc87f Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Tue, 7 Feb 2017 17:16:18 -0500 Subject: [PATCH 21/21] Addressed comments --- .../org/apache/spark/sql/KeyedState.scala | 26 ++++-- .../apache/spark/sql/execution/objects.scala | 2 +- .../execution/streaming/KeyedStateImpl.scala | 49 ++++++++--- .../streaming/statefulOperators.scala | 48 +++++++---- .../streaming/MapGroupsWithStateSuite.scala | 85 +++++++++++++------ 5 files changed, 144 insertions(+), 66 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala index 81d4245e43e9b..6864b6f6b4fd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyedState.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import java.lang.IllegalArgumentException + import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState @@ -53,16 +55,17 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * - All the data will be shuffled before applying the function. * * Important points to note about using KeyedState. - * - The value of the state cannot be null. So you cannot update state with null. + * - The value of the state cannot be null. So updating state with null will throw + * `IllegalArgumentException`. * - Operations on `KeyedState` are not thread-safe. This is to avoid memory barriers. - * - If `remove()` is called, then `exists()` will return `false`, and - * `get()` will return `null`. - * - After that `update(newState)` is called, then `exists()` will return `true`, - * and `get()` will return the non-null value. + * - If `remove()` is called, then `exists()` will return `false`, + * `get()` will throw `NoSuchElementException` and `getOption()` will return `None` + * - After that, if `update(newState)` is called, then `exists()` will again return `true`, + * `get()` and `getOption()`will return the updated value. * - * Scala example of using KeyedState` in `mapGroupsWithState`: + * Scala example of using KeyedState in `mapGroupsWithState`: * {{{ - * // A mapping function that maintains an integer state for string keys and returns a string. + * /* A mapping function that maintains an integer state for string keys and returns a string. */ * def mappingFunction(key: String, value: Iterator[Int], state: KeyedState[Int]): String = { * // Check if state exists * if (state.exists) { @@ -85,7 +88,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalKeyedState * * Java example of using `KeyedState`: * {{{ - * // A mapping function that maintains an integer state for string keys and returns a string. + * /* A mapping function that maintains an integer state for string keys and returns a string. */ * MapGroupsWithStateFunction mappingFunction = * new MapGroupsWithStateFunction() { * @@ -120,13 +123,18 @@ trait KeyedState[S] extends LogicalKeyedState[S] { /** Whether state exists or not. */ def exists: Boolean - /** Get the state object if it exists, or null. */ + /** Get the state value if it exists, or throw NoSuchElementException. */ + @throws[NoSuchElementException]("when state does not exist") def get: S + /** Get the state value as a scala Option. */ + def getOption: Option[S] + /** * Update the value of the state. Note that `null` is not a valid value, and it throws * IllegalArgumentException. */ + @throws[IllegalArgumentException]("when updating with null") def update(newState: S): Unit /** Remove this keyed state. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 6b4e1eb7b0256..199ba5ce6969b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -360,7 +360,7 @@ object MapGroupsExec { dataAttributes: Seq[Attribute], outputObjAttr: Attribute, child: SparkPlan): MapGroupsExec = { - val f = (key: Any, values: Iterator[Any]) => func(key, values, KeyedStateImpl[Any](null)) + val f = (key: Any, values: Iterator[Any]) => func(key, values, new KeyedStateImpl[Any](None)) new MapGroupsExec(f, keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr, child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala index d57f2aa461b57..eee7ec45dd77b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/KeyedStateImpl.scala @@ -20,38 +20,61 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.KeyedState /** Internal implementation of the [[KeyedState]] interface. Methods are not thread-safe. */ -private[sql] case class KeyedStateImpl[S](private var value: S) extends KeyedState[S] { - private var updated: Boolean = false // whether value has been updated (but not removed) - private var removed: Boolean = false // whether value has been removed +private[sql] class KeyedStateImpl[S](optionalValue: Option[S]) extends KeyedState[S] { + private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) + private var defined: Boolean = optionalValue.isDefined + private var updated: Boolean = false + // whether value has been updated (but not removed) + private var removed: Boolean = false // whether value has been removed // ========= Public API ========= - override def exists: Boolean = { value != null } + override def exists: Boolean = defined - override def get: S = value + override def get: S = { + if (defined) { + value + } else { + throw new NoSuchElementException("State is either not defined or has already been removed") + } + } + + override def getOption: Option[S] = { + if (defined) { + Some(value) + } else { + None + } + } override def update(newValue: S): Unit = { if (newValue == null) { throw new IllegalArgumentException("'null' is not a valid state value") - } else { - value = newValue - updated = true - removed = false } + value = newValue + defined = true + updated = true + removed = false } override def remove(): Unit = { - value = null.asInstanceOf[S] + defined = false updated = false removed = true } - override def toString: String = s"KeyedState($value)" + override def toString: String = { + s"KeyedState(${getOption.map(_.toString).getOrElse("")})" + } // ========= Internal API ========= /** Whether the state has been marked for removing */ - def isRemoved: Boolean = removed + def isRemoved: Boolean = { + removed + } /** Whether the state has been been updated */ - def isUpdated: Boolean = updated + def isUpdated: Boolean = { + updated + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index fc926c787ed65..1292452574594 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -272,8 +272,10 @@ case class MapGroupsWithStateExec( val numUpdatedStateRows = longMetric("numUpdatedStateRows") val numOutputRows = longMetric("numOutputRows") + // Generate a iterator that returns the rows grouped by the grouping function val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) + // Converters to and from object and rows val getKeyObj = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes) val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) @@ -281,28 +283,38 @@ case class MapGroupsWithStateExec( ObjectOperator.deserializeRowToObject(stateDeserializer) val outputStateObj = ObjectOperator.serializeObjectToRow(stateSerializer) - val finalIterator = groupedIter.flatMap { case (keyRow, valueRowIter) => - val key = keyRow.asInstanceOf[UnsafeRow] - val keyObj = getKeyObj(keyRow) - val valueObjIter = valueRowIter.map(getValueObj.apply) - val stateObjOption = store.get(key).map(getStateObj) - val wrappedState = KeyedStateImpl[Any](stateObjOption.orNull) - val mappedIterator = func(keyObj, valueObjIter, wrappedState) - - if (wrappedState.isRemoved) { - store.remove(key) - numUpdatedStateRows += 1 - } else if (wrappedState.isUpdated) { - store.put(key, outputStateObj(wrappedState.get)) - numUpdatedStateRows += 1 - } + // For every group, get the key, values and corresponding state and call the function, + // and return an iterator of rows + val allRowsIterator = groupedIter.flatMap { case (keyRow, valueRowIter) => - mappedIterator.map { obj => + val key = keyRow.asInstanceOf[UnsafeRow] + val keyObj = getKeyObj(keyRow) // convert key to objects + val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects + val stateObjOption = store.get(key).map(getStateObj) // get existing state if any + val wrappedState = new KeyedStateImpl(stateObjOption) + val mappedIterator = func(keyObj, valueObjIter, wrappedState).map { obj => numOutputRows += 1 - getOutputRow(obj) + getOutputRow(obj) // convert back to rows } + + // Return an iterator of rows generated this key, + // such that fully consumed, the updated state value will be saved + CompletionIterator[InternalRow, Iterator[InternalRow]]( + mappedIterator, { + // When the iterator is consumed, then write changes to state + if (wrappedState.isRemoved) { + store.remove(key) + numUpdatedStateRows += 1 + } else if (wrappedState.isUpdated) { + store.put(key, outputStateObj(wrappedState.get)) + numUpdatedStateRows += 1 + } + }) } - CompletionIterator[InternalRow, Iterator[InternalRow]](finalIterator, { + + // Return an iterator of all the rows generated by all the keys, such that when fully + // consumer, all the state updates will be committed by the state store + CompletionIterator[InternalRow, Iterator[InternalRow]](allRowsIterator, { store.commit() numTotalStateRows += store.numKeys() }) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala index 665beaefe13ac..0524898b15ead 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MapGroupsWithStateSuite.scala @@ -37,7 +37,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { StateStore.stop() } - test("state - get, exists, update, remove") { + test("KeyedState - get, exists, update, remove") { var state: KeyedStateImpl[String] = null def testState( @@ -49,20 +49,23 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { assert(state.get === expectedData.get) } else { assert(!state.exists) - assert(state.get === null) + intercept[NoSuchElementException] { + state.get + } } + assert(state.getOption === expectedData) assert(state.isUpdated === shouldBeUpdated) assert(state.isRemoved === shouldBeRemoved) } // Updating empty state - state = KeyedStateImpl[String](null) + state = new KeyedStateImpl[String](None) testState(None) state.update("") testState(Some(""), shouldBeUpdated = true) // Updating exiting state - state = KeyedStateImpl[String]("2") + state = new KeyedStateImpl[String](Some("2")) testState(Some("2")) state.update("3") testState(Some("3"), shouldBeUpdated = true) @@ -80,20 +83,21 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { } } - test("state - primitive types") { - val intState = new KeyedStateImpl[Int](10) + test("KeyedState - primitive type") { + var intState = new KeyedStateImpl[Int](None) + intercept[NoSuchElementException] { + intState.get + } + assert(intState.getOption === None) + + intState = new KeyedStateImpl[Int](Some(10)) assert(intState.get == 10) intState.update(0) assert(intState.get == 0) intState.remove() - assert(intState.get == null) - - val longState = new KeyedStateImpl[Long](10) - assert(longState.get == 10) - longState.update(0) - assert(longState.get == 0) - longState.remove() - assert(longState.get == null) + intercept[NoSuchElementException] { + intState.get + } } test("flatMapGroupsWithState - streaming") { @@ -101,7 +105,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { // Returns the data and the count if state is defined, otherwise does not return anything val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - val count = Option(state.get).map(_.count).getOrElse(0L) + values.size + val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { state.remove() Iterator.empty @@ -137,13 +141,47 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { ) } + test("flatMapGroupsWithState - streaming + func returns iterator that updates state lazily") { + // Function to maintain running count up to 2, and then remove the count + // Returns the data and the count if state is defined, otherwise does not return anything + // Additionally, it updates state lazily as the returned iterator get consumed + val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { + values.flatMap { _ => + val count = state.getOption.map(_.count).getOrElse(0L) + 1 + if (count == 3) { + state.remove() + None + } else { + state.update(RunningCount(count)) + Some((key, count.toString)) + } + } + } + + val inputData = MemoryStream[String] + val result = + inputData.toDS() + .groupByKey(x => x) + .flatMapGroupsWithState(stateFunc) // State: Int, Out: (Str, Str) + + testStream(result, Append)( + AddData(inputData, "a", "a", "b"), + CheckLastBatch(("a", "1"), ("a", "2"), ("b", "1")), + StopStream, + StartStream(), + AddData(inputData, "a", "b"), // should remove state for "a" and not return anything for a + CheckLastBatch(("b", "2")), + StopStream, + StartStream(), + AddData(inputData, "a", "c"), // should recreate state for "a" and return count as 1 and + CheckLastBatch(("a", "1"), ("c", "1")) + ) + } + test("flatMapGroupsWithState - batch") { // Function that returns running count only if its even, otherwise does not return val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { if (state.exists) throw new IllegalArgumentException("state.exists should be false") - if (state.get != null) { - throw new IllegalArgumentException("state.get should be empty") - } Iterator((key, values.size)) } checkAnswer( @@ -156,7 +194,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - val count = Option(state.get).map(_.count).getOrElse(0L) + values.size + val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { state.remove() (key, "-1") @@ -192,12 +230,12 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { ) } - test("mapGroupsWithState - streaming with aggregation later") { + test("mapGroupsWithState - streaming + aggregation") { // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { - val count = Option(state.get).map(_.count).getOrElse(0L) + values.size + val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { state.remove() (key, "-1") @@ -239,9 +277,6 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { test("mapGroupsWithState - batch") { val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { if (state.exists) throw new IllegalArgumentException("state.exists should be false") - if (state.get != null) { - throw new IllegalArgumentException("state.get should be empty") - } (key, values.size) } @@ -256,7 +291,7 @@ class MapGroupsWithStateSuite extends StreamTest with BeforeAndAfterAll { testQuietly("StateStore.abort on task failure handling") { val stateFunc = (key: String, values: Iterator[String], state: KeyedState[RunningCount]) => { if (MapGroupsWithStateSuite.failInTask) throw new Exception("expected failure") - val count = Option(state.get).map(_.count).getOrElse(0L) + values.size + val count = state.getOption.map(_.count).getOrElse(0L) + values.size state.update(RunningCount(count)) (key, count) }