Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge remote-tracking branch 'origin/master' into streaming-infra
Conflicts:
	sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
  • Loading branch information
marmbrus committed Jan 5, 2016
commit e3c4c8301fdcfaaa0bd56ed92a81e4e1d2db64a8
325 changes: 92 additions & 233 deletions sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,23 @@

package org.apache.spark.sql

import java.lang.Thread.UncaughtExceptionHandler
import java.util.{Locale, TimeZone}

import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._

import scala.collection.JavaConverters._

import org.apache.spark.sql.catalyst.encoders.{RowEncoder, encoderFor}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.Queryable

import scala.collection.mutable
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.{LogicalRDD, Queryable}

abstract class QueryTest extends PlanTest with Timeouts {
abstract class QueryTest extends PlanTest {

protected def sqlContext: SQLContext

Expand Down Expand Up @@ -196,241 +189,107 @@ abstract class QueryTest extends PlanTest with Timeouts {
planWithCaching)
}

// ==========================
// Streaming helper functions
// ==========================

implicit class RichSource(s: Source) {
def toDF(): DataFrame = new DataFrame(sqlContext, StreamingRelation(s))
}

/** How long to wait for an active stream to catch up when checking a result. */
val streamingTimout = 10.seconds

/** A trait for actions that can be performed while testing a streaming DataFrame. */
trait StreamAction

/** A trait to mark actions that require the stream to be actively running. */
trait StreamMustBeRunning

/**
* Adds the given data to the stream. Subsuquent check answers will block until this data has
* been processed.
*/
object AddData {
def apply[A](source: MemoryStream[A], data: A*): AddDataMemory[A] =
AddDataMemory(source, data)
}

/** A trait that can be extended when testing other sources. */
trait AddData extends StreamAction {
def source: Source
def addData(): Offset
}

case class AddDataMemory[A](source: MemoryStream[A], data: Seq[A]) extends AddData {
override def toString: String = s"AddData to $source: ${data.mkString(",")}"

override def addData(): Offset = {
source.addData(data)
private def checkJsonFormat(df: DataFrame): Unit = {
val logicalPlan = df.queryExecution.analyzed
// bypass some cases that we can't handle currently.
logicalPlan.transform {
case _: MapPartitions[_, _] => return
case _: MapGroups[_, _, _] => return
case _: AppendColumns[_, _] => return
case _: CoGroup[_, _, _, _] => return
case _: LogicalRelation => return
}.transformAllExpressions {
case a: ImperativeAggregate => return
}
}

case class AwaitEventTime(time: Long) extends StreamAction with StreamMustBeRunning
// bypass hive tests before we fix all corner cases in hive module.
if (this.getClass.getName.startsWith("org.apache.spark.sql.hive")) return

/**
* Checks to make sure that the current data stored in the sink matches the `expectedAnswer`.
* This operation automatically blocks untill all added data has been processed.
*/
object CheckAnswer {
def apply[A : Encoder](data: A*): CheckAnswerRows = {
val encoder = encoderFor[A]
val toExternalRow = RowEncoder(encoder.schema)
CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))))
val jsonString = try {
logicalPlan.toJSON
} catch {
case e =>
fail(
s"""
|Failed to parse logical plan to JSON:
|${logicalPlan.treeString}
""".stripMargin, e)
}

def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows)
}

case class CheckAnswerRows(expectedAnswer: Seq[Row])
extends StreamAction with StreamMustBeRunning {
override def toString: String = s"CheckAnswer: ${expectedAnswer.mkString(",")}"
}

case class DropBatches(num: Int) extends StreamAction

/** Stops the stream. It must currently be running. */
case object StopStream extends StreamAction
// scala function is not serializable to JSON, use null to replace them so that we can compare
// the plans later.
val normalized1 = logicalPlan.transformAllExpressions {
case udf: ScalaUDF => udf.copy(function = null)
case gen: UserDefinedGenerator => gen.copy(function = null)
}

/** Starts the stream, resuming if data has already been processed. It must not be running. */
case object StartStream extends StreamAction
// RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains
// these non-serializable stuff, and use these original ones to replace the null-placeholders
// in the logical plans parsed from JSON.
var logicalRDDs = logicalPlan.collect { case l: LogicalRDD => l }
var localRelations = logicalPlan.collect { case l: LocalRelation => l }
var inMemoryRelations = logicalPlan.collect { case i: InMemoryRelation => i }

/** Restarts all sources that implement a `restart()` method. */
case object RestartSources extends StreamAction
val jsonBackPlan = try {
TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext)
} catch {
case e =>
fail(
s"""
|Failed to rebuild the logical plan from JSON:
|${logicalPlan.treeString}
|
|${logicalPlan.prettyJson}
""".stripMargin, e)
}

/** Signals that a failure is expected and should not kill the test. */
case object ExpectFailure extends StreamAction
val normalized2 = jsonBackPlan transformDown {
case l: LogicalRDD =>
val origin = logicalRDDs.head
logicalRDDs = logicalRDDs.drop(1)
LogicalRDD(l.output, origin.rdd)(sqlContext)
case l: LocalRelation =>
val origin = localRelations.head
localRelations = localRelations.drop(1)
l.copy(data = origin.data)
case l: InMemoryRelation =>
val origin = inMemoryRelations.head
inMemoryRelations = inMemoryRelations.drop(1)
InMemoryRelation(
l.output,
l.useCompression,
l.batchSize,
l.storageLevel,
origin.child,
l.tableName)(
origin.cachedColumnBuffers,
l._statistics,
origin._batchStats)
}

