Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
now with event time windows
  • Loading branch information
marmbrus committed Dec 12, 2015
commit c8a923831bbfc24f71eb744e36a15e432d6ae067
3 changes: 3 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import java.beans.{BeanInfo, Introspector}
import java.util.Properties
import java.util.concurrent.atomic.AtomicReference

import org.apache.spark.sql.execution.streaming.state.GroupWindows

import scala.collection.JavaConverters._
import scala.collection.immutable
import scala.reflect.runtime.universe.TypeTag
Expand Down Expand Up @@ -192,6 +194,7 @@ class SQLContext private[sql](
protected[sql] lazy val analyzer: Analyzer =
new Analyzer(catalog, functionRegistry, conf) {
override val extendedResolutionRules =
GroupWindows ::
ExtractPythonUDFs ::
PreInsertCastAndRename ::
(if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,53 @@

package org.apache.spark.sql.execution.streaming

import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.{Accumulator, Logging}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.execution.streaming.state.StatefulPlanner
import org.apache.spark.sql.execution.streaming.state.{GroupWindows, StatefulPlanner}
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.execution.{SparkPlan, QueryExecution, LogicalRDD}

class EventTimeSource(val max: Accumulator[Watermark]) extends Source {
override def watermark: Watermark = max.value

override def getSlice(
sqlContext: SQLContext, start: Watermark, end: Watermark): RDD[InternalRow] = ???

// HACK
override def equals(other: Any): Boolean = other.isInstanceOf[EventTimeSource]
override def hashCode: Int = 0

override def toString: String = "EventTime"
}

case object BatchId extends Source {
override def watermark: Watermark = new Watermark(-1)

override def getSlice(
sqlContext: SQLContext, start: Watermark, end: Watermark): RDD[InternalRow] = ???
}

class StreamExecution(
sqlContext: SQLContext,
private[sql] val logicalPlan: LogicalPlan,
val sink: Sink) extends Logging {

/** All stream sources present the query plan. */
private val sources = logicalPlan.collect { case s: Source => s: Source }

/** Tracks how much data we have processed from each input source. */
private[sql] val currentWatermarks = new StreamProgress

import org.apache.spark.sql.execution.streaming.state.MaxWatermark
private[sql] val maxEventTime = sqlContext.sparkContext.accumulator(Watermark(-1))(org.apache.spark.sql.execution.streaming.state.MaxWatermark)
private[sql] val maxEventTime =
sqlContext.sparkContext.accumulator(Watermark(-1))

private[sql] val eventTimeSource = new EventTimeSource(maxEventTime)

/** All stream sources present the query plan. */
private val sources =
logicalPlan.collect { case s: Source => s: Source } ++ Seq(eventTimeSource, BatchId)

// Start the execution at the current watermark for the sink. (i.e. avoid reprocessing data
// that we have already processed).
Expand All @@ -45,6 +72,11 @@ class StreamExecution(
currentWatermarks.update(s, sourceWatermark)
}

// Restore the position of the eventtime watermark accumulator
currentWatermarks.get(eventTimeSource).foreach(eventTimeSource.max.setValue)

logInfo(s"Stream running at $currentWatermarks")

/** When false, signals to the microBatchThread that it should stop running. */
@volatile private var shouldRun = true

Expand All @@ -58,6 +90,8 @@ class StreamExecution(
microBatchThread.setDaemon(true)
microBatchThread.start()

var lastExecution: QueryExecution = null
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: volatile and private[streaming]?


/**
* Checks to see if any new data is present in any of the sources. When new data is available,
* a batch is executed and passed to the sink, updating the currentWatermarks.
Expand All @@ -71,7 +105,7 @@ class StreamExecution(

if (newData.nonEmpty) {
val startTime = System.nanoTime()
logDebug(s"Running with new data up to: $newData")
logInfo(s"Running with new data up to: $newData")

// Replace sources in the logical plan with data that has arrived since the last batch.
val newPlan = logicalPlan transform {
Expand All @@ -83,22 +117,32 @@ class StreamExecution(

val optimizerStart = System.nanoTime()

val executedPlan = new QueryExecution(sqlContext, newPlan) {
lastExecution = new QueryExecution(sqlContext, newPlan) {
// Skip the optimizer for now cause the streaming planner is not great.
override lazy val optimizedPlan: LogicalPlan = EliminateSubQueries(analyzed)
override lazy val sparkPlan: SparkPlan = {
SQLContext.setActive(sqlContext)
new StatefulPlanner(sqlContext, maxEventTime).plan(optimizedPlan).next()
}
}.executedPlan
val batchPlanner = new StatefulPlanner(
sqlContext,
maxEventTime,
currentWatermarks(BatchId))

logInfo(s"BatchPlanner: ${currentWatermarks(BatchId)}, $maxEventTime")

batchPlanner.plan(optimizedPlan).next()
}
}
val executedPlan = lastExecution.executedPlan
val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000
logDebug(s"Optimized batch in ${optimizerTime}ms")

val results = executedPlan.execute().map(_.copy())
sink.addBatch(newData, results)
val updated = newData + (BatchId -> (currentWatermarks(BatchId) + 1))
sink.addBatch(updated, results)
updated.foreach(currentWatermarks.update)

logInfo(s"EventTime Watermark: ${maxEventTime.value}")
StreamExecution.this.synchronized {
newData.foreach(currentWatermarks.update)
StreamExecution.this.notifyAll()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,22 @@ case class MemoryStream[A : Encoder](output: Seq[Attribute]) extends LeafNode wi
}

class MemorySink(schema: StructType) extends Sink with Logging {
private val currentWatermarks = new StreamProgress
private var batches = new ArrayBuffer[(StreamProgress, Seq[Row])]()

private val output = schema.toAttributes

def currentWatermarks: StreamProgress = batches.lastOption.map(_._1).getOrElse(new StreamProgress)
def currentWatermark(source: Source): Option[Watermark] = currentWatermarks.get(source)

def allData: Seq[Row] = batches.flatMap(_._2)

val externalRowConverter = RowEncoder(schema)
def addBatch(watermarks: Map[Source, Watermark], rdd: RDD[InternalRow]): Unit = {
watermarks.foreach(currentWatermarks.update)
batches.append((currentWatermarks.copy(), rdd.collect().map(externalRowConverter.fromRow)))
watermarks.foreach(currentWatermarks.update)
}

def dropBatches(num: Int): Unit = {
batches.remove(batches.size - num, num)
}

override def toString: String = batches.map(b => s"${b._1}: ${b._2.mkString(" ")}").mkString("\n")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.types.{DoubleType, LongType}
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeRowJoiner, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.{StructField, StructType, DoubleType, LongType}
import org.apache.spark.sql.{Strategy, SQLContext, Column, DataFrame}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -29,8 +31,45 @@ import org.apache.spark.sql.execution.{SparkPlanner, SparkPlan}

package object state {

trait StateStore {
def get(key: InternalRow): InternalRow
object StateStore extends Logging {
private val checkpoints = new scala.collection.mutable.HashMap[(Int, Watermark), StateStore]()

/** Checkpoints the state for a partition at a given batchId. */
def checkpoint(batchId: Watermark, store: StateStore): Unit = synchronized {
val partitionId = TaskContext.get().partitionId()
logInfo(s"Checkpointing ${(partitionId, batchId)}")
checkpoints.put((partitionId, batchId), store)
}

/** Gets the state for a partition at a given batchId. */
def getAt(batchId: Watermark): StateStore = synchronized {
val partitionId = TaskContext.get().partitionId()

if (batchId.offset == -1) {
new StateStore
} else {
val copy = new StateStore
checkpoints((partitionId, batchId)).data.foreach(copy.data.+=)
copy
}
}
}


class StateStore {
private val data = new scala.collection.mutable.HashMap[InternalRow, Long]()

def get(key: InternalRow): Option[Long] = data.get(key)

def put(key: InternalRow, value: Long): Unit = {
data.put(key, value)
}

def triggerWindowsBefore(watermark: Long): Seq[(InternalRow, Long)] = {
val triggeredWindows = data.filter(_._1.getLong(0) <= watermark).toArray
triggeredWindows.map(_._1).foreach(data.remove)
triggeredWindows.toSeq
}
}

case class Window(
Expand Down Expand Up @@ -71,33 +110,63 @@ package object state {
override def zero(initialValue: Watermark): Watermark = new Watermark(-1)
}

/**
* Simple incremental windowing aggregate that supports a fixed delay. Late data is dropped.
* Only supports grouped count(*) :P.
*/
case class WindowAggregate(
eventtime: Expression,
eventtimeMax: Accumulator[Watermark],
eventtimeWatermark: Watermark,
lastCheckpoint: Watermark,
windowAttribute: AttributeReference,
step: Int,
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: SparkPlan) extends SparkPlan {

override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(groupingExpressions.filterNot(_ == windowAttribute)) :: Nil

/**
* Overridden by concrete implementations of SparkPlan.
* Produces the result of the query as an RDD[InternalRow]
*/
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitions { iter =>
val stateStore = StateStore.getAt(lastCheckpoint)

// TODO: Move to window operator...
val window =
Multiply(Ceil(Divide(Cast(eventtime, DoubleType), Literal(step))), Literal(step))
val windowAndGroupProjection =
GenerateUnsafeProjection.generate(window +: groupingExpressions, child.output)
GenerateUnsafeProjection.generate(
window +: groupingExpressions.filterNot(_ == windowAttribute), child.output)

iter.foreach { row =>
val windowAndGroup = windowAndGroupProjection(row)
println(windowAndGroup.toSeq((windowAttribute +: groupingExpressions).map(_.dataType)))

eventtimeMax += new Watermark(windowAndGroup.getLong(0))
if (windowAndGroup.getLong(0) > eventtimeWatermark.offset) {
eventtimeMax += new Watermark(windowAndGroup.getLong(0))
val newCount = stateStore.get(windowAndGroup).getOrElse(0L) + 1
stateStore.put(windowAndGroup.copy(), newCount)
}
}

Iterator.empty
StateStore.checkpoint(lastCheckpoint + 1, stateStore)

val buildRow = GenerateUnsafeProjection.generate(
BoundReference(0, LongType, false) ::
BoundReference(1, LongType, false) ::
BoundReference(2, LongType, false) :: Nil)
val joinedRow = new JoinedRow
val countRow = new SpecificMutableRow(LongType :: Nil)

logDebug(s"Triggering windows < $eventtimeWatermark")
stateStore.triggerWindowsBefore(eventtimeWatermark.offset).toIterator.map {
case (key, count) =>
countRow.setLong(0, count)
buildRow(joinedRow(key, countRow))
}
}
}

Expand All @@ -124,21 +193,35 @@ package object state {
df.logicalPlan))
}

class StatefulPlanner(sqlContext: SQLContext, maxWatermark: Accumulator[Watermark])
extends SparkPlanner(sqlContext) {

object GroupWindows extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Aggregate(grouping, aggregate,
Window(et, windowAttribute, step, trigger, child))
if !grouping.contains(windowAttribute) =>
Aggregate(windowAttribute +: grouping, windowAttribute +: aggregate,
Window(et, windowAttribute, step, trigger, child))
}
}


class StatefulPlanner(
sqlContext: SQLContext,
maxWatermark: Accumulator[Watermark],
lastCheckpoint: Watermark)
extends SparkPlanner(sqlContext) {

override def strategies: Seq[Strategy] = WindowStrategy +: super.strategies

object WindowStrategy extends Strategy with Logging {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case Aggregate(grouping, aggregate,
Window(et, windowAttribute, step, trigger, child)) =>
println(s"triggering ${maxWatermark.value - trigger}")

WindowAggregate(
et,
maxWatermark,
maxWatermark.value - trigger,
lastCheckpoint,
windowAttribute,
step,
grouping,
Expand Down
Loading