From ac6c491c0c1654bd31cb0eb9d4fddb9d8b11f424 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 11 Nov 2015 11:46:59 -0800 Subject: [PATCH 1/7] Add trackStateByKey Java API --- .../spark/api/java/function/Function4.java | 27 ++++ .../JavaStatefulNetworkWordCount.java | 43 +++--- .../streaming/StatefulNetworkWordCount.scala | 2 +- .../org/apache/spark/streaming/State.scala | 10 +- .../apache/spark/streaming/StateSpec.scala | 62 ++------ .../streaming/api/java/JavaPairDStream.scala | 46 +++++- .../spark/streaming/api/java/JavaState.scala | 104 +++++++++++++ .../streaming/api/java/JavaStateSpec.scala | 142 ++++++++++++++++++ .../api/java/JavaTrackStateDStream.scala | 44 ++++++ .../streaming/dstream/TrackStateDStream.scala | 1 + .../spark/streaming/rdd/TrackStateRDD.scala | 4 +- .../spark/streaming/util/StateMap.scala | 2 +- 12 files changed, 401 insertions(+), 86 deletions(-) create mode 100644 core/src/main/java/org/apache/spark/api/java/function/Function4.java create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaState.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStateSpec.scala create mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala diff --git a/core/src/main/java/org/apache/spark/api/java/function/Function4.java b/core/src/main/java/org/apache/spark/api/java/function/Function4.java new file mode 100644 index 000000000000..fd727d64863d --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/Function4.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; + +/** + * A four-argument function that takes arguments of type T1, T2, T3 and T4 and returns an R. + */ +public interface Function4 extends Serializable { + public R call(T1 v1, T2 v2, T3 v3, T4 v4) throws Exception; +} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java index 99b63a2590ae..50be62c6305a 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -26,18 +26,13 @@ import com.google.common.base.Optional; import com.google.common.collect.Lists; -import org.apache.spark.HashPartitioner; import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.StorageLevels; -import org.apache.spark.api.java.function.FlatMapFunction; -import org.apache.spark.api.java.function.Function2; -import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.streaming.Durations; -import org.apache.spark.streaming.api.java.JavaDStream; -import org.apache.spark.streaming.api.java.JavaPairDStream; -import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; -import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.Time; +import org.apache.spark.streaming.api.java.*; /** * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every @@ -63,25 +58,12 @@ public static void main(String[] args) { StreamingExamples.setStreamingLogLevels(); - // Update the cumulative count function - final Function2, Optional, Optional> updateFunction = - new Function2, Optional, Optional>() { - @Override - public Optional call(List values, Optional state) { - Integer newSum = state.or(0); - for (Integer value : values) { - newSum += value; - } - return Optional.of(newSum); - } - }; - // Create the context with a 1 second batch size SparkConf sparkConf = new SparkConf().setAppName("JavaStatefulNetworkWordCount"); JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1)); ssc.checkpoint("."); - // Initial RDD input to updateStateByKey + // Initial RDD input to trackStateByKey @SuppressWarnings("unchecked") List> tuples = Arrays.asList(new Tuple2("hello", 1), new Tuple2("world", 1)); @@ -105,9 +87,22 @@ public Tuple2 call(String s) { } }); + // Update the cumulative count function + final Function4, JavaState, Optional>> trackStateFunc = + new Function4, JavaState, Optional>>() { + + @Override + public Optional> call(Time time, String word, Optional one, JavaState state) { + int sum = one.or(0) + state.getOption().or(0); + Tuple2 output = new Tuple2(word, sum); + state.update(sum); + return Optional.of(output); + } + }; + // This will give a Dstream made of state (which is the cumulative count of the words) - JavaPairDStream stateDstream = wordsDstream.updateStateByKey(updateFunction, - new HashPartitioner(ssc.sparkContext().defaultParallelism()), initialRDD); + JavaTrackStateDStream> stateDstream = + wordsDstream.trackStateByKey(JavaStateSpec.function(trackStateFunc).initialState(initialRDD)); stateDstream.print(); ssc.start(); diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index be2ae0b47336..a4f847f118b2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -49,7 +49,7 @@ object StatefulNetworkWordCount { val ssc = new StreamingContext(sparkConf, Seconds(1)) ssc.checkpoint(".") - // Initial RDD input to updateStateByKey + // Initial RDD input to trackStateByKey val initialRDD = ssc.sparkContext.parallelize(List(("hello", 1), ("world", 1))) // Create a ReceiverInputDStream on target ip:port and count the diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala index 7dd1b72f8049..8c2c3d1bfebf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -24,10 +24,9 @@ import org.apache.spark.annotation.Experimental /** * :: Experimental :: * Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of - * a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a - * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. * - * Scala example of using `State`: + * Example of using `State`: * {{{ * // A tracking function that maintains an integer state and return a String * def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] = { @@ -49,11 +48,6 @@ import org.apache.spark.annotation.Experimental * } * * }}} - * - * Java example: - * {{{ - * TODO(@zsxwing) - * }}} */ @Experimental sealed abstract class State[S] { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index c9fe35e74c1c..c72940e57b3b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -17,26 +17,20 @@ package org.apache.spark.streaming -import scala.reflect.ClassTag - import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.JavaPairRDD import org.apache.spark.rdd.RDD import org.apache.spark.util.ClosureCleaner import org.apache.spark.{HashPartitioner, Partitioner} - /** * :: Experimental :: * Abstract class representing all the specifications of the DStream transformation * `trackStateByKey` operation of a - * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a - * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). - * Use the [[org.apache.spark.streaming.StateSpec StateSpec.apply()]] or - * [[org.apache.spark.streaming.StateSpec StateSpec.create()]] to create instances of + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. + * Use the [[org.apache.spark.streaming.StateSpec StateSpec.function()]] to create instances of * this class. * - * Example in Scala: + * Example: * {{{ * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = { * ... @@ -46,16 +40,6 @@ import org.apache.spark.{HashPartitioner, Partitioner} * * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec) * }}} - * - * Example in Java: - * {{{ - * StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec = - * StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction) - * .numPartition(10); - * - * JavaDStream[EmittedDataType] emittedRecordDStream = - * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec); - * }}} */ @Experimental sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] extends Serializable { @@ -63,9 +47,6 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte /** Set the RDD containing the initial states that will be used by `trackStateByKey` */ def initialState(rdd: RDD[(KeyType, StateType)]): this.type - /** Set the RDD containing the initial states that will be used by `trackStateByKey` */ - def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type - /** * Set the number of partitions by which the state RDDs generated by `trackStateByKey` * will be partitioned. Hash partitioning will be used. @@ -92,30 +73,17 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte /** * :: Experimental :: * Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]] - * that is used for specifying the parameters of the DStream transformation - * `trackStateByKey` operation of a - * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a - * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * that is used for specifying the parameters of the DStream transformation `trackStateByKey` + * operation of a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. * - * Example in Scala: + * Example: * {{{ * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = { * ... * } * - * val spec = StateSpec.function(trackingFunction).numPartitions(10) - * - * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec) - * }}} - * - * Example in Java: - * {{{ - * StateStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec = - * StateStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction) - * .numPartition(10); - * - * JavaDStream[EmittedDataType] emittedRecordDStream = - * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec); + * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType]( + * StateSpec.function(trackingFunction).numPartitions(10)) * }}} */ @Experimental @@ -123,8 +91,8 @@ object StateSpec { /** * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications * `trackStateByKey` operation on a - * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a - * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. + * * @param trackingFunction The function applied on every data item to manage the associated state * and generate the emitted data * @tparam KeyType Class of the keys @@ -142,8 +110,8 @@ object StateSpec { /** * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications * `trackStateByKey` operation on a - * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a - * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. + * * @param trackingFunction The function applied on every data item to manage the associated state * and generate the emitted data * @tparam ValueType Class of the values @@ -179,12 +147,6 @@ case class StateSpecImpl[K, V, S, T]( this } - override def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = { - this.initialStateRDD = javaPairRDD.rdd - this - } - - override def numPartitions(numPartitions: Int): this.type = { this.partitioner(new HashPartitioner(numPartitions)) this diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index e2aec6c2f63e..9dbfe388f1e2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -28,8 +28,10 @@ import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} + import org.apache.spark.Partitioner -import org.apache.spark.api.java.{JavaPairRDD, JavaUtils} +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaUtils} import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} @@ -426,6 +428,48 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( ) } + /** + * :: Experimental :: + * Return a new [[JavaDStream]] of data generated by combining the key-value data in `this` stream + * with a continuously updated per-key state. The user-provided state tracking function is + * applied on each keyed data item along with its corresponding state. The function can choose to + * update/remove the state and return a transformed data, which forms the + * [[JavaTrackStateDStream]]. + * + * The specifications of this transformation is made through the [[JavaStateSpec]] class. Besides + * the tracking function, there are a number of optional parameters - initial state data, number + * of partitions, timeouts, etc. See the [[JavaStateSpec]] for more details. + * + * Example of using `trackStateByKey`: + * {{{ + * // A tracking function that maintains an integer state and return a String + * Function2, JavaState, Optional> trackStateFunc = + * new Function2, JavaState, Optional>() { + * + * @Override + * public Optional call(Optional one, JavaState state) { + * // Check if state exists, accordingly update/remove state and return transformed data + * } + * }; + * + * JavaTrackStateDStream[Integer, Integer, Integer, String] trackStateDStream = + * keyValueDStream.trackStateByKey[Int, String]( + * JavaStateSpec.function(trackingFunction).numPartitions(10)); + * }}} + * + * @param spec Specification of this transformation + * @tparam StateType Class type of the state + * @tparam EmittedType Class type of the tranformed data return by the tracking function + */ + @Experimental + def trackStateByKey[StateType, EmittedType](spec: JavaStateSpec[K, V, StateType, EmittedType]): + JavaTrackStateDStream[K, V, StateType, EmittedType] = { + new JavaTrackStateDStream( + dstream.trackStateByKey(spec.stateSpec)( + JavaSparkContext.fakeClassTag, + JavaSparkContext.fakeClassTag)) + } + private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]): (Seq[V], Option[S]) => Option[S] = { val scalaFunc: (Seq[V], Option[S]) => Option[S] = (values, state) => { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaState.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaState.scala new file mode 100644 index 000000000000..dfe93a3942ff --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaState.scala @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import com.google.common.base.Optional + +import org.apache.spark.annotation.Experimental +import org.apache.spark.streaming.State + +/** + * :: Experimental :: + * Class for getting and updating the tracked state in the `trackStateByKey` operation of a + * [[JavaPairDStream]]. + * + * Example of using `State`: + * {{{ + * // A tracking function that maintains an integer state and return a String + * Function2, JavaState, Optional> trackStateFunc = + * new Function2, JavaState, Optional>() { + * + * @Override + * public Optional call(Optional one, JavaState state) { + * if (state.exists()) { + * int existingState = state.get(); // Get the existing state + * boolean shouldRemove = ...; // Decide whether to remove the state + * if (shouldRemove) { + * state.remove(); // Remove the state + * } else { + * int newState = ...; + * state.update(newState); // Set the new state + * } + * } else { + * int initialState = ...; // Set the initial state + * state.update(initialState); + * } + * // return something + * } + * }; + * }}} + */ +@Experimental +final class JavaState[S](state: State[S]) extends Serializable { + + /** Whether the state already exists */ + def exists(): Boolean = state.exists() + + /** + * Get the state if it exists, otherwise it will throw `java.util.NoSuchElementException`. + * Check with `exists()` whether the state exists or not before calling `get()`. + * + * @throws java.util.NoSuchElementException If the state does not exist. + */ + def get(): S = state.get() + + /** + * Update the state with a new value. + * + * State cannot be updated if it has been already removed (that is, `remove()` has already been + * called) or it is going to be removed due to timeout (that is, `isTimingOut()` is `true`). + * + * @throws java.lang.IllegalArgumentException If the state has already been removed, or is + * going to be removed + */ + def update(newState: S): Unit = state.update(newState) + + /** + * Remove the state if it exists. + * + * State cannot be updated if it has been already removed (that is, `remove()` has already been + * called) or it is going to be removed due to timeout (that is, `isTimingOut()` is `true`). + */ + def remove(): Unit = state.remove() + + /** + * Whether the state is timing out and going to be removed by the system after the current batch. + * This timeout can occur if timeout duration has been specified in the + * [[org.apache.spark.streaming.StateSpec StatSpec]] and the key has not received any new data + * for that timeout duration. + */ + def isTimingOut(): Boolean = state.isTimingOut() + + /** + * Get the state as an `Optional`. It will be `Optional.of(state)` if it exists, otherwise + * `Optional.absent()`. + */ + def getOption(): Optional[S] = if (exists) Optional.of(get) else Optional.absent() + + override def toString(): String = state.toString() +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStateSpec.scala new file mode 100644 index 000000000000..7ed3d59bb933 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStateSpec.scala @@ -0,0 +1,142 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import com.google.common.base.Optional + +import org.apache.spark.Partitioner +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.{JavaPairRDD, JavaUtils} +import org.apache.spark.api.java.function.{Function2 => JFunction2, Function4 => JFunction4} +import org.apache.spark.streaming.{Time, Duration, State, StateSpec} + +/** + * :: Experimental :: + * Class representing all the specifications of the [[JavaDStream]] transformation + * `trackStateByKey` operation of a [[JavaPairDStream]]. + * Use the [[JavaStateSpec.function()]] to create instances of this class. + * + * Example: + * {{{ + * JavaStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec = + * JavaStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction) + * .numPartition(10); + * + * JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType] emittedRecordDStream = + * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec); + * }}} + */ +@Experimental +final class JavaStateSpec[K, V, S, T]( + private[streaming] val stateSpec: StateSpec[K, V, S, T]) extends Serializable { + + /** Set the RDD containing the initial states that will be used by `trackStateByKey` */ + def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = { + stateSpec.initialState(javaPairRDD.rdd) + this + } + + /** + * Set the number of partitions by which the state RDDs generated by `trackStateByKey` + * will be partitioned. Hash partitioning will be used. + */ + def numPartitions(numPartitions: Int): this.type = { + stateSpec.numPartitions(numPartitions) + this + } + + /** + * Set the partitioner by which the state RDDs generated by `trackStateByKey` will be + * be partitioned. + */ + def partitioner(partitioner: Partitioner): this.type = { + stateSpec.partitioner(partitioner) + this + } + + /** + * Set the duration after which the state of an idle key will be removed. A key and its state is + * considered idle if it has not received any data for at least the given duration. The state + * tracking function will be called one final time on the idle states that are going to be + * removed; [[JavaState.isTimingOut()]] set to `true` in that call. + */ + def timeout(interval: Duration): this.type = { + stateSpec.timeout(interval) + this + } +} + + +/** + * :: Experimental :: + * Builder object for creating instances of [[JavaStateSpec]] that is used for specifying the + * parameters of the DStream transformation `trackStateByKey` operation of a [[JavaPairDStream]]. + * + * Example: + * {{{ + * JavaStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec = + * JavaStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction) + * .numPartition(10); + * + * JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType] emittedRecordDStream = + * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec); + * }}} + */ +@Experimental +final object JavaStateSpec { + + /** + * Create a [[JavaStateSpec]] for setting all the specifications `trackStateByKey` operation on a + * [[JavaPairDStream]]. + * + * @param javaTrackingFunction The function applied on every data item to manage the associated + * state and generate the emitted data + * @tparam KeyType Class of the keys + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam EmittedType Class of the emitted data + */ + def function[KeyType, ValueType, StateType, EmittedType](javaTrackingFunction: + JFunction4[Time, KeyType, Optional[ValueType], JavaState[StateType], Optional[EmittedType]]): + JavaStateSpec[KeyType, ValueType, StateType, EmittedType] = { + val trackingFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => { + val t = javaTrackingFunction.call(time, k, JavaUtils.optionToOptional(v), new JavaState(s)) + Option(t.orNull) + } + new JavaStateSpec(StateSpec.function(trackingFunc)) + } + + /** + * Create a [[JavaStateSpec]] for setting all the specifications `trackStateByKey` operation on a + * [[JavaPairDStream]]. + * + * @param javaTrackingFunction The function applied on every data item to manage the associated + * state and generate the emitted data + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam EmittedType Class of the emitted data + */ + def function[KeyType, ValueType, StateType, EmittedType]( + javaTrackingFunction: JFunction2[Optional[ValueType], JavaState[StateType], EmittedType]): + JavaStateSpec[KeyType, ValueType, StateType, EmittedType] = { + val trackingFunc = (v: Option[ValueType], s: State[StateType]) => { + javaTrackingFunction.call(Optional.fromNullable(v.get), new JavaState(s)) + } + new JavaStateSpec(StateSpec.function(trackingFunc)) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala new file mode 100644 index 000000000000..e58d61e12bc7 --- /dev/null +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming.api.java + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.streaming.dstream.TrackStateDStream + +/** + * :: Experimental :: + * [[JavaDStream]] representing the stream of records emitted by the tracking function in the + * `trackStateByKey` operation on a [[JavaPairDStream]]. Additionally, it also gives access to the + * stream of state snapshots, that is, the state data of ll keys after a batch has updated them. + * + * @tparam KeyType Class of the state key + * @tparam ValueType Class of the state value + * @tparam StateType Class of the state + * @tparam EmittedType Class of the emitted records + */ +@Experimental +class JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType]( + dstream: TrackStateDStream[KeyType, ValueType, StateType, EmittedType]) + extends JavaDStream[EmittedType](dstream)(JavaSparkContext.fakeClassTag) { + + def stateSnapshots(): JavaPairDStream[KeyType, StateType] = + new JavaPairDStream(dstream.stateSnapshots())( + JavaSparkContext.fakeClassTag, + JavaSparkContext.fakeClassTag) +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala index 58d89c93bcbe..98e881e6ae11 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/TrackStateDStream.scala @@ -35,6 +35,7 @@ import org.apache.spark.streaming.rdd.{TrackStateRDD, TrackStateRDDRecord} * all keys after a batch has updated them. * * @tparam KeyType Class of the state key + * @tparam ValueType Class of the state value * @tparam StateType Class of the state data * @tparam EmittedType Class of the emitted records */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala index ed7cea26d060..fc51496be47b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/TrackStateRDD.scala @@ -70,12 +70,14 @@ private[streaming] class TrackStateRDDPartition( * in the `prevStateRDD` to create `this` RDD * @param trackingFunction The function that will be used to update state and return new data * @param batchTime The time of the batch to which this RDD belongs to. Use to update + * @param timeoutThresholdTime The time to indicate which keys are timeout */ private[streaming] class TrackStateRDD[K: ClassTag, V: ClassTag, S: ClassTag, T: ClassTag]( private var prevStateRDD: RDD[TrackStateRDDRecord[K, S, T]], private var partitionedDataRDD: RDD[(K, V)], trackingFunction: (Time, K, Option[V], State[S]) => Option[T], - batchTime: Time, timeoutThresholdTime: Option[Long] + batchTime: Time, + timeoutThresholdTime: Option[Long] ) extends RDD[TrackStateRDDRecord[K, S, T]]( partitionedDataRDD.sparkContext, List( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index ed622ef7bf70..4479d9db683d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -267,7 +267,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( // Read the data of the delta val deltaMapSize = inputStream.readInt() - deltaMap = new OpenHashMap[K, StateInfo[S]]() + deltaMap = new OpenHashMap[K, StateInfo[S]](deltaMapSize) var deltaMapCount = 0 while (deltaMapCount < deltaMapSize) { val key = inputStream.readObject().asInstanceOf[K] From 74a9a83fa1453422c1a488c2f4f05abb093f4616 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 11 Nov 2015 13:58:21 -0800 Subject: [PATCH 2/7] Remove JavaState and JavaStateSpec --- .../JavaStatefulNetworkWordCount.java | 12 +- .../org/apache/spark/streaming/State.scala | 28 +++- .../apache/spark/streaming/StateSpec.scala | 77 +++++++++- .../streaming/api/java/JavaPairDStream.scala | 24 +-- .../spark/streaming/api/java/JavaState.scala | 104 ------------- .../streaming/api/java/JavaStateSpec.scala | 142 ------------------ .../streaming/JavaTrackStateByKeySuite.java | 67 +++++++++ 7 files changed, 188 insertions(+), 266 deletions(-) delete mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaState.scala delete mode 100644 streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStateSpec.scala create mode 100644 streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java index 50be62c6305a..c400e4237abe 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -31,6 +31,8 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.StorageLevels; import org.apache.spark.streaming.Durations; +import org.apache.spark.streaming.State; +import org.apache.spark.streaming.StateSpec; import org.apache.spark.streaming.Time; import org.apache.spark.streaming.api.java.*; @@ -88,12 +90,12 @@ public Tuple2 call(String s) { }); // Update the cumulative count function - final Function4, JavaState, Optional>> trackStateFunc = - new Function4, JavaState, Optional>>() { + final Function4, State, Optional>> trackStateFunc = + new Function4, State, Optional>>() { @Override - public Optional> call(Time time, String word, Optional one, JavaState state) { - int sum = one.or(0) + state.getOption().or(0); + public Optional> call(Time time, String word, Optional one, State state) { + int sum = one.or(0) + (state.exists() ? state.get() : 0); Tuple2 output = new Tuple2(word, sum); state.update(sum); return Optional.of(output); @@ -102,7 +104,7 @@ public Optional> call(Time time, String word, Optional> stateDstream = - wordsDstream.trackStateByKey(JavaStateSpec.function(trackStateFunc).initialState(initialRDD)); + wordsDstream.trackStateByKey(StateSpec.function(trackStateFunc).initialState(initialRDD)); stateDstream.print(); ssc.start(); diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala index 8c2c3d1bfebf..a9f24b7b3b78 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -26,7 +26,7 @@ import org.apache.spark.annotation.Experimental * Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of * a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. * - * Example of using `State`: + * Scala example of using `State`: * {{{ * // A tracking function that maintains an integer state and return a String * def trackStateFunc(data: Option[Int], state: State[Int]): Option[String] = { @@ -48,6 +48,32 @@ import org.apache.spark.annotation.Experimental * } * * }}} + * + * Java example of using `State`: + * {{{ + * // A tracking function that maintains an integer state and return a String + * Function2, JavaState, Optional> trackStateFunc = + * new Function2, JavaState, Optional>() { + * + * @Override + * public Optional call(Optional one, JavaState state) { + * if (state.exists()) { + * int existingState = state.get(); // Get the existing state + * boolean shouldRemove = ...; // Decide whether to remove the state + * if (shouldRemove) { + * state.remove(); // Remove the state + * } else { + * int newState = ...; + * state.update(newState); // Set the new state + * } + * } else { + * int initialState = ...; // Set the initial state + * state.update(initialState); + * } + * // return something + * } + * }; + * }}} */ @Experimental sealed abstract class State[S] { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index c72940e57b3b..15aa60f71ebd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -17,7 +17,10 @@ package org.apache.spark.streaming +import com.google.common.base.Optional import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.{JavaPairRDD, JavaUtils} +import org.apache.spark.api.java.function.{Function2 => JFunction2, Function4 => JFunction4} import org.apache.spark.rdd.RDD import org.apache.spark.util.ClosureCleaner import org.apache.spark.{HashPartitioner, Partitioner} @@ -30,7 +33,7 @@ import org.apache.spark.{HashPartitioner, Partitioner} * Use the [[org.apache.spark.streaming.StateSpec StateSpec.function()]] to create instances of * this class. * - * Example: + * Example in Scala: * {{{ * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = { * ... @@ -40,6 +43,16 @@ import org.apache.spark.{HashPartitioner, Partitioner} * * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType](spec) * }}} + * + * Example in Java: + * {{{ + * StateSpec[KeyType, ValueType, StateType, EmittedDataType] spec = + * StateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction) + * .numPartition(10); + * + * JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType] emittedRecordDStream = + * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec); + * }}} */ @Experimental sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] extends Serializable { @@ -47,6 +60,9 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte /** Set the RDD containing the initial states that will be used by `trackStateByKey` */ def initialState(rdd: RDD[(KeyType, StateType)]): this.type + /** Set the RDD containing the initial states that will be used by `trackStateByKey` */ + def initialState(javaPairRDD: JavaPairRDD[KeyType, StateType]): this.type + /** * Set the number of partitions by which the state RDDs generated by `trackStateByKey` * will be partitioned. Hash partitioning will be used. @@ -76,7 +92,7 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte * that is used for specifying the parameters of the DStream transformation `trackStateByKey` * operation of a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. * - * Example: + * Example in Scala: * {{{ * def trackingFunction(data: Option[ValueType], wrappedState: State[StateType]): EmittedType = { * ... @@ -85,6 +101,16 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte * val emittedRecordDStream = keyValueDStream.trackStateByKey[StateType, EmittedDataType]( * StateSpec.function(trackingFunction).numPartitions(10)) * }}} + * + * Example in Java: + * {{{ + * StateSpec[KeyType, ValueType, StateType, EmittedDataType] spec = + * StateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction) + * .numPartition(10); + * + * JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType] emittedRecordDStream = + * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec); + * }}} */ @Experimental object StateSpec { @@ -128,6 +154,48 @@ object StateSpec { } new StateSpecImpl(wrappedFunction) } + + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all + * the specifications `trackStateByKey` operation on a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]]. + * + * @param javaTrackingFunction The function applied on every data item to manage the associated + * state and generate the emitted data + * @tparam KeyType Class of the keys + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam EmittedType Class of the emitted data + */ + def function[KeyType, ValueType, StateType, EmittedType](javaTrackingFunction: + JFunction4[Time, KeyType, Optional[ValueType], State[StateType], Optional[EmittedType]]): + StateSpec[KeyType, ValueType, StateType, EmittedType] = { + val trackingFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => { + val t = javaTrackingFunction.call(time, k, JavaUtils.optionToOptional(v), s) + Option(t.orNull) + } + StateSpec.function(trackingFunc) + } + + /** + * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications + * `trackStateByKey` operation on a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]]. + * + * @param javaTrackingFunction The function applied on every data item to manage the associated + * state and generate the emitted data + * @tparam ValueType Class of the values + * @tparam StateType Class of the states data + * @tparam EmittedType Class of the emitted data + */ + def function[KeyType, ValueType, StateType, EmittedType]( + javaTrackingFunction: JFunction2[Optional[ValueType], State[StateType], EmittedType]): + StateSpec[KeyType, ValueType, StateType, EmittedType] = { + val trackingFunc = (v: Option[ValueType], s: State[StateType]) => { + javaTrackingFunction.call(Optional.fromNullable(v.get), s) + } + StateSpec.function(trackingFunc) + } } @@ -147,6 +215,11 @@ case class StateSpecImpl[K, V, S, T]( this } + override def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = { + this.initialStateRDD = javaPairRDD.rdd + this + } + override def numPartitions(numPartitions: Int): this.type = { this.partitioner(new HashPartitioner(numPartitions)) this diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 9dbfe388f1e2..f150d123c5f9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -436,25 +436,26 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * update/remove the state and return a transformed data, which forms the * [[JavaTrackStateDStream]]. * - * The specifications of this transformation is made through the [[JavaStateSpec]] class. Besides - * the tracking function, there are a number of optional parameters - initial state data, number - * of partitions, timeouts, etc. See the [[JavaStateSpec]] for more details. + * The specifications of this transformation is made through the + * [[org.apache.spark.streaming.StateSpec StateSpec]] class. Besides the tracking function, there + * are a number of optional parameters - initial state data, number of partitions, timeouts, etc. + * See the [[org.apache.spark.streaming.StateSpec StateSpec]] for more details. * * Example of using `trackStateByKey`: * {{{ * // A tracking function that maintains an integer state and return a String - * Function2, JavaState, Optional> trackStateFunc = - * new Function2, JavaState, Optional>() { + * Function2, State, Optional> trackStateFunc = + * new Function2, State, Optional>() { * * @Override - * public Optional call(Optional one, JavaState state) { + * public Optional call(Optional one, State state) { * // Check if state exists, accordingly update/remove state and return transformed data * } * }; * * JavaTrackStateDStream[Integer, Integer, Integer, String] trackStateDStream = * keyValueDStream.trackStateByKey[Int, String]( - * JavaStateSpec.function(trackingFunction).numPartitions(10)); + * StateSpec.function(trackingFunction).numPartitions(10)); * }}} * * @param spec Specification of this transformation @@ -462,12 +463,11 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * @tparam EmittedType Class type of the tranformed data return by the tracking function */ @Experimental - def trackStateByKey[StateType, EmittedType](spec: JavaStateSpec[K, V, StateType, EmittedType]): + def trackStateByKey[StateType, EmittedType](spec: StateSpec[K, V, StateType, EmittedType]): JavaTrackStateDStream[K, V, StateType, EmittedType] = { - new JavaTrackStateDStream( - dstream.trackStateByKey(spec.stateSpec)( - JavaSparkContext.fakeClassTag, - JavaSparkContext.fakeClassTag)) + new JavaTrackStateDStream(dstream.trackStateByKey(spec)( + JavaSparkContext.fakeClassTag, + JavaSparkContext.fakeClassTag)) } private def convertUpdateStateFunction[S](in: JFunction2[JList[V], Optional[S], Optional[S]]): diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaState.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaState.scala deleted file mode 100644 index dfe93a3942ff..000000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaState.scala +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.api.java - -import com.google.common.base.Optional - -import org.apache.spark.annotation.Experimental -import org.apache.spark.streaming.State - -/** - * :: Experimental :: - * Class for getting and updating the tracked state in the `trackStateByKey` operation of a - * [[JavaPairDStream]]. - * - * Example of using `State`: - * {{{ - * // A tracking function that maintains an integer state and return a String - * Function2, JavaState, Optional> trackStateFunc = - * new Function2, JavaState, Optional>() { - * - * @Override - * public Optional call(Optional one, JavaState state) { - * if (state.exists()) { - * int existingState = state.get(); // Get the existing state - * boolean shouldRemove = ...; // Decide whether to remove the state - * if (shouldRemove) { - * state.remove(); // Remove the state - * } else { - * int newState = ...; - * state.update(newState); // Set the new state - * } - * } else { - * int initialState = ...; // Set the initial state - * state.update(initialState); - * } - * // return something - * } - * }; - * }}} - */ -@Experimental -final class JavaState[S](state: State[S]) extends Serializable { - - /** Whether the state already exists */ - def exists(): Boolean = state.exists() - - /** - * Get the state if it exists, otherwise it will throw `java.util.NoSuchElementException`. - * Check with `exists()` whether the state exists or not before calling `get()`. - * - * @throws java.util.NoSuchElementException If the state does not exist. - */ - def get(): S = state.get() - - /** - * Update the state with a new value. - * - * State cannot be updated if it has been already removed (that is, `remove()` has already been - * called) or it is going to be removed due to timeout (that is, `isTimingOut()` is `true`). - * - * @throws java.lang.IllegalArgumentException If the state has already been removed, or is - * going to be removed - */ - def update(newState: S): Unit = state.update(newState) - - /** - * Remove the state if it exists. - * - * State cannot be updated if it has been already removed (that is, `remove()` has already been - * called) or it is going to be removed due to timeout (that is, `isTimingOut()` is `true`). - */ - def remove(): Unit = state.remove() - - /** - * Whether the state is timing out and going to be removed by the system after the current batch. - * This timeout can occur if timeout duration has been specified in the - * [[org.apache.spark.streaming.StateSpec StatSpec]] and the key has not received any new data - * for that timeout duration. - */ - def isTimingOut(): Boolean = state.isTimingOut() - - /** - * Get the state as an `Optional`. It will be `Optional.of(state)` if it exists, otherwise - * `Optional.absent()`. - */ - def getOption(): Optional[S] = if (exists) Optional.of(get) else Optional.absent() - - override def toString(): String = state.toString() -} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStateSpec.scala deleted file mode 100644 index 7ed3d59bb933..000000000000 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStateSpec.scala +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.streaming.api.java - -import com.google.common.base.Optional - -import org.apache.spark.Partitioner -import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.{JavaPairRDD, JavaUtils} -import org.apache.spark.api.java.function.{Function2 => JFunction2, Function4 => JFunction4} -import org.apache.spark.streaming.{Time, Duration, State, StateSpec} - -/** - * :: Experimental :: - * Class representing all the specifications of the [[JavaDStream]] transformation - * `trackStateByKey` operation of a [[JavaPairDStream]]. - * Use the [[JavaStateSpec.function()]] to create instances of this class. - * - * Example: - * {{{ - * JavaStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec = - * JavaStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction) - * .numPartition(10); - * - * JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType] emittedRecordDStream = - * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec); - * }}} - */ -@Experimental -final class JavaStateSpec[K, V, S, T]( - private[streaming] val stateSpec: StateSpec[K, V, S, T]) extends Serializable { - - /** Set the RDD containing the initial states that will be used by `trackStateByKey` */ - def initialState(javaPairRDD: JavaPairRDD[K, S]): this.type = { - stateSpec.initialState(javaPairRDD.rdd) - this - } - - /** - * Set the number of partitions by which the state RDDs generated by `trackStateByKey` - * will be partitioned. Hash partitioning will be used. - */ - def numPartitions(numPartitions: Int): this.type = { - stateSpec.numPartitions(numPartitions) - this - } - - /** - * Set the partitioner by which the state RDDs generated by `trackStateByKey` will be - * be partitioned. - */ - def partitioner(partitioner: Partitioner): this.type = { - stateSpec.partitioner(partitioner) - this - } - - /** - * Set the duration after which the state of an idle key will be removed. A key and its state is - * considered idle if it has not received any data for at least the given duration. The state - * tracking function will be called one final time on the idle states that are going to be - * removed; [[JavaState.isTimingOut()]] set to `true` in that call. - */ - def timeout(interval: Duration): this.type = { - stateSpec.timeout(interval) - this - } -} - - -/** - * :: Experimental :: - * Builder object for creating instances of [[JavaStateSpec]] that is used for specifying the - * parameters of the DStream transformation `trackStateByKey` operation of a [[JavaPairDStream]]. - * - * Example: - * {{{ - * JavaStateSpec[KeyType, ValueType, StateType, EmittedDataType] spec = - * JavaStateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction) - * .numPartition(10); - * - * JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType] emittedRecordDStream = - * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec); - * }}} - */ -@Experimental -final object JavaStateSpec { - - /** - * Create a [[JavaStateSpec]] for setting all the specifications `trackStateByKey` operation on a - * [[JavaPairDStream]]. - * - * @param javaTrackingFunction The function applied on every data item to manage the associated - * state and generate the emitted data - * @tparam KeyType Class of the keys - * @tparam ValueType Class of the values - * @tparam StateType Class of the states data - * @tparam EmittedType Class of the emitted data - */ - def function[KeyType, ValueType, StateType, EmittedType](javaTrackingFunction: - JFunction4[Time, KeyType, Optional[ValueType], JavaState[StateType], Optional[EmittedType]]): - JavaStateSpec[KeyType, ValueType, StateType, EmittedType] = { - val trackingFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => { - val t = javaTrackingFunction.call(time, k, JavaUtils.optionToOptional(v), new JavaState(s)) - Option(t.orNull) - } - new JavaStateSpec(StateSpec.function(trackingFunc)) - } - - /** - * Create a [[JavaStateSpec]] for setting all the specifications `trackStateByKey` operation on a - * [[JavaPairDStream]]. - * - * @param javaTrackingFunction The function applied on every data item to manage the associated - * state and generate the emitted data - * @tparam ValueType Class of the values - * @tparam StateType Class of the states data - * @tparam EmittedType Class of the emitted data - */ - def function[KeyType, ValueType, StateType, EmittedType]( - javaTrackingFunction: JFunction2[Optional[ValueType], JavaState[StateType], EmittedType]): - JavaStateSpec[KeyType, ValueType, StateType, EmittedType] = { - val trackingFunc = (v: Option[ValueType], s: State[StateType]) => { - javaTrackingFunction.call(Optional.fromNullable(v.get), new JavaState(s)) - } - new JavaStateSpec(StateSpec.function(trackingFunc)) - } -} diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java new file mode 100644 index 000000000000..188aefba6130 --- /dev/null +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.streaming; + +import com.google.common.base.Optional; +import org.apache.spark.HashPartitioner; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.Function4; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.Durations; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaTrackStateDStream; +import org.junit.Test; +import scala.Tuple2; + +import java.io.Serializable; + +public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implements Serializable { + + /** + * This test is only for testing the APIs. It's not necessary to run it. + */ + public void testAPI() { + // TODO +// JavaPairRDD initialRDD = null; +// JavaPairDStream wordsDstream = null; +// final Function4, State, Optional> +// trackStateFunc = +// new Function4, State, Optional>() { +// +// @Override +// public Optional call(Time time, String word, Optional one, +// State state) { +// // Use all State's methods here +// state.exists(); +// state.get(); +// state.isTimingOut(); +// state.remove(); +// state.update(10); +// return "test"; +// } +// }; +// +// JavaTrackStateDStream> stateDstream = +// wordsDstream.trackStateByKey( +// StateSpec.function(trackStateFunc) +// .initialState(initialRDD) +// .numPartitions(10) +// .partitioner(new HashPartitioner(10)) +// .timeout(Durations.seconds(10))); + } +} From 44cf5c21059cfef11355b9289e4b80afc460f8d2 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 11 Nov 2015 14:00:37 -0800 Subject: [PATCH 3/7] Fix docs --- .../src/main/scala/org/apache/spark/streaming/State.scala | 3 ++- .../main/scala/org/apache/spark/streaming/StateSpec.scala | 6 ++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala index a9f24b7b3b78..8e6e9498a5b1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -24,7 +24,8 @@ import org.apache.spark.annotation.Experimental /** * :: Experimental :: * Abstract class for getting and updating the tracked state in the `trackStateByKey` operation of - * a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. + * a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). * * Scala example of using `State`: * {{{ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index 15aa60f71ebd..4a7a907e41b7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -29,8 +29,10 @@ import org.apache.spark.{HashPartitioner, Partitioner} * :: Experimental :: * Abstract class representing all the specifications of the DStream transformation * `trackStateByKey` operation of a - * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. - * Use the [[org.apache.spark.streaming.StateSpec StateSpec.function()]] to create instances of + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). + * Use the [[org.apache.spark.streaming.StateSpec StateSpec.apply()]] or + * [[org.apache.spark.streaming.StateSpec StateSpec.create()]] to create instances of * this class. * * Example in Scala: From de4ef2bfe4825c24ffec978ab67f72cc898bd06a Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 11 Nov 2015 15:26:56 -0800 Subject: [PATCH 4/7] Fix docs --- .../org/apache/spark/streaming/StateSpec.scala | 13 ++++++++----- .../streaming/api/java/JavaTrackStateDStream.scala | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index 4a7a907e41b7..3fd89135b19f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -92,7 +92,10 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte * :: Experimental :: * Builder object for creating instances of [[org.apache.spark.streaming.StateSpec StateSpec]] * that is used for specifying the parameters of the DStream transformation `trackStateByKey` - * operation of a [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. + * that is used for specifying the parameters of the DStream transformation + * `trackStateByKey` operation of a + * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]] (Scala) or a + * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]] (Java). * * Example in Scala: * {{{ @@ -118,7 +121,7 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte object StateSpec { /** * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications - * `trackStateByKey` operation on a + * of the `trackStateByKey` operation on a * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. * * @param trackingFunction The function applied on every data item to manage the associated state @@ -137,7 +140,7 @@ object StateSpec { /** * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications - * `trackStateByKey` operation on a + * of the `trackStateByKey` operation on a * [[org.apache.spark.streaming.dstream.PairDStreamFunctions pair DStream]]. * * @param trackingFunction The function applied on every data item to manage the associated state @@ -159,7 +162,7 @@ object StateSpec { /** * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all - * the specifications `trackStateByKey` operation on a + * the specifications of the `trackStateByKey` operation on a * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]]. * * @param javaTrackingFunction The function applied on every data item to manage the associated @@ -181,7 +184,7 @@ object StateSpec { /** * Create a [[org.apache.spark.streaming.StateSpec StateSpec]] for setting all the specifications - * `trackStateByKey` operation on a + * of the `trackStateByKey` operation on a * [[org.apache.spark.streaming.api.java.JavaPairDStream JavaPairDStream]]. * * @param javaTrackingFunction The function applied on every data item to manage the associated diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala index e58d61e12bc7..f459930d0660 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaTrackStateDStream.scala @@ -25,7 +25,7 @@ import org.apache.spark.streaming.dstream.TrackStateDStream * :: Experimental :: * [[JavaDStream]] representing the stream of records emitted by the tracking function in the * `trackStateByKey` operation on a [[JavaPairDStream]]. Additionally, it also gives access to the - * stream of state snapshots, that is, the state data of ll keys after a batch has updated them. + * stream of state snapshots, that is, the state data of all keys after a batch has updated them. * * @tparam KeyType Class of the state key * @tparam ValueType Class of the state value From 7160786a62acef797d95c0e4cd8d62eb02d809d3 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 11 Nov 2015 16:59:29 -0800 Subject: [PATCH 5/7] Add unit tests --- .../apache/spark/streaming/Java8APISuite.java | 47 +++- .../spark/streaming/util/StateMap.scala | 6 +- .../streaming/JavaTrackStateByKeySuite.java | 209 +++++++++++++++--- 3 files changed, 226 insertions(+), 36 deletions(-) diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index 73091cfe2c09..872e9d348318 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -31,9 +31,12 @@ import org.apache.spark.HashPartitioner; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.Function4; import org.apache.spark.api.java.function.PairFunction; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaTrackStateDStream; /** * Most of these tests replicate org.apache.spark.streaming.JavaAPISuite using java 8 @@ -617,7 +620,7 @@ public void testCombineByKey() { JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream combined = pairStream.combineByKey(i -> i, - (x, y) -> x + y, (x, y) -> x + y, new HashPartitioner(2)); + (x, y) -> x + y, (x, y) -> x + y, new HashPartitioner(2)); JavaTestUtils.attachTestOutputStream(combined); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -700,7 +703,7 @@ public void testReduceByKeyAndWindowWithInverse() { JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow((x, y) -> x + y, (x, y) -> x - y, new Duration(2000), - new Duration(1000)); + new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3); @@ -831,4 +834,44 @@ public void testFlatMapValues() { Assert.assertEquals(expected, result); } + /** + * This test is only for testing the APIs. It's not necessary to run it. + */ + public void testTrackStateByAPI() { + JavaPairRDD initialRDD = null; + JavaPairDStream wordsDstream = null; + + JavaTrackStateDStream stateDstream = + wordsDstream.trackStateByKey( + StateSpec. function((time, key, value, state) -> { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return Optional.of(2.0); + }).initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + JavaPairDStream emittedRecords = stateDstream.stateSnapshots(); + + JavaTrackStateDStream stateDstream2 = + wordsDstream.trackStateByKey( + StateSpec.function((value, state) -> { + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return 2.0; + }).initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + JavaPairDStream emittedRecords2 = stateDstream2.stateSnapshots(); + } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 4479d9db683d..34287c3e0090 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -267,7 +267,11 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( // Read the data of the delta val deltaMapSize = inputStream.readInt() - deltaMap = new OpenHashMap[K, StateInfo[S]](deltaMapSize) + deltaMap = if (deltaMapSize != 0) { + new OpenHashMap[K, StateInfo[S]](deltaMapSize) + } else { + new OpenHashMap[K, StateInfo[S]](initialCapacity) + } var deltaMapCount = 0 while (deltaMapCount < deltaMapSize) { val key = inputStream.readObject().asInstanceOf[K] diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java index 188aefba6130..eac4cdd14a68 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTrackStateByKeySuite.java @@ -17,18 +17,30 @@ package org.apache.spark.streaming; +import java.io.Serializable; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import scala.Tuple2; + import com.google.common.base.Optional; +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.util.ManualClock; +import org.junit.Assert; +import org.junit.Test; + import org.apache.spark.HashPartitioner; import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.Function4; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.api.java.JavaPairDStream; import org.apache.spark.streaming.api.java.JavaTrackStateDStream; -import org.junit.Test; -import scala.Tuple2; - -import java.io.Serializable; public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implements Serializable { @@ -36,32 +48,163 @@ public class JavaTrackStateByKeySuite extends LocalJavaStreamingContext implemen * This test is only for testing the APIs. It's not necessary to run it. */ public void testAPI() { - // TODO -// JavaPairRDD initialRDD = null; -// JavaPairDStream wordsDstream = null; -// final Function4, State, Optional> -// trackStateFunc = -// new Function4, State, Optional>() { -// -// @Override -// public Optional call(Time time, String word, Optional one, -// State state) { -// // Use all State's methods here -// state.exists(); -// state.get(); -// state.isTimingOut(); -// state.remove(); -// state.update(10); -// return "test"; -// } -// }; -// -// JavaTrackStateDStream> stateDstream = -// wordsDstream.trackStateByKey( -// StateSpec.function(trackStateFunc) -// .initialState(initialRDD) -// .numPartitions(10) -// .partitioner(new HashPartitioner(10)) -// .timeout(Durations.seconds(10))); + JavaPairRDD initialRDD = null; + JavaPairDStream wordsDstream = null; + + final Function4, State, Optional> + trackStateFunc = + new Function4, State, Optional>() { + + @Override + public Optional call( + Time time, String word, Optional one, State state) { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return Optional.of(2.0); + } + }; + + JavaTrackStateDStream stateDstream = + wordsDstream.trackStateByKey( + StateSpec.function(trackStateFunc) + .initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + JavaPairDStream emittedRecords = stateDstream.stateSnapshots(); + + final Function2, State, Double> trackStateFunc2 = + new Function2, State, Double>() { + + @Override + public Double call(Optional one, State state) { + // Use all State's methods here + state.exists(); + state.get(); + state.isTimingOut(); + state.remove(); + state.update(true); + return 2.0; + } + }; + + JavaTrackStateDStream stateDstream2 = + wordsDstream.trackStateByKey( + StateSpec. function(trackStateFunc2) + .initialState(initialRDD) + .numPartitions(10) + .partitioner(new HashPartitioner(10)) + .timeout(Durations.seconds(10))); + + JavaPairDStream emittedRecords2 = stateDstream2.stateSnapshots(); + } + + @Test + public void testBasicFunction() { + List> inputData = Arrays.asList( + Collections.emptyList(), + Arrays.asList("a"), + Arrays.asList("a", "b"), + Arrays.asList("a", "b", "c"), + Arrays.asList("a", "b"), + Arrays.asList("a"), + Collections.emptyList() + ); + + List> outputData = Arrays.asList( + Collections.emptySet(), + Sets.newHashSet(1), + Sets.newHashSet(2, 1), + Sets.newHashSet(3, 2, 1), + Sets.newHashSet(4, 3), + Sets.newHashSet(5), + Collections.emptySet() + ); + + List>> stateData = Arrays.asList( + Collections.>emptySet(), + Sets.newHashSet(new Tuple2("a", 1)), + Sets.newHashSet(new Tuple2("a", 2), new Tuple2("b", 1)), + Sets.newHashSet( + new Tuple2("a", 3), + new Tuple2("b", 2), + new Tuple2("c", 1)), + Sets.newHashSet( + new Tuple2("a", 4), + new Tuple2("b", 3), + new Tuple2("c", 1)), + Sets.newHashSet( + new Tuple2("a", 5), + new Tuple2("b", 3), + new Tuple2("c", 1)), + Sets.newHashSet( + new Tuple2("a", 5), + new Tuple2("b", 3), + new Tuple2("c", 1)) + ); + + Function2, State, Integer> trackStateFunc = + new Function2, State, Integer>() { + + @Override + public Integer call(Optional value, State state) throws Exception { + int sum = value.or(0) + (state.exists() ? state.get() : 0); + state.update(sum); + return sum; + } + }; + testOperation( + inputData, + StateSpec.function(trackStateFunc), + outputData, + stateData); + } + + private void testOperation( + List> input, + StateSpec trackStateSpec, + List> expectedOutputs, + List>> expectedStateSnapshots) { + int numBatches = expectedOutputs.size(); + JavaDStream inputStream = JavaTestUtils.attachTestInputStream(ssc, input, 2); + JavaTrackStateDStream trackeStateStream = + JavaPairDStream.fromJavaDStream(inputStream.map(new Function>() { + @Override + public Tuple2 call(K x) throws Exception { + return new Tuple2(x, 1); + } + })).trackStateByKey(trackStateSpec); + + final List> collectedOutputs = + Collections.synchronizedList(Lists.>newArrayList()); + trackeStateStream.foreachRDD(new Function, Void>() { + @Override + public Void call(JavaRDD rdd) throws Exception { + collectedOutputs.add(Sets.newHashSet(rdd.collect())); + return null; + } + }); + final List>> collectedStateSnapshots = + Collections.synchronizedList(Lists.>>newArrayList()); + trackeStateStream.stateSnapshots().foreachRDD(new Function, Void>() { + @Override + public Void call(JavaPairRDD rdd) throws Exception { + collectedStateSnapshots.add(Sets.newHashSet(rdd.collect())); + return null; + } + }); + BatchCounter batchCounter = new BatchCounter(ssc.ssc()); + ssc.start(); + ((ManualClock) ssc.ssc().scheduler().clock()) + .advance(ssc.ssc().progressListener().batchDuration() * numBatches + 1); + batchCounter.waitUntilBatchesCompleted(numBatches, 10000); + + Assert.assertEquals(expectedOutputs, collectedOutputs); + Assert.assertEquals(expectedStateSnapshots, collectedStateSnapshots); } } From 04845c8b37d30636d78ef95d04ce545e18b61bc2 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 11 Nov 2015 17:50:38 -0800 Subject: [PATCH 6/7] Fix the example codes in docs --- .../scala/org/apache/spark/streaming/State.scala | 6 +++--- .../org/apache/spark/streaming/StateSpec.scala | 16 ++++++++-------- .../streaming/api/java/JavaPairDStream.scala | 6 +++--- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/State.scala b/streaming/src/main/scala/org/apache/spark/streaming/State.scala index 8e6e9498a5b1..604e64fc6163 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/State.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/State.scala @@ -53,11 +53,11 @@ import org.apache.spark.annotation.Experimental * Java example of using `State`: * {{{ * // A tracking function that maintains an integer state and return a String - * Function2, JavaState, Optional> trackStateFunc = - * new Function2, JavaState, Optional>() { + * Function2, State, Optional> trackStateFunc = + * new Function2, State, Optional>() { * * @Override - * public Optional call(Optional one, JavaState state) { + * public Optional call(Optional one, State state) { * if (state.exists()) { * int existingState = state.get(); // Get the existing state * boolean shouldRemove = ...; // Decide whether to remove the state diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index 3fd89135b19f..bea5b9df20b5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -48,12 +48,12 @@ import org.apache.spark.{HashPartitioner, Partitioner} * * Example in Java: * {{{ - * StateSpec[KeyType, ValueType, StateType, EmittedDataType] spec = - * StateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction) + * StateSpec spec = + * StateSpec.function(trackingFunction) * .numPartition(10); * - * JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType] emittedRecordDStream = - * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec); + * JavaTrackStateDStream emittedRecordDStream = + * javaPairDStream.trackStateByKey(spec); * }}} */ @Experimental @@ -109,12 +109,12 @@ sealed abstract class StateSpec[KeyType, ValueType, StateType, EmittedType] exte * * Example in Java: * {{{ - * StateSpec[KeyType, ValueType, StateType, EmittedDataType] spec = - * StateSpec.function[KeyType, ValueType, StateType, EmittedDataType](trackingFunction) + * StateSpec spec = + * StateSpec.function(trackingFunction) * .numPartition(10); * - * JavaTrackStateDStream[KeyType, ValueType, StateType, EmittedType] emittedRecordDStream = - * javaPairDStream.trackStateByKey[StateType, EmittedDataType](spec); + * JavaTrackStateDStream emittedRecordDStream = + * javaPairDStream.trackStateByKey(spec); * }}} */ @Experimental diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index f150d123c5f9..70e32b383e45 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -453,9 +453,9 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])( * } * }; * - * JavaTrackStateDStream[Integer, Integer, Integer, String] trackStateDStream = - * keyValueDStream.trackStateByKey[Int, String]( - * StateSpec.function(trackingFunction).numPartitions(10)); + * JavaTrackStateDStream trackStateDStream = + * keyValueDStream.trackStateByKey( + * StateSpec.function(trackStateFunc).numPartitions(10)); * }}} * * @param spec Specification of this transformation From f6c45cb3610747d48ce8c9cd96c0dee6742f6c45 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 11 Nov 2015 21:58:12 -0800 Subject: [PATCH 7/7] Remove unnecessary changes --- .../test/java/org/apache/spark/streaming/Java8APISuite.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index 872e9d348318..163ae92c12c6 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -620,7 +620,7 @@ public void testCombineByKey() { JavaPairDStream pairStream = JavaPairDStream.fromJavaDStream(stream); JavaPairDStream combined = pairStream.combineByKey(i -> i, - (x, y) -> x + y, (x, y) -> x + y, new HashPartitioner(2)); + (x, y) -> x + y, (x, y) -> x + y, new HashPartitioner(2)); JavaTestUtils.attachTestOutputStream(combined); List>> result = JavaTestUtils.runStreams(ssc, 2, 2); @@ -703,7 +703,7 @@ public void testReduceByKeyAndWindowWithInverse() { JavaPairDStream reduceWindowed = pairStream.reduceByKeyAndWindow((x, y) -> x + y, (x, y) -> x - y, new Duration(2000), - new Duration(1000)); + new Duration(1000)); JavaTestUtils.attachTestOutputStream(reduceWindowed); List>> result = JavaTestUtils.runStreams(ssc, 3, 3);