/** A helper for running actions on a Streaming Dataset. See `checkAnswer(DataFrame)`. */
def testStream(stream: Dataset[_])(actions: StreamAction*): Unit =
testStream(stream.toDF())(actions: _*)
assert(logicalRDDs.isEmpty)
assert(localRelations.isEmpty)
assert(inMemoryRelations.isEmpty)

/**
* Executes the specified actions on the the given streaming DataFrame and provides helpful
* error messages in the case of failures or incorrect answers.
*
* Note that if the stream is not explictly started before an action that requires it to be
* running then it will be automatically started before performing any other actions.
*/
def testStream(stream: DataFrame)(actions: StreamAction*): Unit = {
var pos = 0
var currentPlan: LogicalPlan = stream.logicalPlan
var currentStream: StreamExecution = null
val awaiting = new mutable.HashMap[Source, Offset]()
val sink = new MemorySink(stream.schema)

@volatile
var streamDeathCause: Throwable = null

// If the test doesn't manually start the stream, we do it automatically at the beginning.
val startedManually =
actions.takeWhile(_.isInstanceOf[StreamMustBeRunning]).contains(StartStream)
val startedTest = if (startedManually) actions else StartStream +: actions

def testActions = actions.zipWithIndex.map {
case (a, i) =>
if ((pos == i && startedManually) || (pos == (i + 1) && !startedManually)) {
"=> " + a.toString
} else {
" " + a.toString
}
}.mkString("\n")

def currentOffsets =
if (currentStream != null) currentStream.currentOffsets.toString else "not started"

def threadState =
if (currentStream != null && currentStream.microBatchThread.isAlive) "alive" else "dead"
def testState =
s"""
|== Progress ==
|$testActions
|
|== Stream ==
|Stream state: $currentOffsets
|Thread state: $threadState
|Event time trigger: ${if (currentStream != null) currentStream.maxEventTime else ""}
|${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""}
|
|== Sink ==
|$sink
|
|== Plan ==
|${if (currentStream != null) currentStream.lastExecution else ""}
"""

def checkState(check: Boolean, error: String) = if (!check) {
if (normalized1 != normalized2) {
fail(
s"""
|Invalid State: $error
|$testState
""".stripMargin)
|== FAIL: the logical plan parsed from json does not match the original one ===
|${sideBySide(logicalPlan.treeString, normalized2.treeString).mkString("\n")}
""".stripMargin)
}
}

val testThread = Thread.currentThread()

try {
startedTest.foreach { action =>
action match {
case StartStream =>
checkState(currentStream == null, "stream already running")

currentPlan = currentPlan transform {
case StreamingRelation(s, _) =>
StreamingRelation(s.restart())
}

currentStream = new StreamExecution(sqlContext, stream.logicalPlan, sink)
currentStream.microBatchThread.setUncaughtExceptionHandler(
new UncaughtExceptionHandler {
override def uncaughtException(t: Thread, e: Throwable): Unit = {
streamDeathCause = e
testThread.interrupt()
}
})

case StopStream =>
checkState(currentStream != null, "can not stop a stream that is not running")
currentStream.stop()
currentStream = null

case DropBatches(num) =>
checkState(currentStream == null, "dropping batches while running leads to corruption")
sink.dropBatches(num)

case ExpectFailure =>
try failAfter(streamingTimout) {
while (streamDeathCause == null) {
Thread.sleep(100)
}
} catch {
case _: InterruptedException =>
case _: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
fail(
s"""
|Timed out while waiting for failure.
|$testState
""".stripMargin)
}

currentStream = null
streamDeathCause = null

case a: AddData =>
awaiting.put(a.source, a.addData())

case AwaitEventTime(time) =>
checkState(currentStream != null, "stream not running")
failAfter(streamingTimout) {
currentStream.awaitOffset(currentStream.eventTimeSource, LongOffset(time))
}
case CheckAnswerRows(expectedAnswer) =>
checkState(currentStream != null, "stream not running")

// Block until all data added has been processed
awaiting.foreach { case (source, offset) =>
failAfter(streamingTimout) {
currentStream.awaitOffset(source, offset)
}
}
QueryTest.sameRows(expectedAnswer, sink.allData).foreach {
error => fail(
s"""
|$error
|$testState
""".stripMargin)
}
}
pos += 1
}
} catch {
case _: InterruptedException if streamDeathCause != null =>
fail(
s"""
|Stream Thread Died
|$testState
""".stripMargin)
case _: org.scalatest.exceptions.TestFailedDueToTimeoutException =>
fail(
s"""
|Timed out waiting for stream
|$testState
""".stripMargin)
} finally {
if (currentStream != null && currentStream.microBatchThread.isAlive) {
currentStream.stop()
}
}
/**
* Asserts that a given [[Queryable]] does not have missing inputs in all the analyzed plans.
*/
def assertEmptyMissingInput(query: Queryable): Unit = {
assert(query.queryExecution.analyzed.missingInput.isEmpty,
s"The analyzed logical plan has missing inputs: ${query.queryExecution.analyzed}")
assert(query.queryExecution.optimizedPlan.missingInput.isEmpty,
s"The optimized logical plan has missing inputs: ${query.queryExecution.optimizedPlan}")
assert(query.queryExecution.executedPlan.missingInput.isEmpty,
s"The physical plan has missing inputs: ${query.queryExecution.executedPlan}")
}
}

Expand Down
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.