Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update
  • Loading branch information
dtenedor committed Mar 15, 2023
commit a677845f56efc41e41bf94bc4fd830ba1a5c8971
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,11 @@ class SQLAnalyzerTestSuite extends QueryTest with SharedSparkSession with SQLHel
val absPath = file.getAbsolutePath
val testCaseName = absPath.stripPrefix(inputFilePath).stripPrefix(File.separator)

RegularTestCase(testCaseName, absPath, resultFile) :: Nil
if (file.getAbsolutePath.startsWith(s"$inputFilePath${File.separator}ansi")) {
AnsiTestCase(testCaseName, absPath, resultFile) :: Nil
} else {
RegularTestCase(testCaseName, absPath, resultFile) :: Nil
}
}.sortBy(_.name)
}

Expand All @@ -142,7 +146,12 @@ class SQLAnalyzerTestSuite extends QueryTest with SharedSparkSession with SQLHel
// This does not isolate catalog changes.
val localSparkSession = spark.newSession()

localSparkSession.conf.set(SQLConf.ANSI_ENABLED.key, false)
testCase match {
case _: AnsiTest =>
localSparkSession.conf.set(SQLConf.ANSI_ENABLED.key, true)
case _ =>
localSparkSession.conf.set(SQLConf.ANSI_ENABLED.key, false)
}

if (configSet.nonEmpty) {
// Execute the list of set operations in order to add the desired configurations.
Expand Down Expand Up @@ -201,4 +210,17 @@ class SQLAnalyzerTestSuite extends QueryTest with SharedSparkSession with SQLHel
// Get the output, but also get rid of the #1234 expression IDs that show up in plan strings.
(schema, Seq(replaceNotIncludedMsg(df.queryExecution.analyzed.toString)))
}

override def beforeAll(): Unit = {
super.beforeAll()
createTestTables(spark)
}

override def afterAll(): Unit = {
try {
removeTestTables(spark)
} finally {
super.afterAll()
}
}
}
123 changes: 122 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql

import java.io.File
import java.net.URI

import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
Expand All @@ -32,9 +33,12 @@ import org.apache.spark.sql.catalyst.util.fileToString
import org.apache.spark.sql.execution.HiveResult.hiveResultString
import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeCommandBase}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils

