Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
respond to code review comments
  • Loading branch information
dtenedor committed Mar 16, 2023
commit 2009b91543fe2f7f7c2a8581466e9c397a35a6f3
95 changes: 1 addition & 94 deletions sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql

import java.io.File
import java.util.TimeZone

import scala.collection.JavaConverters._
Expand All @@ -30,6 +29,7 @@ import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.storage.StorageLevel


abstract class QueryTest extends PlanTest {

protected def spark: SparkSession
Expand Down Expand Up @@ -229,60 +229,6 @@ abstract class QueryTest extends PlanTest {
s"The physical plan has missing inputs:\n${query.queryExecution.executedPlan}")
}

/**
* Consumes contents from a single golden file and compares the expected results against the
* output of running a query.
*/
def readGoldenFileAndCompareResults(
resultFile: String,
outputs: Seq[QueryTestOutput],
makeOutput: (String, Option[String], String) => QueryTestOutput): Unit = {
// Read back the golden file.
val expectedOutputs: Seq[QueryTestOutput] = {
val goldenOutput = fileToString(new File(resultFile))
val segments = goldenOutput.split("-- !query.*\n")

val numSegments = outputs.map(_.numSegments).sum + 1
assert(segments.size == numSegments,
s"Expected $numSegments blocks in result file but got " +
s"${segments.size}. Try regenerate the result files.")
var curSegment = 0
outputs.map { output =>
val result = if (output.numSegments == 3) {
makeOutput(
segments(curSegment + 1).trim, // SQL
Some(segments(curSegment + 2).trim), // Schema
segments(curSegment + 3).replaceAll("\\s+$", "")) // Output
} else {
makeOutput(
segments(curSegment + 1).trim, // SQL
None, // Schema
segments(curSegment + 2).replaceAll("\\s+$", "")) // Output
}
curSegment += output.numSegments
result
}
}

// Compare results.
assertResult(expectedOutputs.size, s"Number of queries should be ${expectedOutputs.size}") {
outputs.size
}

outputs.zip(expectedOutputs).zipWithIndex.foreach { case ((output, expected), i) =>
assertResult(expected.sql, s"SQL query did not match for query #$i\n${expected.sql}") {
output.sql
}
assertResult(expected.schema,
s"Schema did not match for query #$i\n${expected.sql}: $output") {
output.schema
}
assertResult(expected.output, s"Result did not match" +
s" for query #$i\n${expected.sql}") {
output.output
}
}
}
}

object QueryTest extends Assertions {
Expand Down Expand Up @@ -480,45 +426,6 @@ object QueryTest extends Assertions {
}
}

/** A single SQL query's output. */
trait QueryTestOutput {
def sql: String
def schema: Option[String]
def output: String
def numSegments: Int
}

/** A single SQL query's output. */
protected case class ExecutionOutput(
sql: String,
schema: Option[String],
output: String) extends QueryTestOutput {
override def toString: String = {
// We are explicitly not using multi-line string due to stripMargin removing "|" in output.
s"-- !query\n" +
sql + "\n" +
s"-- !query schema\n" +
schema.get + "\n" +
s"-- !query output\n" +
output
}
override def numSegments: Int = 3
}

/** A single SQL query's analysis results. */
protected case class AnalyzerOutput(
sql: String,
schema: Option[String],
output: String) extends QueryTestOutput {
override def toString: String = {
// We are explicitly not using multi-line string due to stripMargin removing "|" in output.
s"-- !query\n" +
sql + "\n" +
s"-- !query analysis\n" +
output
}
override def numSegments: Int = 2
}

