Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
respond to code review comments
  • Loading branch information
dtenedor committed Mar 18, 2023
commit 492069ba1b2d2df2c3749ad0c81b5492bb438992
Original file line number Diff line number Diff line change
Expand Up @@ -126,23 +126,4 @@ trait SQLQueryTestHelper extends Logging {
(emptySchema, Seq(e.getClass.getName, e.getMessage))
}
}

protected def splitWithSemicolon(seq: Seq[String]): Array[String] = {
seq.mkString("\n").split("(?<=[^\\\\]);")
}

protected def splitCommentsAndCodes(input: String): (Array[String], Array[String]) =
input.split("\n").partition { line =>
val newLine = line.trim
newLine.startsWith("--") && !newLine.startsWith("--QUERY-DELIMITER")
}

/** Returns all the files (not directories) in a directory, recursively. */
protected def listFilesRecursively(path: File): Seq[File] = {
val (dirs, files) = path.listFiles().partition(_.isDirectory)
// Filter out test files with invalid extensions such as temp files created
// by vi (.swp), Mac (.DS_Store) etc.
val filteredFiles = files.filter(_.getName.endsWith(validFileExtensions))
filteredFiles ++ dirs.flatMap(listFilesRecursively)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
/** List of test cases to perform analyzer tests for. */
protected def analyzerTestCaseList = Seq("array.sql")

/** A test case. */
protected trait TestCase {
val name: String
val inputFile: String
val resultFile: String
}

/**
* traits that indicate UDF or PgSQL to trigger the code path specific to each. For instance,
* PgSQL tests require to register some UDF functions.
Expand Down Expand Up @@ -261,27 +268,22 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
case _ =>
// Create a test case to run this case.
test(testCase.name) {
runSqlTestCase(testCase, listTestCases, runQueries)
runSqlTestCase(testCase, listTestCases)
}
}
}

/** A test case. */
protected trait TestCase {
val name: String
val inputFile: String
val resultFile: String
}

/** Run a test case. */
protected def runSqlTestCase(
testCase: TestCase,
listTestCases: Seq[TestCase],
runQueries: (
Seq[String], // queries
TestCase, // test case
Seq[(String, String)] // config set
) => Unit): Unit = {
protected def runSqlTestCase(testCase: TestCase, listTestCases: Seq[TestCase]): Unit = {
def splitWithSemicolon(seq: Seq[String]) = {
seq.mkString("\n").split("(?<=[^\\\\]);")
}

def splitCommentsAndCodes(input: String) = input.split("\n").partition { line =>
val newLine = line.trim
newLine.startsWith("--") && !newLine.startsWith("--QUERY-DELIMITER")
}

val input = fileToString(new File(testCase.inputFile))

val (comments, code) = splitCommentsAndCodes(input)
Expand Down Expand Up @@ -338,7 +340,6 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
conf.trim -> value.substring(1).trim
})

val regenerateGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1"
if (regenerateGoldenFiles) {
runQueries(queries, testCase, settings)
} else {
Expand Down Expand Up @@ -491,8 +492,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
val testCaseName = absPath.stripPrefix(inputFilePath).stripPrefix(File.separator)
val analyzerTestCaseName = s"${testCaseName}_analyzer_test"

val newTestCase =
if (file.getAbsolutePath.startsWith(
val newTestCase = if (file.getAbsolutePath.startsWith(
s"$inputFilePath${File.separator}udf${File.separator}postgreSQL")) {
Seq(TestScalaUDF("udf"), TestPythonUDF("udf"), TestScalarPandasUDF("udf")).map { udf =>
UDFPgSQLTestCase(
Expand Down Expand Up @@ -525,6 +525,15 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
}.sortBy(_.name)
}

/** Returns all the files (not directories) in a directory, recursively. */
protected def listFilesRecursively(path: File): Seq[File] = {
val (dirs, files) = path.listFiles().partition(_.isDirectory)
// Filter out test files with invalid extensions such as temp files created
// by vi (.swp), Mac (.DS_Store) etc.
val filteredFiles = files.filter(_.getName.endsWith(validFileExtensions))
filteredFiles ++ dirs.flatMap(listFilesRecursively)
}

/** Load built-in test tables into the SparkSession. */
protected def createTestTables(session: SparkSession): Unit = {
import session.implicits._
Expand Down Expand Up @@ -669,9 +678,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
* output of running a query.
*/
def readGoldenFileAndCompareResults(
resultFile: String,
outputs: Seq[QueryTestOutput],
makeOutput: (String, Option[String], String) => QueryTestOutput): Unit = {
resultFile: String,
outputs: Seq[QueryTestOutput],
makeOutput: (String, Option[String], String) => QueryTestOutput): Unit = {
// Read back the golden file.
val expectedOutputs: Seq[QueryTestOutput] = {
val goldenOutput = fileToString(new File(resultFile))
Expand Down Expand Up @@ -727,7 +736,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
def numSegments: Int
}

/** A single SQL query's output. */
/** A single SQL query's execution output. */
case class ExecutionOutput(
sql: String,
schema: Option[String],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import scala.util.control.NonFatal
import org.apache.commons.lang3.exception.ExceptionUtils

import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.sql.SQLQueryTestSuite
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
import org.apache.spark.sql.catalyst.util.fileToString
Expand Down Expand Up @@ -68,7 +69,7 @@ import org.apache.spark.sql.types._
* 4. Support UDAF testing.
*/
// scalastyle:on line.size.limit
class ThriftServerQueryTestSuite extends SQLQueryTestSuite with SharedThriftServer {
class ThriftServerQueryTestSuite extends SQLQueryTestSuite with SharedThriftServer with Logging {


override def mode: ServerMode.Value = ServerMode.binary
Expand Down Expand Up @@ -235,7 +236,7 @@ class ThriftServerQueryTestSuite extends SQLQueryTestSuite with SharedThriftServ
} else {
// Create a test case to run this case.
test(testCase.name) {
runSqlTestCase(testCase, Seq.empty, runQueries)
runSqlTestCase(testCase, listTestCases)
}
}
}
Expand Down