Skip to content
Prev Previous commit
Add compareAnswers to object SQLTestUtils
  • Loading branch information
zsxwing committed Aug 28, 2015
commit 7dcd502fc7278978fab5a233f4a81fefcca8bf72
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.util.control.NonFatal
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.test.SQLTestUtils

/**
* Base class for writing tests for individual physical operators. For an example of how this
Expand Down Expand Up @@ -184,7 +184,7 @@ object SparkPlanTest {
return Some(errorMessage)
}

compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
SQLTestUtils.compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
s"""
| Results do not match.
| Actual result Spark plan:
Expand Down Expand Up @@ -229,7 +229,7 @@ object SparkPlanTest {
return Some(errorMessage)
}

compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
SQLTestUtils.compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
s"""
| Results do not match for Spark plan:
| $outputPlan
Expand All @@ -238,46 +238,6 @@ object SparkPlanTest {
}
}

private def compareAnswers(
sparkAnswer: Seq[Row],
expectedAnswer: Seq[Row],
sort: Boolean): Option[String] = {
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
// Java's java.math.BigDecimal.compareTo).
// For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
// equality test.
// This function is copied from Catalyst's QueryTest
val converted: Seq[Row] = answer.map { s =>
Row.fromSeq(s.toSeq.map {
case d: java.math.BigDecimal => BigDecimal(d)
case b: Array[Byte] => b.toSeq
case o => o
})
}
if (sort) {
converted.sortBy(_.toString())
} else {
converted
}
}
if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
val errorMessage =
s"""
| == Results ==
| ${sideBySide(
s"== Expected Answer - ${expectedAnswer.size} ==" +:
prepareAnswer(expectedAnswer).map(_.toString()),
s"== Actual Answer - ${sparkAnswer.size} ==" +:
prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")}
""".stripMargin
Some(errorMessage)
} else {
None
}
}

private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = {
// A very simple resolver to make writing tests easier. In contrast to the real resolver
// this is always case sensitive and does not try to handle scoping or complex type resolution.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@ package org.apache.spark.sql.execution.local
import scala.util.control.NonFatal

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.test.SQLTestUtils

class LocalNodeTest extends SparkFunSuite {

Expand Down Expand Up @@ -137,53 +135,12 @@ object LocalNodeTest {
return Some(errorMessage)
}

compareAnswers(outputResult, expectedAnswer, sortAnswers).map { errorMessage =>
SQLTestUtils.compareAnswers(outputResult, expectedAnswer, sortAnswers).map { errorMessage =>
s"""
| Results do not match for local plan:
| $outputNode
| $errorMessage
""".stripMargin
}
}

private def compareAnswers(
answer: Seq[Row],
expectedAnswer: Seq[Row],
sort: Boolean): Option[String] = {
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
// Java's java.math.BigDecimal.compareTo).
// For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
// equality test.
// This function is copied from Catalyst's QueryTest
val converted: Seq[Row] = answer.map { s =>
Row.fromSeq(s.toSeq.map {
case d: java.math.BigDecimal => BigDecimal(d)
case b: Array[Byte] => b.toSeq
case o => o
})
}
if (sort) {
converted.sortBy(_.toString())
} else {
converted
}
}
if (prepareAnswer(expectedAnswer) != prepareAnswer(answer)) {
val errorMessage =
s"""
| == Results ==
| ${sideBySide(
s"== Expected Answer - ${expectedAnswer.size} ==" +:
prepareAnswer(expectedAnswer).map(_.toString()),
s"== Actual Answer - ${answer.size} ==" +:
prepareAnswer(answer).map(_.toString())).mkString("\n")}
""".stripMargin
Some(errorMessage)
} else {
None
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ import org.apache.hadoop.conf.Configuration
import org.scalatest.BeforeAndAfterAll

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits}
import org.apache.spark.sql.{DataFrame, Row, SQLContext, SQLImplicits}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -179,3 +180,46 @@ private[sql] trait SQLTestUtils
DataFrame(_sqlContext, plan)
}
}

private[sql] object SQLTestUtils {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why put this in an object instead of in the already existing trait? It just makes invocation more verbose.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, because this is called in object SparkPlanTest and LocalNodeTest which don't use the already existing trait.


def compareAnswers(
sparkAnswer: Seq[Row],
expectedAnswer: Seq[Row],
sort: Boolean): Option[String] = {
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
// Java's java.math.BigDecimal.compareTo).
// For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
// equality test.
// This function is copied from Catalyst's QueryTest
val converted: Seq[Row] = answer.map { s =>
Row.fromSeq(s.toSeq.map {
case d: java.math.BigDecimal => BigDecimal(d)
case b: Array[Byte] => b.toSeq
case o => o
})
}
if (sort) {
converted.sortBy(_.toString())
} else {
converted
}
}
if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
val errorMessage =
s"""
| == Results ==
| ${sideBySide(
s"== Expected Answer - ${expectedAnswer.size} ==" +:
prepareAnswer(expectedAnswer).map(_.toString()),
s"== Actual Answer - ${sparkAnswer.size} ==" +:
prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")}
""".stripMargin
Some(errorMessage)
} else {
None
}
}
}