Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@
package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions.{BoundReference, Ascending, SortOrder}
import org.apache.spark.sql.catalyst.dsl.expressions._

import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType}

class SortSuite extends SparkPlanTest {
import TestSQLContext.implicits.localSeqToDataFrameHolder

test("basic sorting using ExternalSort") {

Expand All @@ -30,16 +34,14 @@ class SortSuite extends SparkPlanTest {
("World", 8)
)

val sortOrder = Seq(
SortOrder(BoundReference(0, StringType, nullable = false), Ascending),
SortOrder(BoundReference(1, IntegerType, nullable = false), Ascending)
)

checkAnswer(
input,
(child: SparkPlan) => new ExternalSort(sortOrder, global = false, child),
input.sorted
)
input.toDF("a", "b"),
ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan),
input.sorted)

checkAnswer(
input.toDF("a", "b"),
ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan),
input.sortBy(t => (t._2, t._1)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ import scala.util.control.NonFatal
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.SparkFunSuite

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.BoundReference
import org.apache.spark.sql.catalyst.util._

import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.sql.catalyst.util._

/**
* Base class for writing tests for individual physical operators. For an example of how this
Expand All @@ -48,6 +52,24 @@ class SparkPlanTest extends SparkFunSuite {
}
}

/**
* Runs the plan and makes sure the answer matches the expected result.
* @param input the input data to be used.
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
*/
protected def checkAnswer[A <: Product : TypeTag](
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
expectedAnswer: Seq[A]): Unit = {
val expectedRows = expectedAnswer.map(Row.fromTuple)
SparkPlanTest.checkAnswer(input, planFunction, expectedRows) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
}

/**
* Runs the plan and makes sure the answer matches the expected result.
* @param input the input data to be used.
Expand Down Expand Up @@ -87,6 +109,23 @@ object SparkPlanTest {

val outputPlan = planFunction(input.queryExecution.sparkPlan)

// A very simple resolver to make writing tests easier. In contrast to the real resolver
// this is always case sensitive and does not try to handle scoping or complex type resolution.
val resolvedPlan = outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).zipWithIndex.map {
case (a, i) =>
(a.name, BoundReference(i, a.dataType, a.nullable))
}.toMap

plan.transformExpressions {
case UnresolvedAttribute(Seq(u)) =>
inputMap.get(u).getOrElse {
sys.error(s"Invalid Test: Cannot resolve $u given input ${inputMap}")
}
}
}

def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
Expand All @@ -105,7 +144,7 @@ object SparkPlanTest {
}

val sparkAnswer: Seq[Row] = try {
outputPlan.executeCollect().toSeq
resolvedPlan.executeCollect().toSeq
} catch {
case NonFatal(e) =>
val errorMessage =
Expand Down