class QueryTestSuite extends QueryTest with test.SharedSparkSession {
test("SPARK-16940: checkAnswer should raise TestFailedException for wrong results") {
Expand Down
107 changes: 102 additions & 5 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import org.apache.spark.{SparkConf, TestUtils}
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_SECOND
import org.apache.spark.sql.catalyst.util.stringToFile
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.TimestampTypes
Expand Down Expand Up @@ -381,6 +381,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
val testCaseName = absPath.stripPrefix(inputFilePath).stripPrefix(File.separator)
val analyzerTestCaseName = s"${testCaseName}_analyzer_test"

val newTestCase =
if (file.getAbsolutePath.startsWith(
s"$inputFilePath${File.separator}udf${File.separator}postgreSQL")) {
Seq(TestScalaUDF("udf"), TestPythonUDF("udf"), TestScalarPandasUDF("udf")).map { udf =>
Expand All @@ -403,13 +404,14 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
AnsiTestCase(testCaseName, absPath, resultFile) :: Nil
} else if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}timestampNTZ")) {
TimestampNTZTestCase(testCaseName, absPath, resultFile) :: Nil
} else if (analyzerTestCaseList.contains(file.getName.toLowerCase(Locale.ROOT))) {
Seq(
AnalyzerTestCase(analyzerTestCaseName, absPath, analyzerResultFile),
RegularTestCase(testCaseName, absPath, resultFile))
} else {
RegularTestCase(testCaseName, absPath, resultFile) :: Nil
}
if (analyzerTestCaseList.contains(file.getName.toLowerCase(Locale.ROOT))) {
AnalyzerTestCase(analyzerTestCaseName, absPath, analyzerResultFile) +: newTestCase
} else {
newTestCase
}
}.sortBy(_.name)
}

Expand Down Expand Up @@ -551,4 +553,99 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
super.afterAll()
}
}

/**
* Consumes contents from a single golden file and compares the expected results against the
* output of running a query.
*/
def readGoldenFileAndCompareResults(
resultFile: String,
outputs: Seq[QueryTestOutput],
makeOutput: (String, Option[String], String) => QueryTestOutput): Unit = {
// Read back the golden file.
val expectedOutputs: Seq[QueryTestOutput] = {
val goldenOutput = fileToString(new File(resultFile))
val segments = goldenOutput.split("-- !query.*\n")

val numSegments = outputs.map(_.numSegments).sum + 1
assert(segments.size == numSegments,
s"Expected $numSegments blocks in result file but got " +
s"${segments.size}. Try regenerate the result files.")
var curSegment = 0
outputs.map { output =>
val result = if (output.numSegments == 3) {
makeOutput(
segments(curSegment + 1).trim, // SQL
Some(segments(curSegment + 2).trim), // Schema
segments(curSegment + 3).replaceAll("\\s+$", "")) // Output
} else {
makeOutput(
segments(curSegment + 1).trim, // SQL
None, // Schema
segments(curSegment + 2).replaceAll("\\s+$", "")) // Output
}
curSegment += output.numSegments
result
}
}

// Compare results.
assertResult(expectedOutputs.size, s"Number of queries should be ${expectedOutputs.size}") {
outputs.size
}

outputs.zip(expectedOutputs).zipWithIndex.foreach { case ((output, expected), i) =>
assertResult(expected.sql, s"SQL query did not match for query #$i\n${expected.sql}") {
output.sql
}
assertResult(expected.schema,
s"Schema did not match for query #$i\n${expected.sql}: $output") {
output.schema
}
assertResult(expected.output, s"Result did not match" +
s" for query #$i\n${expected.sql}") {
output.output
}
}
}

/** A single SQL query's output. */
trait QueryTestOutput {
def sql: String
def schema: Option[String]
def output: String
def numSegments: Int
}

/** A single SQL query's output. */
case class ExecutionOutput(
sql: String,
schema: Option[String],
output: String) extends QueryTestOutput {
override def toString: String = {
// We are explicitly not using multi-line string due to stripMargin removing "|" in output.
s"-- !query\n" +
sql + "\n" +
s"-- !query schema\n" +
schema.get + "\n" +
s"-- !query output\n" +
output
}
override def numSegments: Int = 3
}

/** A single SQL query's analysis results. */
case class AnalyzerOutput(
sql: String,
schema: Option[String],
output: String) extends QueryTestOutput {
override def toString: String = {
// We are explicitly not using multi-line string due to stripMargin removing "|" in output.
s"-- !query\n" +
sql + "\n" +
s"-- !query analysis\n" +
output
}
override def numSegments: Int = 2
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import scala.util.control.NonFatal
import org.apache.commons.lang3.exception.ExceptionUtils

import org.apache.spark.SparkException
import org.apache.spark.sql.{ExecutionOutput, QueryTestOutput}
import org.apache.spark.sql.SQLQueryTestSuite
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.util.fileToString
Expand Down