Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
working on stateful streaming
  • Loading branch information
marmbrus committed Dec 11, 2015
commit 7a3590fe5dd659d1a47e645eec247e21280a3373
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming

import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.types.{DoubleType, LongType}
import org.apache.spark.sql.{Strategy, SQLContext, Column, DataFrame}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, UnaryNode}
import org.apache.spark.sql.execution.{SparkPlanner, SparkPlan}

package object state {

trait StateStore {
def get(key: InternalRow): InternalRow
}

case class Window(
eventtime: Expression,
windowAttribute: AttributeReference,
step: Int,
closingTriggerDelay: Int,
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = windowAttribute +: child.output
override def missingInput: AttributeSet = super.missingInput - windowAttribute
}

implicit object MaxWatermark extends AccumulatorParam[Watermark] {
/**
* Add additional data to the accumulator value. Is allowed to modify and return `r`
* for efficiency (to avoid allocating objects).
*
* @param r the current value of the accumulator
* @param t the data to be added to the accumulator
* @return the new value of the accumulator
*/
override def addAccumulator(r: Watermark, t: Watermark): Watermark = if (r > t) r else t

/**
* Merge two accumulated values together. Is allowed to modify and return the first value
* for efficiency (to avoid allocating objects).
*
* @param r1 one set of accumulated data
* @param r2 another set of accumulated data
* @return both data sets merged together
*/
override def addInPlace(r1: Watermark, r2: Watermark): Watermark = if (r1 > r2) r1 else r2

/**
* Return the "zero" (identity) value for an accumulator type, given its initial value. For
* example, if R was a vector of N dimensions, this would return a vector of N zeroes.
*/
override def zero(initialValue: Watermark): Watermark = new Watermark(-1)
}

case class WindowAggregate(
eventtime: Expression,
eventtimeMax: Accumulator[Watermark],
eventtimeWatermark: Watermark,
windowAttribute: AttributeReference,
step: Int,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: SparkPlan) extends SparkPlan {
/**
* Overridden by concrete implementations of SparkPlan.
* Produces the result of the query as an RDD[InternalRow]
*/
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitions { iter =>
val window =
Multiply(Ceil(Divide(Cast(eventtime, DoubleType), Literal(step))), Literal(step))
val windowAndGroupProjection =
GenerateUnsafeProjection.generate(window +: groupingExpressions, child.output)
iter.foreach { row =>
val windowAndGroup = windowAndGroupProjection(row)
println(windowAndGroup.toSeq((windowAttribute +: groupingExpressions).map(_.dataType)))

eventtimeMax += new Watermark(windowAndGroup.getLong(0))
}

Iterator.empty
}
}

override def output: Seq[Attribute] =
windowAttribute +: aggregateExpressions.map(_.toAttribute)

override def missingInput: AttributeSet = super.missingInput - windowAttribute

/**
* Returns a Seq of the children of this node.
* Children should not change. Immutability required for containsChild optimization
*/
override def children: Seq[SparkPlan] = child :: Nil
}

implicit class StatefulDataFrame(df: DataFrame) {
def window(eventTime: Column, step: Int, closingTriggerDelay: Int): DataFrame =
df.withPlan(
Window(
eventTime.expr,
new AttributeReference("window", LongType)(),
step,
closingTriggerDelay,
df.logicalPlan))
}

class StatefulPlanner(sqlContext: SQLContext, maxWatermark: Accumulator[Watermark])
extends SparkPlanner(sqlContext) {

override def strategies: Seq[Strategy] = WindowStrategy +: super.strategies

object WindowStrategy extends Strategy with Logging {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case Aggregate(grouping, aggregate,
Window(et, windowAttribute, step, trigger, child)) =>
println(s"triggering ${maxWatermark.value - trigger}")

WindowAggregate(
et,
maxWatermark,
maxWatermark.value - trigger,
windowAttribute,
step,
grouping,
aggregate,
planLater(child)) :: Nil

case other => Nil
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ abstract class QueryTest extends PlanTest with Timeouts {
|== Stream ==
|Stream state: $currentWatermarks
|Thread state: $threadState
|Event time trigger: ${if (currentStream != null) currentStream.maxEventTime else ""}
|${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""}
|
|== Sink ==
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.streaming

import org.apache.spark.sql.{Row, QueryTest}
import org.apache.spark.sql.functions._

import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.test.SharedSQLContext

class StatefulStreamSuite extends QueryTest with SharedSQLContext {

import testImplicits._

test("windowed aggregation") {
val inputData = MemoryStream[Int]
val tenSecondCounts =
inputData.toDF("eventTime")
.window($"eventTime", step = 10, closingTriggerDelay = 20)
.groupBy($"eventTime" % 2)
.agg(count("*"))

testStream(tenSecondCounts)(
AddData(inputData, 1, 2, 3),
CheckAnswer(),
AddData(inputData, 11, 12),
CheckAnswer(),
AddData(inputData, 20),
CheckAnswer(),
AddData(inputData, 30),
CheckAnswer((0, 0, 1), (0, 1, 2)))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.streaming

import org.apache.spark.sql.{Row, QueryTest}
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.test.SharedSQLContext

class StreamSuite extends QueryTest with SharedSQLContext {

import testImplicits._

test("map with recovery") {
val inputData = MemoryStream[Int]
val mapped = inputData.toDS().map(_ + 1)

testStream(mapped)(
AddData(inputData, 1, 2, 3),
CheckAnswer(2, 3, 4),
StopStream,
AddData(inputData, 4, 5, 6),
StartStream,
CheckAnswer(2, 3, 4, 5, 6, 7))
}

test("join") {
// Make a table and ensure it will be broadcast.
val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word")

// Join the input stream with a table.
val inputData = MemoryStream[Int]
val joined = inputData.toDS().toDF().join(smallTable, $"value" === $"number")

testStream(joined)(
AddData(inputData, 1, 2, 3),
CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two")),
AddData(inputData, 4),
CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four")))
}

test("union two streams") {
val inputData1 = MemoryStream[Int]
val inputData2 = MemoryStream[Int]

val unioned = inputData1.toDS().union(inputData2.toDS())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the user reuse a Source? E.g., inputData1.toDS().union(inputData1.toDS())

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that works already.


testStream(unioned)(
AddData(inputData1, 1, 3, 5),
CheckAnswer(1, 3, 5),
AddData(inputData2, 2, 4, 6),
CheckAnswer(1, 2, 3, 4, 5, 6),
StopStream,
AddData(inputData1, 7),
StartStream,
AddData(inputData2, 8),
CheckAnswer(1, 2, 3, 4, 5, 6, 7, 8))
}

test("sql queries") {
val inputData = MemoryStream[Int]
inputData.toDF().registerTempTable("stream")
val evens = sql("SELECT * FROM stream WHERE value % 2 = 0")

testStream(evens)(
AddData(inputData, 1, 2, 3, 4),
CheckAnswer(2, 4))
}
}