trait SQLQueryTestHelper extends Logging {
trait SQLQueryTestHelper extends Logging with SQLTestUtils {

private val notIncludedMsg = "[not included in comparison]"
private val clsName = this.getClass.getCanonicalName
Expand Down Expand Up @@ -127,6 +131,13 @@ trait SQLQueryTestHelper extends Logging {
protected case class RegularTestCase(
name: String, inputFile: String, resultFile: String) extends TestCase

/** An ANSI-related test case. */
protected case class AnsiTestCase(
name: String, inputFile: String, resultFile: String) extends TestCase with AnsiTest

/** Trait that indicates ANSI-related tests with the ANSI mode enabled. */
protected trait AnsiTest

/** Run a test case. */
protected def runTest(
testCase: TestCase,
Expand Down Expand Up @@ -247,4 +258,114 @@ trait SQLQueryTestHelper extends Logging {
val filteredFiles = files.filter(_.getName.endsWith(validFileExtensions))
filteredFiles ++ dirs.flatMap(listFilesRecursively)
}

/** Load built-in test tables into the SparkSession. */
protected def createTestTables(session: SparkSession, conf: SQLConf): Unit = {
import session.implicits._

// Before creating test tables, deletes orphan directories in warehouse dir
Seq("testdata", "arraydata", "mapdata", "aggtest", "onek", "tenk1").foreach { dirName =>
val f = new File(new URI(s"${conf.warehousePath}/$dirName"))
if (f.exists()) {
Utils.deleteRecursively(f)
}
}

(1 to 100).map(i => (i, i.toString)).toDF("key", "value")
.repartition(1)
.write
.format("parquet")
.saveAsTable("testdata")

((Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: (Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
.toDF("arraycol", "nestedarraycol")
.write
.format("parquet")
.saveAsTable("arraydata")

(Tuple1(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
Tuple1(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
Tuple1(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
Tuple1(Map(1 -> "a4", 2 -> "b4")) ::
Tuple1(Map(1 -> "a5")) :: Nil)
.toDF("mapcol")
.write
.format("parquet")
.saveAsTable("mapdata")

session
.read
.format("csv")
.options(Map("delimiter" -> "\t", "header" -> "false"))
.schema("a int, b float")
.load(testFile("test-data/postgresql/agg.data"))
.write
.format("parquet")
.saveAsTable("aggtest")

session
.read
.format("csv")
.options(Map("delimiter" -> "\t", "header" -> "false"))
.schema(
"""
|unique1 int,
|unique2 int,
|two int,
|four int,
|ten int,
|twenty int,
|hundred int,
|thousand int,
|twothousand int,
|fivethous int,
|tenthous int,
|odd int,
|even int,
|stringu1 string,
|stringu2 string,
|string4 string
""".stripMargin)
.load(testFile("test-data/postgresql/onek.data"))
.write
.format("parquet")
.saveAsTable("onek")

session
.read
.format("csv")
.options(Map("delimiter" -> "\t", "header" -> "false"))
.schema(
"""
|unique1 int,
|unique2 int,
|two int,
|four int,
|ten int,
|twenty int,
|hundred int,
|thousand int,
|twothousand int,
|fivethous int,
|tenthous int,
|odd int,
|even int,
|stringu1 string,
|stringu2 string,
|string4 string
""".stripMargin)
.load(testFile("test-data/postgresql/tenk.data"))
.write
.format("parquet")
.saveAsTable("tenk1")
}

protected def removeTestTables(session: SparkSession): Unit = {
session.sql("DROP TABLE IF EXISTS testdata")
session.sql("DROP TABLE IF EXISTS arraydata")
session.sql("DROP TABLE IF EXISTS mapdata")
session.sql("DROP TABLE IF EXISTS aggtest")
session.sql("DROP TABLE IF EXISTS onek")
session.sql("DROP TABLE IF EXISTS tenk1")
}
}
122 changes: 2 additions & 120 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,6 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
*/
protected trait PgSQLTest

/**
* traits that indicate ANSI-related tests with the ANSI mode enabled.
*/
protected trait AnsiTest

/**
* traits that indicate the default timestamp type is TimestampNTZType.
*/
Expand Down Expand Up @@ -202,10 +197,6 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
resultFile: String,
udf: TestUDF) extends TestCase with UDFTest with PgSQLTest

/** An ANSI-related test case. */
protected case class AnsiTestCase(
name: String, inputFile: String, resultFile: String) extends TestCase with AnsiTest

/** An date time test case with default timestamp as TimestampNTZType */
protected case class TimestampNTZTestCase(
name: String, inputFile: String, resultFile: String) extends TestCase with TimestampNTZTest
Expand Down Expand Up @@ -281,6 +272,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
logInfo(s"Setting configs: ${setOperations.mkString(", ")}")
setOperations.foreach(localSparkSession.sql)
}
conf

// Run the SQL queries preparing them for comparison.
val outputs: Seq[QueryOutput] = queries.map { sql =>
Expand Down Expand Up @@ -371,119 +363,9 @@ class SQLQueryTestSuite extends QueryTest with SharedSparkSession with SQLHelper
}.sortBy(_.name)
}

/** Load built-in test tables into the SparkSession. */
private def createTestTables(session: SparkSession): Unit = {
import session.implicits._

// Before creating test tables, deletes orphan directories in warehouse dir
Seq("testdata", "arraydata", "mapdata", "aggtest", "onek", "tenk1").foreach { dirName =>
val f = new File(new URI(s"${conf.warehousePath}/$dirName"))
if (f.exists()) {
Utils.deleteRecursively(f)
}
}

(1 to 100).map(i => (i, i.toString)).toDF("key", "value")
.repartition(1)
.write
.format("parquet")
.saveAsTable("testdata")

((Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: (Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
.toDF("arraycol", "nestedarraycol")
.write
.format("parquet")
.saveAsTable("arraydata")

(Tuple1(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
Tuple1(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
Tuple1(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
Tuple1(Map(1 -> "a4", 2 -> "b4")) ::
Tuple1(Map(1 -> "a5")) :: Nil)
.toDF("mapcol")
.write
.format("parquet")
.saveAsTable("mapdata")

session
.read
.format("csv")
.options(Map("delimiter" -> "\t", "header" -> "false"))
.schema("a int, b float")
.load(testFile("test-data/postgresql/agg.data"))
.write
.format("parquet")
.saveAsTable("aggtest")

session
.read
.format("csv")
.options(Map("delimiter" -> "\t", "header" -> "false"))
.schema(
"""
|unique1 int,
|unique2 int,
|two int,
|four int,
|ten int,
|twenty int,
|hundred int,
|thousand int,
|twothousand int,
|fivethous int,
|tenthous int,
|odd int,
|even int,
|stringu1 string,
|stringu2 string,
|string4 string
""".stripMargin)
.load(testFile("test-data/postgresql/onek.data"))
.write
.format("parquet")
.saveAsTable("onek")

session
.read
.format("csv")
.options(Map("delimiter" -> "\t", "header" -> "false"))
.schema(
"""
|unique1 int,
|unique2 int,
|two int,
|four int,
|ten int,
|twenty int,
|hundred int,
|thousand int,
|twothousand int,
|fivethous int,
|tenthous int,
|odd int,
|even int,
|stringu1 string,
|stringu2 string,
|string4 string
""".stripMargin)
.load(testFile("test-data/postgresql/tenk.data"))
.write
.format("parquet")
.saveAsTable("tenk1")
}

private def removeTestTables(session: SparkSession): Unit = {
session.sql("DROP TABLE IF EXISTS testdata")
session.sql("DROP TABLE IF EXISTS arraydata")
session.sql("DROP TABLE IF EXISTS mapdata")
session.sql("DROP TABLE IF EXISTS aggtest")
session.sql("DROP TABLE IF EXISTS onek")
session.sql("DROP TABLE IF EXISTS tenk1")
}

override def beforeAll(): Unit = {
super.beforeAll()
createTestTables(spark)
createTestTables(spark, conf)
RuleExecutor.resetMetrics()
CodeGenerator.resetCompileTime()
WholeStageCodegenExec.resetCodeGenTime()
Expand Down